strat9_kernel/ipc/
semaphore.rs1use crate::sync::{SpinLock, WaitQueue};
2use alloc::{collections::BTreeMap, sync::Arc};
3use core::sync::atomic::{AtomicBool, AtomicI32, AtomicU64, Ordering};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
6pub struct SemId(pub u64);
7
8impl SemId {
9 pub fn as_u64(self) -> u64 {
11 self.0
12 }
13 pub fn from_u64(raw: u64) -> Self {
15 Self(raw)
16 }
17}
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
20pub enum SemaphoreError {
21 #[error("would block")]
22 WouldBlock,
23 #[error("semaphore destroyed")]
24 Destroyed,
25 #[error("invalid initial value")]
26 InvalidValue,
27 #[error("semaphore not found")]
28 NotFound,
29}
30
31pub struct PosixSemaphore {
32 count: AtomicI32,
33 destroyed: AtomicBool,
34 waitq: WaitQueue,
35}
36
37impl PosixSemaphore {
38 fn new(initial: u32) -> Self {
40 Self {
41 count: AtomicI32::new(initial as i32),
42 destroyed: AtomicBool::new(false),
43 waitq: WaitQueue::new(),
44 }
45 }
46
47 pub fn wait(&self) -> Result<(), SemaphoreError> {
49 self.waitq.wait_until(|| {
50 if self.destroyed.load(Ordering::Acquire) {
51 return Some(Err(SemaphoreError::Destroyed));
52 }
53 let cur = self.count.load(Ordering::Acquire);
54 if cur <= 0 {
55 return None;
56 }
57 match self.count.compare_exchange_weak(
58 cur,
59 cur - 1,
60 Ordering::AcqRel,
61 Ordering::Acquire,
62 ) {
63 Ok(_) => Some(Ok(())),
64 Err(_) => None,
65 }
66 })
67 }
68
69 pub fn try_wait(&self) -> Result<(), SemaphoreError> {
71 if self.destroyed.load(Ordering::Acquire) {
72 return Err(SemaphoreError::Destroyed);
73 }
74 loop {
75 let cur = self.count.load(Ordering::Acquire);
76 if cur <= 0 {
77 return Err(SemaphoreError::WouldBlock);
78 }
79 if self
80 .count
81 .compare_exchange_weak(cur, cur - 1, Ordering::AcqRel, Ordering::Acquire)
82 .is_ok()
83 {
84 return Ok(());
85 }
86 }
87 }
88
89 pub fn post(&self) -> Result<(), SemaphoreError> {
91 if self.destroyed.load(Ordering::Acquire) {
92 return Err(SemaphoreError::Destroyed);
93 }
94 loop {
95 let cur = self.count.load(Ordering::Acquire);
96 if cur >= i32::MAX {
97 return Err(SemaphoreError::InvalidValue);
98 }
99 if self
100 .count
101 .compare_exchange_weak(cur, cur + 1, Ordering::AcqRel, Ordering::Acquire)
102 .is_ok()
103 {
104 break;
105 }
106 }
107 self.waitq.wake_one();
108 Ok(())
109 }
110
111 pub fn destroy(&self) {
113 self.destroyed.store(true, Ordering::Release);
114 self.waitq.wake_all();
115 }
116
117 pub fn count(&self) -> i32 {
119 self.count.load(Ordering::Acquire)
120 }
121
122 pub fn is_destroyed(&self) -> bool {
124 self.destroyed.load(Ordering::Acquire)
125 }
126}
127
128static NEXT_SEM_ID: AtomicU64 = AtomicU64::new(1);
129static SEMAPHORES: SpinLock<Option<BTreeMap<SemId, Arc<PosixSemaphore>>>> = SpinLock::new(None);
130
131fn ensure_registry(guard: &mut Option<BTreeMap<SemId, Arc<PosixSemaphore>>>) {
133 if guard.is_none() {
134 *guard = Some(BTreeMap::new());
135 }
136}
137
138pub fn create_semaphore(initial: u32) -> Result<SemId, SemaphoreError> {
140 if initial > i32::MAX as u32 {
141 return Err(SemaphoreError::InvalidValue);
142 }
143 let id = SemId(NEXT_SEM_ID.fetch_add(1, Ordering::Relaxed));
144 let sem = Arc::new(PosixSemaphore::new(initial));
145 let mut reg = SEMAPHORES.lock();
146 ensure_registry(&mut *reg);
147 reg.as_mut().unwrap().insert(id, sem);
148 Ok(id)
149}
150
151pub fn get_semaphore(id: SemId) -> Option<Arc<PosixSemaphore>> {
153 let reg = SEMAPHORES.lock();
154 reg.as_ref().and_then(|m| m.get(&id).cloned())
155}
156
157pub fn destroy_semaphore(id: SemId) -> Result<(), SemaphoreError> {
159 let sem = {
160 let mut reg = SEMAPHORES.lock();
161 let map = reg.as_mut().ok_or(SemaphoreError::NotFound)?;
162 map.remove(&id).ok_or(SemaphoreError::NotFound)?
163 };
164 sem.destroy();
165 Ok(())
166}