Skip to main content

strat9_kernel/hardware/virtio/
console.rs

1// VirtIO Console Driver
2// Reference: VirtIO spec v1.2, Section 5.3 (Console Device)
3
4use crate::{
5    arch::x86_64::pci::{self, Bar, ProbeCriteria},
6    memory::{allocate_dma_frame, phys_to_virt},
7};
8use alloc::{sync::Arc, vec::Vec};
9use core::sync::atomic::{AtomicBool, Ordering};
10use spin::{Mutex, Once};
11
12const VIRTIO_RING_SIZE: usize = 8;
13const VIRTIO_CONSOLE_PORT_SIZE: usize = 256;
14
15pub struct VirtioConsole {
16    device: VirtioDevice,
17    ports: Mutex<Vec<VirtioConsolePort>>,
18}
19
20struct VirtioDevice {
21    mmio: usize,
22}
23
24struct VirtioConsolePort {
25    #[allow(dead_code)]
26    id: u16,
27    rx_queue: Virtqueue,
28    tx_queue: Virtqueue,
29    #[allow(dead_code)]
30    open: bool,
31}
32
33struct Virtqueue {
34    desc: *mut VirtqDesc,
35    avail: *mut VirtqAvail,
36    used: *mut VirtqUsed,
37    #[allow(dead_code)]
38    desc_phys: u64,
39    #[allow(dead_code)]
40    avail_phys: u64,
41    #[allow(dead_code)]
42    used_phys: u64,
43    buffer_phys: u64,
44    buffer_virt: *mut u8,
45    free: Vec<u16>,
46    last_used_idx: u16,
47}
48
49unsafe impl Send for Virtqueue {}
50
51#[repr(C)]
52#[derive(Clone, Copy)]
53struct VirtqDesc {
54    addr: u64,
55    len: u32,
56    flags: u16,
57    next: u16,
58}
59
60#[repr(C)]
61struct VirtqAvail {
62    flags: u16,
63    idx: u16,
64    ring: [u16; VIRTIO_RING_SIZE],
65}
66
67#[repr(C)]
68struct VirtqUsed {
69    flags: u16,
70    idx: u16,
71    ring: [VirtqUsedElem; VIRTIO_RING_SIZE],
72}
73
74#[repr(C)]
75#[derive(Clone, Copy)]
76struct VirtqUsedElem {
77    id: u32,
78    len: u32,
79}
80
81const VIRTIO_F_VERSION_1: u64 = 1 << 32;
82const VIRTIO_CONSOLE_F_MULTIPORT: u32 = 1;
83
84#[allow(dead_code)]
85const VIRTIO_STATUS_RESET: u8 = 0;
86const VIRTIO_STATUS_ACKNOWLEDGE: u8 = 1;
87const VIRTIO_STATUS_DRIVER: u8 = 2;
88const VIRTIO_STATUS_DRIVER_OK: u8 = 4;
89const VIRTIO_STATUS_FEATURES_OK: u8 = 8;
90
91impl VirtioConsole {
92    /// Creates a new instance.
93    pub unsafe fn new(pci_dev: pci::PciDevice) -> Result<Self, &'static str> {
94        let bar = match pci_dev.read_bar(0) {
95            Some(Bar::Memory64 { addr, .. }) => addr,
96            _ => return Err("Invalid BAR"),
97        };
98
99        let mmio = phys_to_virt(bar) as usize;
100        let mut device = VirtioDevice { mmio };
101
102        device.reset();
103        device.add_status(VIRTIO_STATUS_ACKNOWLEDGE);
104        device.add_status(VIRTIO_STATUS_DRIVER);
105
106        let features = device.read_features();
107        let mut guest_features = VIRTIO_F_VERSION_1;
108        if (features & (1 << VIRTIO_CONSOLE_F_MULTIPORT)) != 0 {
109            guest_features |= 1 << VIRTIO_CONSOLE_F_MULTIPORT;
110        }
111        device.write_features(guest_features);
112        device.add_status(VIRTIO_STATUS_FEATURES_OK);
113
114        if (device.read_status() & VIRTIO_STATUS_FEATURES_OK) == 0 {
115            return Err("Features negotiation failed");
116        }
117
118        let mut ports = Vec::new();
119        let rx_queue = Virtqueue::new(&mut device, 0)?;
120        let tx_queue = Virtqueue::new(&mut device, 1)?;
121
122        ports.push(VirtioConsolePort {
123            id: 0,
124            rx_queue,
125            tx_queue,
126            open: true,
127        });
128
129        device.add_status(VIRTIO_STATUS_DRIVER_OK);
130
131        Ok(Self {
132            device,
133            ports: Mutex::new(ports),
134        })
135    }
136
137    /// Performs the write operation.
138    pub fn write(&self, data: &[u8]) -> Result<usize, &'static str> {
139        let mut ports = self.ports.lock();
140        let port = ports.first_mut().ok_or("No console port")?;
141
142        for chunk in data.chunks(VIRTIO_CONSOLE_PORT_SIZE) {
143            port.tx_queue.write(chunk)?;
144            self.device.notify_queue(1);
145
146            loop {
147                if port.tx_queue.poll_used() {
148                    break;
149                }
150                core::hint::spin_loop();
151            }
152        }
153
154        Ok(data.len())
155    }
156
157    /// Performs the read operation.
158    pub fn read(&self, buf: &mut [u8]) -> Result<usize, &'static str> {
159        let mut ports = self.ports.lock();
160        let port = ports.first_mut().ok_or("No console port")?;
161        port.rx_queue.read(buf)
162    }
163}
164
165impl VirtioDevice {
166    /// Performs the reset operation.
167    fn reset(&mut self) {
168        unsafe {
169            (self.mmio as *mut u32).write_volatile(0);
170        }
171        core::hint::spin_loop();
172    }
173
174    /// Performs the add status operation.
175    fn add_status(&mut self, status: u8) {
176        unsafe {
177            let current = ((self.mmio + 0x14) as *const u8).read_volatile();
178            ((self.mmio + 0x14) as *mut u8).write_volatile(current | status);
179        }
180    }
181
182    /// Reads status.
183    fn read_status(&self) -> u8 {
184        unsafe { ((self.mmio + 0x14) as *const u8).read_volatile() }
185    }
186
187    /// Reads features.
188    fn read_features(&self) -> u64 {
189        unsafe {
190            let lo = (self.mmio as *const u32).read_volatile() as u64;
191            let hi = ((self.mmio + 4) as *const u32).read_volatile() as u64;
192            (hi << 32) | lo
193        }
194    }
195
196    /// Writes features.
197    fn write_features(&mut self, features: u64) {
198        unsafe {
199            (self.mmio as *mut u32).write_volatile((features & 0xFFFFFFFF) as u32);
200            ((self.mmio + 4) as *mut u32).write_volatile(((features >> 32) & 0xFFFFFFFF) as u32);
201        }
202    }
203
204    /// Performs the notify queue operation.
205    fn notify_queue(&self, queue: u16) {
206        unsafe {
207            let offset = ((self.mmio + 0x20) as *const u16).read_volatile() as usize;
208            let queue_notify = self.mmio + 0x50 + offset * 4;
209            (queue_notify as *mut u32).write_volatile(queue as u32);
210        }
211    }
212}
213
214impl Virtqueue {
215    /// Creates a new instance.
216    fn new(device: &mut VirtioDevice, queue_idx: u16) -> Result<Self, &'static str> {
217        unsafe {
218            ((device.mmio + 0x16) as *mut u16).write_volatile(queue_idx);
219            let max_size = ((device.mmio + 0x18) as *const u16).read_volatile();
220            if max_size < VIRTIO_RING_SIZE as u16 {
221                return Err("Queue size too small");
222            }
223            ((device.mmio + 0x16) as *mut u16).write_volatile(VIRTIO_RING_SIZE as u16);
224
225            let desc_frame = allocate_dma_frame().ok_or("Failed to allocate desc")?;
226            let avail_frame = allocate_dma_frame().ok_or("Failed to allocate avail")?;
227            let used_frame = allocate_dma_frame().ok_or("Failed to allocate used")?;
228
229            let desc_phys = desc_frame.start_address.as_u64();
230            let avail_phys = avail_frame.start_address.as_u64();
231            let used_phys = used_frame.start_address.as_u64();
232
233            let desc_virt = phys_to_virt(desc_phys) as *mut VirtqDesc;
234            let avail_virt = phys_to_virt(avail_phys) as *mut VirtqAvail;
235            let used_virt = phys_to_virt(used_phys) as *mut VirtqUsed;
236
237            core::ptr::write_bytes(
238                desc_virt,
239                0,
240                VIRTIO_RING_SIZE * core::mem::size_of::<VirtqDesc>(),
241            );
242            core::ptr::write_bytes(avail_virt, 0, core::mem::size_of::<VirtqAvail>());
243            core::ptr::write_bytes(used_virt, 0, core::mem::size_of::<VirtqUsed>());
244
245            ((device.mmio + 0x10) as *mut u32).write_volatile((desc_phys & 0xFFFFFFFF) as u32);
246            ((device.mmio + 0x1A) as *mut u16).write_volatile(0xFFFF);
247
248            let buffer_frame = allocate_dma_frame().ok_or("Failed to allocate buffer")?;
249            let buffer_phys = buffer_frame.start_address.as_u64();
250            let buffer_virt = phys_to_virt(buffer_phys) as *mut u8;
251            core::ptr::write_bytes(buffer_virt, 0, 4096);
252
253            let mut free = Vec::with_capacity(VIRTIO_RING_SIZE);
254            for i in 0..VIRTIO_RING_SIZE {
255                free.push(i as u16);
256            }
257
258            Ok(Self {
259                desc: desc_virt,
260                avail: avail_virt,
261                used: used_virt,
262                desc_phys,
263                avail_phys,
264                used_phys,
265                buffer_phys,
266                buffer_virt,
267                free,
268                last_used_idx: 0,
269            })
270        }
271    }
272
273    /// Performs the write operation.
274    fn write(&mut self, data: &[u8]) -> Result<(), &'static str> {
275        unsafe {
276            if self.free.is_empty() {
277                return Err("No free descriptors");
278            }
279            let desc_idx = self.free.pop().unwrap();
280
281            let desc = &mut *self.desc.add(desc_idx as usize);
282            core::ptr::copy_nonoverlapping(data.as_ptr(), self.buffer_virt, data.len());
283            desc.addr = self.buffer_phys;
284            desc.len = data.len() as u32;
285            desc.flags = 2;
286            desc.next = 0;
287
288            let avail = &mut *self.avail;
289            let idx = avail.idx as usize % VIRTIO_RING_SIZE;
290            avail.ring[idx] = desc_idx;
291            avail.idx = avail.idx.wrapping_add(1);
292        }
293        Ok(())
294    }
295
296    /// Performs the read operation.
297    fn read(&mut self, buf: &mut [u8]) -> Result<usize, &'static str> {
298        unsafe {
299            if self.last_used_idx == (*self.used).idx {
300                return Ok(0);
301            }
302
303            let idx = self.last_used_idx as usize % VIRTIO_RING_SIZE;
304            let elem = (*self.used).ring[idx];
305
306            let len = core::cmp::min(elem.len as usize, buf.len());
307            core::ptr::copy_nonoverlapping(self.buffer_virt, buf.as_mut_ptr(), len);
308
309            self.last_used_idx = self.last_used_idx.wrapping_add(1);
310            Ok(len)
311        }
312    }
313
314    /// Performs the poll used operation.
315    fn poll_used(&mut self) -> bool {
316        unsafe { self.last_used_idx != (*self.used).idx }
317    }
318}
319
320static CONSOLE_INSTANCE: Once<Arc<VirtioConsole>> = Once::new();
321static CONSOLE_INITIALIZED: AtomicBool = AtomicBool::new(false);
322
323/// Performs the init operation.
324pub fn init() {
325    log::info!("[VirtIO-Console] Scanning for VirtIO Console devices...");
326
327    let candidates = pci::probe_all(ProbeCriteria {
328        vendor_id: Some(pci::vendor::VIRTIO),
329        device_id: Some(pci::device::VIRTIO_CONSOLE),
330        class_code: None,
331        subclass: None,
332        prog_if: None,
333    });
334
335    for pci_dev in candidates.into_iter() {
336        log::info!(
337            "VirtIO-Console: Found device at {:?} (VEN:{:04x} DEV:{:04x})",
338            pci_dev.address,
339            pci_dev.vendor_id,
340            pci_dev.device_id
341        );
342
343        pci_dev.enable_bus_master();
344
345        match unsafe { VirtioConsole::new(pci_dev) } {
346            Ok(console) => {
347                let arc = Arc::new(console);
348                CONSOLE_INSTANCE.call_once(|| arc.clone());
349                CONSOLE_INITIALIZED.store(true, Ordering::SeqCst);
350                log::info!("[VirtIO-Console] Initialized");
351                return;
352            }
353            Err(e) => {
354                log::warn!("VirtIO-Console: Failed to initialize device: {}", e);
355            }
356        }
357    }
358
359    log::info!("[VirtIO-Console] No device found");
360}
361
362/// Performs the write operation.
363pub fn write(data: &[u8]) -> Result<usize, &'static str> {
364    CONSOLE_INSTANCE
365        .get()
366        .ok_or("Console not initialized")?
367        .write(data)
368}
369
370/// Performs the read operation.
371pub fn read(buf: &mut [u8]) -> Result<usize, &'static str> {
372    CONSOLE_INSTANCE
373        .get()
374        .ok_or("Console not initialized")?
375        .read(buf)
376}
377
378/// Returns whether available.
379pub fn is_available() -> bool {
380    CONSOLE_INITIALIZED.load(Ordering::Relaxed)
381}