Skip to main content

strat9_kernel/ostd/
util.rs

1//! Utility types and functions for OSTD
2//!
3//! Provides common utility types including:
4//! - ID sets for CPU/task sets
5//! - Bit manipulation helpers
6//! - Other common utilities
7
8#![deny(unsafe_code)]
9
10extern crate alloc;
11
12use alloc::vec::Vec;
13
14/// A set of IDs (e.g., CPU IDs, task IDs)
15///
16/// Efficiently stores sets of IDs using bitmaps for small IDs
17/// and fallback vectors for sparse large IDs.
18#[derive(Debug, Clone)]
19pub struct IdSet {
20    /// Bitmap for IDs 0-63
21    low_bits: u64,
22    /// Vector for IDs >= 64
23    high_ids: Vec<usize>,
24}
25
26impl IdSet {
27    /// Creates a new empty ID set
28    pub const fn new() -> Self {
29        Self {
30            low_bits: 0,
31            high_ids: Vec::new(),
32        }
33    }
34
35    /// Creates an ID set containing all IDs from 0 to max
36    pub fn all(max: usize) -> Self {
37        let mut set = Self::new();
38        for i in 0..=max {
39            set.insert(i);
40        }
41        set
42    }
43
44    /// Inserts an ID into the set
45    pub fn insert(&mut self, id: usize) {
46        if id < 64 {
47            self.low_bits |= 1 << id;
48        } else {
49            if !self.high_ids.contains(&id) {
50                self.high_ids.push(id);
51            }
52        }
53    }
54
55    /// Removes an ID from the set
56    pub fn remove(&mut self, id: usize) {
57        if id < 64 {
58            self.low_bits &= !(1 << id);
59        } else {
60            self.high_ids.retain(|&x| x != id);
61        }
62    }
63
64    /// Checks if an ID is in the set
65    pub fn contains(&self, id: usize) -> bool {
66        if id < 64 {
67            (self.low_bits & (1 << id)) != 0
68        } else {
69            self.high_ids.contains(&id)
70        }
71    }
72
73    /// Returns true if the set is empty
74    pub fn is_empty(&self) -> bool {
75        self.low_bits == 0 && self.high_ids.is_empty()
76    }
77
78    /// Returns the number of IDs in the set
79    pub fn len(&self) -> usize {
80        self.low_bits.count_ones() as usize + self.high_ids.len()
81    }
82
83    /// Clears the set
84    pub fn clear(&mut self) {
85        self.low_bits = 0;
86        self.high_ids.clear();
87    }
88
89    /// Returns an iterator over the IDs in the set
90    pub fn iter(&self) -> IdSetIter<'_> {
91        IdSetIter {
92            low_bits: self.low_bits,
93            low_index: 0,
94            high_iter: self.high_ids.iter(),
95        }
96    }
97}
98
99impl Default for IdSet {
100    /// Builds a default instance.
101    fn default() -> Self {
102        Self::new()
103    }
104}
105
106/// Iterator over IDs in an IdSet
107pub struct IdSetIter<'a> {
108    low_bits: u64,
109    low_index: usize,
110    high_iter: core::slice::Iter<'a, usize>,
111}
112
113impl<'a> Iterator for IdSetIter<'a> {
114    type Item = usize;
115
116    /// Performs the next operation.
117    fn next(&mut self) -> Option<Self::Item> {
118        // Check low bits first
119        while self.low_index < 64 {
120            if (self.low_bits & (1 << self.low_index)) != 0 {
121                let id = self.low_index;
122                self.low_index += 1;
123                return Some(id);
124            }
125            self.low_index += 1;
126        }
127
128        // Then check high IDs
129        self.high_iter.next().copied()
130    }
131}
132
133/// A CPU set for tracking which CPUs are online/active
134pub type CpuSet = IdSet;
135
136/// Bit manipulation utilities
137pub mod bits {
138    /// Aligns a value up to the given alignment
139    #[inline]
140    pub const fn align_up(value: usize, align: usize) -> usize {
141        (value + align - 1) & !(align - 1)
142    }
143
144    /// Aligns a value down to the given alignment
145    #[inline]
146    pub const fn align_down(value: usize, align: usize) -> usize {
147        value & !(align - 1)
148    }
149
150    /// Checks if a value is aligned to the given alignment
151    #[inline]
152    pub const fn is_aligned(value: usize, align: usize) -> bool {
153        value & (align - 1) == 0
154    }
155
156    /// Returns the number of leading zeros in a u64
157    #[inline]
158    pub const fn leading_zeros(x: u64) -> u32 {
159        x.leading_zeros()
160    }
161
162    /// Returns the number of trailing zeros in a u64
163    #[inline]
164    pub const fn trailing_zeros(x: u64) -> u32 {
165        x.trailing_zeros()
166    }
167
168    /// Returns the number of set bits in a u64
169    #[inline]
170    pub const fn count_ones(x: u64) -> u32 {
171        x.count_ones()
172    }
173
174    /// Returns the next power of two greater than or equal to x
175    #[inline]
176    pub const fn next_power_of_two(mut x: usize) -> usize {
177        if x == 0 {
178            return 1;
179        }
180        x -= 1;
181        x |= x >> 1;
182        x |= x >> 2;
183        x |= x >> 4;
184        x |= x >> 8;
185        x |= x >> 16;
186        #[cfg(target_pointer_width = "64")]
187        {
188            x |= x >> 32;
189        }
190        x + 1
191    }
192
193    /// Returns the log2 of x, rounded down
194    #[inline]
195    pub const fn log2_floor(x: usize) -> u32 {
196        if x == 0 {
197            return 0;
198        }
199        31 - x.leading_zeros()
200    }
201
202    /// Returns the log2 of x, rounded up
203    #[inline]
204    pub const fn log2_ceil(x: usize) -> u32 {
205        if x == 0 {
206            return 0;
207        }
208        let floor = log2_floor(x);
209        if x.is_power_of_two() {
210            floor
211        } else {
212            floor + 1
213        }
214    }
215}
216
217/// Round up a value to the nearest multiple of align
218#[inline]
219pub const fn round_up(value: usize, align: usize) -> usize {
220    bits::align_up(value, align)
221}
222
223/// Round down a value to the nearest multiple of align
224#[inline]
225pub const fn round_down(value: usize, align: usize) -> usize {
226    bits::align_down(value, align)
227}
228
229#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
230pub enum Error {
231    #[error("out of memory")]
232    OutOfMemory,
233    #[error("invalid argument")]
234    InvalidArgument,
235    #[error("not found")]
236    NotFound,
237    #[error("already exists")]
238    AlreadyExists,
239    #[error("permission denied")]
240    PermissionDenied,
241    #[error("busy")]
242    Busy,
243    #[error("page fault")]
244    PageFault,
245    #[error("architecture error: {0}")]
246    ArchError(&'static str),
247}
248
249/// Result type alias for OSTD operations
250pub type Result<T> = core::result::Result<T, Error>;