Skip to main content

strat9_kernel/process/sched_classes/
fair.rs

1// SPDX-License-Identifier: MPL-2.0
2
3use super::{CurrentRuntime, SchedClassRq};
4use crate::process::task::Task;
5use alloc::{collections::BinaryHeap, sync::Arc};
6use core::cmp::{self, Reverse};
7
8const WEIGHT_0: u64 = 1024;
9
10/// Base time slice per task in ticks for the CFS fair scheduler.
11///
12/// At TIMER_HZ=100 (10 ms/tick):
13///   BASE_SLICE_TICKS = 1 -> 1 tick = 10 ms per task (matches `quantum_ms: 10`)
14///
15/// Previously this was mistakenly 10, giving 10 ticks = 100 ms slices and
16/// effectively disabling preemption for lightly loaded workloads.
17///
18/// Derivation: target_ms = 10 ms, tick_ms = 1000 / TIMER_HZ = 10 ms -> 1 tick.
19const BASE_SLICE_TICKS: u64 = 1;
20
21/// Performs the nice to weight operation.
22pub const fn nice_to_weight(nice: super::nice::Nice) -> u64 {
23    const FACTOR_NUMERATOR: u64 = 5;
24    const FACTOR_DENOMINATOR: u64 = 4;
25
26    const NICE_TO_WEIGHT: [u64; 40] = const {
27        let mut ret = [0; 40];
28        let mut index = 0;
29        let mut nice = super::nice::NiceValue::MIN.get();
30        while nice <= super::nice::NiceValue::MAX.get() {
31            ret[index] = match nice {
32                0 => WEIGHT_0,
33                nice @ 1.. => {
34                    let numerator = FACTOR_DENOMINATOR.pow(nice as u32);
35                    let denominator = FACTOR_NUMERATOR.pow(nice as u32);
36                    WEIGHT_0 * numerator / denominator
37                }
38                nice => {
39                    let numerator = FACTOR_NUMERATOR.pow((-nice) as u32);
40                    let denominator = FACTOR_DENOMINATOR.pow((-nice) as u32);
41                    WEIGHT_0 * numerator / denominator
42                }
43            };
44            index += 1;
45            nice += 1;
46        }
47        ret
48    };
49
50    NICE_TO_WEIGHT[(nice.value().get() + 20) as usize]
51}
52
53struct FairQueueItem(Arc<Task>, u64); // Task, vruntime
54
55impl FairQueueItem {
56    /// Performs the key operation.
57    fn key(&self) -> u64 {
58        self.1
59    }
60}
61
62impl core::fmt::Debug for FairQueueItem {
63    /// Performs the fmt operation.
64    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
65        write!(f, "{}", self.key())
66    }
67}
68
69impl PartialEq for FairQueueItem {
70    /// Performs the eq operation.
71    fn eq(&self, other: &Self) -> bool {
72        self.key().eq(&other.key())
73    }
74}
75impl Eq for FairQueueItem {}
76
77impl PartialOrd for FairQueueItem {
78    /// Performs the partial cmp operation.
79    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
80        Some(self.cmp(other))
81    }
82}
83
84impl Ord for FairQueueItem {
85    /// Performs the cmp operation.
86    fn cmp(&self, other: &Self) -> cmp::Ordering {
87        self.key().cmp(&other.key())
88    }
89}
90
91pub struct FairClassRq {
92    entities: BinaryHeap<Reverse<FairQueueItem>>,
93    min_vruntime: u64,
94    total_weight: u64,
95}
96
97impl FairClassRq {
98    /// Creates a new instance.
99    pub fn new() -> Self {
100        Self {
101            entities: BinaryHeap::new(),
102            min_vruntime: 0,
103            total_weight: 0,
104        }
105    }
106
107    /// Performs the period operation.
108    fn period(&self) -> u64 {
109        // Total scheduling period (ticks) = BASE_SLICE_TICKS * nr_runnable.
110        // Ensures each runnable task gets at least BASE_SLICE_TICKS per round.
111        // Minimum = BASE_SLICE_TICKS to avoid division-by-zero in time_slice().
112        let count = (self.entities.len() + 1) as u64;
113        (BASE_SLICE_TICKS * count).max(BASE_SLICE_TICKS)
114    }
115
116    /// Performs the vtime slice operation.
117    fn vtime_slice(&self) -> u64 {
118        self.period() / (self.entities.len() + 1) as u64
119    }
120
121    /// Performs the time slice operation.
122    fn time_slice(&self, cur_weight: u64) -> u64 {
123        if self.total_weight + cur_weight == 0 {
124            return self.period();
125        }
126        self.period() * cur_weight / (self.total_weight + cur_weight)
127    }
128}
129
130impl SchedClassRq for FairClassRq {
131    /// Performs the enqueue operation.
132    fn enqueue(&mut self, task: Arc<Task>) {
133        if let super::SchedPolicy::Fair(nice) = task.sched_policy() {
134            let weight = nice_to_weight(nice);
135            let mut vruntime = task.vruntime();
136            // Start at min_vruntime if hasn't run yet or blocked heavily
137            if vruntime < self.min_vruntime {
138                vruntime = self.min_vruntime;
139            }
140            task.set_vruntime(vruntime);
141            self.total_weight += weight;
142            self.entities.push(Reverse(FairQueueItem(task, vruntime)));
143        }
144    }
145
146    /// Performs the len operation.
147    fn len(&self) -> usize {
148        self.entities.len()
149    }
150
151    /// Performs the pick next operation.
152    fn pick_next(&mut self) -> Option<Arc<Task>> {
153        let Reverse(FairQueueItem(task, _)) = self.entities.pop()?;
154        if let super::SchedPolicy::Fair(nice) = task.sched_policy() {
155            let weight = nice_to_weight(nice);
156            self.total_weight -= weight;
157        }
158        Some(task)
159    }
160
161    /// Updates current.
162    fn update_current(&mut self, rt: &CurrentRuntime, task: &Task, is_yield: bool) -> bool {
163        if is_yield {
164            return true;
165        }
166        if let super::SchedPolicy::Fair(nice) = task.sched_policy() {
167            let weight = nice_to_weight(nice);
168            let delta_vruntime = if weight == 0 {
169                0
170            } else {
171                rt.delta_ticks * WEIGHT_0 / weight
172            };
173            let vruntime = task.vruntime() + delta_vruntime;
174            task.set_vruntime(vruntime);
175
176            let leftmost = self.entities.peek();
177            self.min_vruntime = match leftmost {
178                Some(Reverse(leftmost)) => vruntime.min(leftmost.key()),
179                None => vruntime,
180            };
181
182            if leftmost.is_none() {
183                return false;
184            }
185
186            rt.period_delta_ticks > self.time_slice(weight)
187                || vruntime > self.min_vruntime + self.vtime_slice()
188        } else {
189            false
190        }
191    }
192
193    /// Performs the remove operation.
194    fn remove(&mut self, task_id: crate::process::TaskId) -> bool {
195        let mut vec = self.entities.drain().collect::<alloc::vec::Vec<_>>();
196        let old_len = vec.len();
197        let mut removed_weight = 0u64;
198        vec.retain(|Reverse(item)| {
199            if item.0.id == task_id {
200                if let super::SchedPolicy::Fair(nice) = item.0.sched_policy() {
201                    removed_weight += nice_to_weight(nice);
202                }
203                false
204            } else {
205                true
206            }
207        });
208        let removed = vec.len() < old_len;
209        if removed {
210            self.total_weight = self.total_weight.saturating_sub(removed_weight);
211        }
212        self.entities = alloc::collections::BinaryHeap::from(vec);
213        removed
214    }
215}