Skip to main content

strat9_kernel/vfs/
pty_scheme.rs

1//! Pseudo-terminal (PTY) scheme.
2//!
3//! A PTY pair consists of a **master** (controlling process) and a **slave**
4//! (child process).  Data written to one side is readable from the other.
5//!
6//! # Scheme paths
7//!
8//! | Path          | Description                                 |
9//! |---------------|---------------------------------------------|
10//! | `/dev/pts/new`| Open to allocate a new PTY pair (returns master fd). |
11//! | `/dev/pts/N`  | Open the slave side of PTY number N.        |
12//! | `/dev/pts/`   | List existing PTYs.                         |
13
14use alloc::{collections::BTreeMap, string::String, sync::Arc, vec::Vec};
15use core::sync::atomic::{AtomicU64, Ordering};
16
17use crate::{sync::SpinLock, syscall::error::SyscallError};
18
19use super::scheme::{DirEntry, FileFlags, FileStat, OpenFlags, OpenResult, Scheme, DT_CHR};
20
21const PTY_BUF_SIZE: usize = 4096;
22
23/// Ring buffer used for each direction of a PTY pair.
24struct RingBuf {
25    buf: [u8; PTY_BUF_SIZE],
26    head: usize,
27    tail: usize,
28    count: usize,
29}
30
31impl RingBuf {
32    const fn new() -> Self {
33        Self {
34            buf: [0; PTY_BUF_SIZE],
35            head: 0,
36            tail: 0,
37            count: 0,
38        }
39    }
40
41    fn push(&mut self, data: &[u8]) -> usize {
42        let mut written = 0;
43        for &b in data {
44            if self.count >= PTY_BUF_SIZE {
45                break;
46            }
47            self.buf[self.tail] = b;
48            self.tail = (self.tail + 1) % PTY_BUF_SIZE;
49            self.count += 1;
50            written += 1;
51        }
52        written
53    }
54
55    fn pop(&mut self, out: &mut [u8]) -> usize {
56        let mut read = 0;
57        for slot in out.iter_mut() {
58            if self.count == 0 {
59                break;
60            }
61            *slot = self.buf[self.head];
62            self.head = (self.head + 1) % PTY_BUF_SIZE;
63            self.count -= 1;
64            read += 1;
65        }
66        read
67    }
68}
69
70/// State of a single PTY pair.
71struct PtyPair {
72    /// Data written by master, read by slave.
73    to_slave: RingBuf,
74    /// Data written by slave, read by master.
75    to_master: RingBuf,
76    master_open: bool,
77    slave_open: bool,
78}
79
80/// Manages all PTY pairs.
81pub struct PtyScheme {
82    pairs: SpinLock<BTreeMap<u64, PtyPair>>,
83    next_pty: AtomicU64,
84    next_fid: AtomicU64,
85    /// Maps file_id → (pty_id, is_master).
86    handles: SpinLock<BTreeMap<u64, (u64, bool)>>,
87}
88
89impl PtyScheme {
90    /// Create a new PTY scheme instance.
91    pub fn new() -> Self {
92        Self {
93            pairs: SpinLock::new(BTreeMap::new()),
94            next_pty: AtomicU64::new(0),
95            next_fid: AtomicU64::new(1),
96            handles: SpinLock::new(BTreeMap::new()),
97        }
98    }
99
100    fn alloc_fid(&self) -> u64 {
101        self.next_fid.fetch_add(1, Ordering::Relaxed)
102    }
103}
104
105impl Scheme for PtyScheme {
106    fn open(&self, path: &str, _flags: OpenFlags) -> Result<OpenResult, SyscallError> {
107        let path = path.trim_start_matches('/');
108
109        if path == "new" {
110            let pty_id = self.next_pty.fetch_add(1, Ordering::Relaxed);
111            let pair = PtyPair {
112                to_slave: RingBuf::new(),
113                to_master: RingBuf::new(),
114                master_open: true,
115                slave_open: false,
116            };
117            self.pairs.lock().insert(pty_id, pair);
118
119            let fid = self.alloc_fid();
120            self.handles.lock().insert(fid, (pty_id, true));
121
122            Ok(OpenResult {
123                file_id: fid,
124                size: None,
125                flags: FileFlags::DEVICE,
126            })
127        } else if let Ok(pty_id) = path.parse::<u64>() {
128            let mut pairs = self.pairs.lock();
129            let pair = pairs.get_mut(&pty_id).ok_or(SyscallError::NotFound)?;
130            pair.slave_open = true;
131
132            let fid = self.alloc_fid();
133            self.handles.lock().insert(fid, (pty_id, false));
134
135            Ok(OpenResult {
136                file_id: fid,
137                size: None,
138                flags: FileFlags::DEVICE,
139            })
140        } else if path.is_empty() {
141            let fid = self.alloc_fid();
142            self.handles.lock().insert(fid, (u64::MAX, false));
143            Ok(OpenResult {
144                file_id: fid,
145                size: None,
146                flags: FileFlags::DIRECTORY,
147            })
148        } else {
149            Err(SyscallError::NotFound)
150        }
151    }
152
153    fn read(&self, file_id: u64, _offset: u64, buf: &mut [u8]) -> Result<usize, SyscallError> {
154        let (pty_id, is_master) = {
155            let handles = self.handles.lock();
156            *handles.get(&file_id).ok_or(SyscallError::BadHandle)?
157        };
158
159        if pty_id == u64::MAX {
160            return self.readdir_root(buf);
161        }
162
163        let mut pairs = self.pairs.lock();
164        let pair = pairs.get_mut(&pty_id).ok_or(SyscallError::BadHandle)?;
165
166        let ring = if is_master {
167            &mut pair.to_master
168        } else {
169            &mut pair.to_slave
170        };
171        let n = ring.pop(buf);
172        Ok(n)
173    }
174
175    fn write(&self, file_id: u64, _offset: u64, buf: &[u8]) -> Result<usize, SyscallError> {
176        let (pty_id, is_master) = {
177            let handles = self.handles.lock();
178            *handles.get(&file_id).ok_or(SyscallError::BadHandle)?
179        };
180
181        let mut pairs = self.pairs.lock();
182        let pair = pairs.get_mut(&pty_id).ok_or(SyscallError::BadHandle)?;
183
184        let ring = if is_master {
185            &mut pair.to_slave
186        } else {
187            &mut pair.to_master
188        };
189        let n = ring.push(buf);
190        Ok(n)
191    }
192
193    fn close(&self, file_id: u64) -> Result<(), SyscallError> {
194        let entry = self.handles.lock().remove(&file_id);
195        if let Some((pty_id, is_master)) = entry {
196            if pty_id != u64::MAX {
197                let mut pairs = self.pairs.lock();
198                if let Some(pair) = pairs.get_mut(&pty_id) {
199                    if is_master {
200                        pair.master_open = false;
201                    } else {
202                        pair.slave_open = false;
203                    }
204                    if !pair.master_open && !pair.slave_open {
205                        pairs.remove(&pty_id);
206                    }
207                }
208            }
209        }
210        Ok(())
211    }
212
213    fn stat(&self, file_id: u64) -> Result<FileStat, SyscallError> {
214        let handles = self.handles.lock();
215        let &(pty_id, _) = handles.get(&file_id).ok_or(SyscallError::BadHandle)?;
216        let mut st = FileStat::zeroed();
217        if pty_id == u64::MAX {
218            st.st_mode = 0o40755;
219        } else {
220            st.st_mode = 0o20666;
221            st.st_ino = pty_id;
222        }
223        st.st_nlink = 1;
224        Ok(st)
225    }
226
227    fn readdir(&self, _file_id: u64) -> Result<Vec<DirEntry>, SyscallError> {
228        let pairs = self.pairs.lock();
229        let mut entries = Vec::new();
230        for &pty_id in pairs.keys() {
231            entries.push(DirEntry {
232                ino: pty_id,
233                file_type: DT_CHR,
234                name: alloc::format!("{}", pty_id),
235            });
236        }
237        Ok(entries)
238    }
239}
240
241impl PtyScheme {
242    fn readdir_root(&self, buf: &mut [u8]) -> Result<usize, SyscallError> {
243        let pairs = self.pairs.lock();
244        let mut out = String::new();
245        for &pty_id in pairs.keys() {
246            out.push_str(&alloc::format!("{}\n", pty_id));
247        }
248        let bytes = out.as_bytes();
249        let n = bytes.len().min(buf.len());
250        buf[..n].copy_from_slice(&bytes[..n]);
251        Ok(n)
252    }
253}
254
255static PTY_SCHEME: SpinLock<Option<Arc<PtyScheme>>> = SpinLock::new(None);
256
257/// Initialize and mount the PTY scheme at `/dev/pts`.
258pub fn init_pty_scheme() {
259    let scheme = Arc::new(PtyScheme::new());
260    *PTY_SCHEME.lock() = Some(scheme.clone());
261    let _ = super::mount::mount("/dev/pts", scheme);
262}
263
264/// Get a reference to the global PTY scheme instance.
265pub fn get_pty_scheme() -> Option<Arc<PtyScheme>> {
266    PTY_SCHEME.lock().clone()
267}