Skip to main content

strat9_kernel/ipc/
semaphore.rs

1use crate::sync::{SpinLock, WaitQueue};
2use alloc::{collections::BTreeMap, sync::Arc};
3use core::sync::atomic::{AtomicBool, AtomicI32, AtomicU64, Ordering};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
6pub struct SemId(pub u64);
7
8impl SemId {
9    /// Returns this as u64.
10    pub fn as_u64(self) -> u64 {
11        self.0
12    }
13    /// Builds this from u64.
14    pub fn from_u64(raw: u64) -> Self {
15        Self(raw)
16    }
17}
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
20pub enum SemaphoreError {
21    #[error("would block")]
22    WouldBlock,
23    #[error("semaphore destroyed")]
24    Destroyed,
25    #[error("invalid initial value")]
26    InvalidValue,
27    #[error("semaphore not found")]
28    NotFound,
29}
30
31pub struct PosixSemaphore {
32    count: AtomicI32,
33    destroyed: AtomicBool,
34    waitq: WaitQueue,
35}
36
37impl PosixSemaphore {
38    /// Creates a new instance.
39    fn new(initial: u32) -> Self {
40        Self {
41            count: AtomicI32::new(initial as i32),
42            destroyed: AtomicBool::new(false),
43            waitq: WaitQueue::new(),
44        }
45    }
46
47    /// Performs the wait operation.
48    pub fn wait(&self) -> Result<(), SemaphoreError> {
49        self.waitq.wait_until(|| {
50            if self.destroyed.load(Ordering::Acquire) {
51                return Some(Err(SemaphoreError::Destroyed));
52            }
53            let cur = self.count.load(Ordering::Acquire);
54            if cur <= 0 {
55                return None;
56            }
57            match self.count.compare_exchange_weak(
58                cur,
59                cur - 1,
60                Ordering::AcqRel,
61                Ordering::Acquire,
62            ) {
63                Ok(_) => Some(Ok(())),
64                Err(_) => None,
65            }
66        })
67    }
68
69    /// Attempts to wait.
70    pub fn try_wait(&self) -> Result<(), SemaphoreError> {
71        if self.destroyed.load(Ordering::Acquire) {
72            return Err(SemaphoreError::Destroyed);
73        }
74        loop {
75            let cur = self.count.load(Ordering::Acquire);
76            if cur <= 0 {
77                return Err(SemaphoreError::WouldBlock);
78            }
79            if self
80                .count
81                .compare_exchange_weak(cur, cur - 1, Ordering::AcqRel, Ordering::Acquire)
82                .is_ok()
83            {
84                return Ok(());
85            }
86        }
87    }
88
89    /// Performs the post operation.
90    pub fn post(&self) -> Result<(), SemaphoreError> {
91        if self.destroyed.load(Ordering::Acquire) {
92            return Err(SemaphoreError::Destroyed);
93        }
94        loop {
95            let cur = self.count.load(Ordering::Acquire);
96            if cur >= i32::MAX {
97                return Err(SemaphoreError::InvalidValue);
98            }
99            if self
100                .count
101                .compare_exchange_weak(cur, cur + 1, Ordering::AcqRel, Ordering::Acquire)
102                .is_ok()
103            {
104                break;
105            }
106        }
107        self.waitq.wake_one();
108        Ok(())
109    }
110
111    /// Performs the destroy operation.
112    pub fn destroy(&self) {
113        self.destroyed.store(true, Ordering::Release);
114        self.waitq.wake_all();
115    }
116
117    /// Performs the count operation.
118    pub fn count(&self) -> i32 {
119        self.count.load(Ordering::Acquire)
120    }
121
122    /// Returns whether destroyed.
123    pub fn is_destroyed(&self) -> bool {
124        self.destroyed.load(Ordering::Acquire)
125    }
126}
127
128static NEXT_SEM_ID: AtomicU64 = AtomicU64::new(1);
129static SEMAPHORES: SpinLock<Option<BTreeMap<SemId, Arc<PosixSemaphore>>>> = SpinLock::new(None);
130
131/// Performs the ensure registry operation.
132fn ensure_registry(guard: &mut Option<BTreeMap<SemId, Arc<PosixSemaphore>>>) {
133    if guard.is_none() {
134        *guard = Some(BTreeMap::new());
135    }
136}
137
138/// Creates semaphore.
139pub fn create_semaphore(initial: u32) -> Result<SemId, SemaphoreError> {
140    if initial > i32::MAX as u32 {
141        return Err(SemaphoreError::InvalidValue);
142    }
143    let id = SemId(NEXT_SEM_ID.fetch_add(1, Ordering::Relaxed));
144    let sem = Arc::new(PosixSemaphore::new(initial));
145    let mut reg = SEMAPHORES.lock();
146    ensure_registry(&mut *reg);
147    reg.as_mut().unwrap().insert(id, sem);
148    Ok(id)
149}
150
151/// Returns semaphore.
152pub fn get_semaphore(id: SemId) -> Option<Arc<PosixSemaphore>> {
153    let reg = SEMAPHORES.lock();
154    reg.as_ref().and_then(|m| m.get(&id).cloned())
155}
156
157/// Destroys semaphore.
158pub fn destroy_semaphore(id: SemId) -> Result<(), SemaphoreError> {
159    let sem = {
160        let mut reg = SEMAPHORES.lock();
161        let map = reg.as_mut().ok_or(SemaphoreError::NotFound)?;
162        map.remove(&id).ok_or(SemaphoreError::NotFound)?
163    };
164    sem.destroy();
165    Ok(())
166}