Skip to main content

strat9_kernel/hardware/storage/
virtio_block.rs

1//! VirtIO Block Device driver
2//!
3//! Provides disk I/O via VirtIO-blk protocol for QEMU/KVM environments.
4//! Implements the BlockDevice trait for integration with filesystem layers.
5//!
6//! Reference: VirtIO spec v1.2, Section 5.2 (Block Device)
7
8use crate::{
9    arch::x86_64::pci::{self, PciDevice},
10    hardware::virtio::{
11        common::{VirtioDevice, Virtqueue},
12        status,
13    },
14    memory,
15    sync::SpinLock,
16};
17use alloc::{boxed::Box, vec::Vec};
18use core::{mem, ptr};
19
20/// Block device sector size
21pub const SECTOR_SIZE: usize = 512;
22
23/// VirtIO block device features
24pub mod features {
25    pub const VIRTIO_BLK_F_SIZE_MAX: u32 = 1 << 1;
26    pub const VIRTIO_BLK_F_SEG_MAX: u32 = 1 << 2;
27    pub const VIRTIO_BLK_F_GEOMETRY: u32 = 1 << 4;
28    pub const VIRTIO_BLK_F_RO: u32 = 1 << 5;
29    pub const VIRTIO_BLK_F_BLK_SIZE: u32 = 1 << 6;
30    pub const VIRTIO_BLK_F_FLUSH: u32 = 1 << 9;
31    pub const VIRTIO_BLK_F_TOPOLOGY: u32 = 1 << 10;
32    pub const VIRTIO_BLK_F_CONFIG_WCE: u32 = 1 << 11;
33    pub const VIRTIO_BLK_F_DISCARD: u32 = 1 << 13;
34    pub const VIRTIO_BLK_F_WRITE_ZEROES: u32 = 1 << 14;
35}
36
37/// VirtIO block request types
38#[allow(dead_code)]
39#[repr(u32)]
40pub enum RequestType {
41    /// Read from device
42    In = 0,
43    /// Write to device
44    Out = 1,
45    /// Flush write cache
46    Flush = 4,
47    /// Get device ID
48    GetId = 8,
49    /// Discard sectors
50    Discard = 11,
51    /// Write zeroes
52    WriteZeroes = 13,
53}
54
55/// VirtIO block request header
56#[repr(C)]
57#[derive(Debug, Clone, Copy)]
58pub struct BlockRequestHeader {
59    pub request_type: u32,
60    pub reserved: u32,
61    pub sector: u64,
62}
63
64/// VirtIO block request status
65#[repr(u8)]
66#[derive(Debug, Clone, Copy, PartialEq, Eq)]
67pub enum BlockStatus {
68    Ok = 0,
69    IoError = 1,
70    Unsupported = 2,
71}
72
73/// Block device configuration space
74#[repr(C)]
75#[allow(dead_code)]
76struct BlockConfig {
77    capacity: u64,
78    size_max: u32,
79    seg_max: u32,
80    geometry_cylinders: u16,
81    geometry_heads: u8,
82    geometry_sectors: u8,
83    blk_size: u32,
84    // ... other fields omitted for brevity
85}
86
87/// Block device trait (implemented by VirtIO-blk driver)
88pub trait BlockDevice {
89    /// Read sectors from the device
90    fn read_sector(&self, sector: u64, buf: &mut [u8]) -> Result<(), BlockError>;
91    /// Write sectors to the device
92    fn write_sector(&self, sector: u64, buf: &[u8]) -> Result<(), BlockError>;
93    /// Get the total number of sectors
94    fn sector_count(&self) -> u64;
95}
96
97#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
98pub enum BlockError {
99    #[error("device I/O error")]
100    IoError,
101    #[error("invalid sector number")]
102    InvalidSector,
103    #[error("buffer too small")]
104    BufferTooSmall,
105    #[error("device not ready")]
106    NotReady,
107}
108
109/// VirtIO Block Device driver
110pub struct VirtioBlockDevice {
111    device: VirtioDevice,
112    queue: SpinLock<Virtqueue>,
113    capacity: u64,
114}
115
116// Send and Sync are safe because we use SpinLocks
117unsafe impl Send for VirtioBlockDevice {}
118unsafe impl Sync for VirtioBlockDevice {}
119
120impl VirtioBlockDevice {
121    /// Initialize a VirtIO block device from a PCI device
122    ///
123    /// # Safety
124    /// The PCI device must be a valid VirtIO block device
125    pub unsafe fn new(pci_dev: PciDevice) -> Result<Self, &'static str> {
126        log::info!("VirtIO-blk: Initializing device at {:?}", pci_dev.address);
127
128        // Create VirtIO device
129        let device = VirtioDevice::new(pci_dev)?;
130
131        // Reset device
132        device.reset();
133
134        // Acknowledge device
135        device.add_status(status::ACKNOWLEDGE as u8);
136
137        // Indicate we know how to drive it
138        device.add_status(status::DRIVER as u8);
139
140        // Read and negotiate features
141        let device_features = device.read_device_features();
142        log::debug!("VirtIO-blk: Device features: 0x{:08x}", device_features);
143
144        // We don't need any special features for basic operation yet
145        let guest_features = 0;
146        device.write_guest_features(guest_features);
147
148        // Features OK
149        device.add_status(status::FEATURES_OK as u8);
150
151        // Verify features OK
152        if device.get_status() & (status::FEATURES_OK as u8) == 0 {
153            return Err("Device doesn't support our feature set");
154        }
155
156        // Legacy PCI VirtIO exposes a fixed queue size in QUEUE_NUM.
157        // The vring layout must match exactly what the device expects.
158        let queue_size = device.queue_max_size(0);
159        if queue_size == 0 {
160            return Err("VirtIO-blk queue 0 is unavailable");
161        }
162        log::info!("VirtIO-blk: queue 0 size = {}", queue_size);
163
164        // Create virtqueue (queue 0 is the request queue)
165        let queue = Virtqueue::new(queue_size)?;
166
167        // Setup queue with device
168        device.setup_queue(0, &queue);
169
170        // Driver ready
171        device.add_status(status::DRIVER_OK as u8);
172
173        // Read device capacity from config space (offset 0 in device-specific config)
174        // For legacy devices, device-specific config starts at offset 20 (after header)
175        // Or strictly speaking, after common config.
176        // Legacy VirtIO Header: 20 bytes.
177        // Block Config starts at offset 20.
178        let capacity_low = device.read_reg_u32(20);
179        let capacity_high = device.read_reg_u32(24);
180        let capacity = ((capacity_high as u64) << 32) | (capacity_low as u64);
181
182        log::info!(
183            "VirtIO-blk: Capacity: {} sectors ({} MB)",
184            capacity,
185            (capacity * SECTOR_SIZE as u64) / (1024 * 1024)
186        );
187
188        log::info!("VirtIO-blk: Device initialized successfully");
189
190        Ok(Self {
191            device,
192            queue: SpinLock::new(queue),
193            capacity,
194        })
195    }
196
197    /// Submit a block request and wait for completion
198    fn do_request(
199        &self,
200        request_type: RequestType,
201        sector: u64,
202        mut data_buf: Option<(&mut [u8], bool)>, // (buffer, is_write)
203    ) -> Result<(), BlockError> {
204        // Allocate request header and status byte
205        // We use a single frame for both if possible, or small allocations?
206        // To be safe and simple with the frame allocator, we'll alloc a frame.
207        // In a real optimized driver, we would have a slab allocator or pre-allocated pool.
208
209        let metadata_frame = crate::sync::with_irqs_disabled(|token| memory::allocate_frame(token))
210            .map_err(|_| BlockError::NotReady)?;
211
212        let metadata_phys = metadata_frame.start_address.as_u64();
213        let metadata_virt = crate::memory::phys_to_virt(metadata_phys);
214        let status_offset = mem::size_of::<BlockRequestHeader>() as u64;
215
216        // Layout: [Header (16 bytes)] ... [Status (1 byte)]
217        let header_ptr = metadata_virt as *mut BlockRequestHeader;
218        let status_ptr = (metadata_virt + status_offset) as *mut u8;
219
220        // Setup request header (CPU access → virtual address)
221        unsafe {
222            ptr::write(
223                header_ptr,
224                BlockRequestHeader {
225                    request_type: request_type as u32,
226                    reserved: 0,
227                    sector,
228                },
229            );
230            ptr::write(status_ptr, 0xFF); // Initialize with invalid status
231        }
232
233        // Build descriptor chain (DMA → physical addresses)
234        let mut buffers = Vec::with_capacity(3);
235
236        // 1. Header (Device Readable) : physical addr for DMA
237        buffers.push((
238            metadata_phys,
239            mem::size_of::<BlockRequestHeader>() as u32,
240            false,
241        ));
242
243        // 2. Data (Device Readable OR Writable)
244        // If data_buf is provided
245        let mut data_frame_info = None;
246
247        if let Some((buf, is_write)) = data_buf.as_mut() {
248            // We need a physically contiguous buffer for DMA
249            // For now, we allocate a bounce buffer.
250            // TODO: Support scatter-gather if the input buffer crosses page boundaries or isn't physical.
251
252            let buf_size = buf.len();
253            let buf_pages = (buf_size + 4095) / 4096;
254            let buf_order = buf_pages.next_power_of_two().trailing_zeros() as u8;
255
256            let buf_frame = crate::sync::with_irqs_disabled(|token| {
257                memory::allocate_phys_contiguous(token, buf_order)
258            })
259            .map_err(|_| {
260                crate::sync::with_irqs_disabled(|token| {
261                    memory::free_frame(token, metadata_frame);
262                });
263                BlockError::NotReady
264            })?;
265
266            let buf_phys = buf_frame.start_address.as_u64();
267            let buf_virt = crate::memory::phys_to_virt(buf_phys);
268            data_frame_info = Some((buf_frame, buf_order));
269
270            // If WRITE (Out): Copy from source buf to DMA bounce buffer (CPU access → virtual)
271            if *is_write {
272                unsafe {
273                    ptr::copy_nonoverlapping(buf.as_ptr(), buf_virt as *mut u8, buf_size);
274                }
275            }
276
277            // `is_write` param tells us if we are writing TO disk.
278            // Write to disk: device reads from memory (flags = 0)
279            // Read from disk: device writes to memory (flags = WRITE)
280            let device_writable = !*is_write;
281
282            // DMA → physical address
283            buffers.push((buf_phys, buf_size as u32, device_writable));
284        }
285
286        // 3. Status (Device Writable) : physical addr for DMA
287        buffers.push((metadata_phys + status_offset, 1, true));
288
289        // Submit request
290        let mut queue = self.queue.lock();
291        let token = match queue.add_buffer(&buffers) {
292            Ok(t) => t,
293            Err(_) => {
294                drop(queue);
295                // Cleanup
296                crate::sync::with_irqs_disabled(|token| {
297                    memory::free_frame(token, metadata_frame);
298                    if let Some((f, o)) = data_frame_info {
299                        memory::free_phys_contiguous(token, f, o);
300                    }
301                });
302                return Err(BlockError::IoError);
303            }
304        };
305
306        // Notify device
307        if queue.should_notify() {
308            self.device.notify_queue(0);
309        }
310        drop(queue);
311
312        // Wait for completion (busy-poll for now).
313        //
314        // IMPORTANT:
315        // Do not use HLT here. This path can run from syscall context where IF
316        // may be masked, and HLT would deadlock the CPU.
317        // TODO: replace with proper waitqueue + interrupt completion.
318        let mut spins = 0u32;
319        loop {
320            let mut queue = self.queue.lock();
321            if queue.has_used() {
322                // We don't check the token because we are single-threaded/blocking per device for now
323                // But to be correct we should find OUR token.
324                // virtio::common::Virtqueue::get_used currently pops the *next* used.
325                // If there are multiple in flight, we might pop someone else's.
326                // But here we are blocking, so only one in flight effectively.
327                if let Some((used_token, _len)) = queue.get_used() {
328                    if used_token == token {
329                        break;
330                    } else {
331                        // This shouldn't happen in single-threaded blocking mode
332                        // If it does, we just dropped someone else's completion.
333                        log::warn!("VirtIO-blk: Received unexpected token {}", used_token);
334                    }
335                }
336            }
337            let (used_idx, last_used_idx) = queue.used_indices();
338            drop(queue);
339            spins = spins.saturating_add(1);
340            if spins == 5_000_000 {
341                let isr = self.device.read_isr_status();
342                log::error!(
343                    "VirtIO-blk: request timeout sector={} token={} used_idx={} last_used_idx={} isr={}",
344                    sector,
345                    token,
346                    used_idx,
347                    last_used_idx,
348                    isr
349                );
350                crate::serial_println!(
351                    "[virtio-blk] timeout sector={} token={} used_idx={} last_used_idx={} isr={}",
352                    sector,
353                    token,
354                    used_idx,
355                    last_used_idx,
356                    isr
357                );
358                crate::sync::with_irqs_disabled(|token| {
359                    memory::free_frame(token, metadata_frame);
360                    if let Some((f, o)) = data_frame_info {
361                        memory::free_phys_contiguous(token, f, o);
362                    }
363                });
364                return Err(BlockError::IoError);
365            }
366            core::hint::spin_loop();
367        }
368
369        // Check status
370        let status_byte = unsafe { ptr::read(status_ptr) };
371
372        // Post-processing
373        if let Some((buf, is_write)) = data_buf {
374            if let Some((buf_frame, buf_order)) = data_frame_info {
375                let buf_virt = crate::memory::phys_to_virt(buf_frame.start_address.as_u64());
376
377                // If Read (In): Copy from DMA buf to destination buf (CPU access → virtual)
378                if !is_write && status_byte == BlockStatus::Ok as u8 {
379                    unsafe {
380                        ptr::copy_nonoverlapping(
381                            buf_virt as *const u8,
382                            buf.as_mut_ptr(),
383                            buf.len(),
384                        );
385                    }
386                }
387
388                // Free DMA buffer
389                crate::sync::with_irqs_disabled(|token| {
390                    memory::free_phys_contiguous(token, buf_frame, buf_order);
391                });
392            }
393        }
394
395        // Free metadata frame
396        crate::sync::with_irqs_disabled(|token| {
397            memory::free_frame(token, metadata_frame);
398        });
399
400        if status_byte == BlockStatus::Ok as u8 {
401            Ok(())
402        } else {
403            log::error!("VirtIO-blk: Request failed with status {}", status_byte);
404            Err(BlockError::IoError)
405        }
406    }
407}
408
409impl BlockDevice for VirtioBlockDevice {
410    /// Reads sector.
411    fn read_sector(&self, sector: u64, buf: &mut [u8]) -> Result<(), BlockError> {
412        if sector >= self.capacity {
413            return Err(BlockError::InvalidSector);
414        }
415
416        if buf.len() < SECTOR_SIZE {
417            return Err(BlockError::BufferTooSmall);
418        }
419
420        self.do_request(RequestType::In, sector, Some((buf, false)))
421    }
422
423    /// Writes sector.
424    fn write_sector(&self, sector: u64, buf: &[u8]) -> Result<(), BlockError> {
425        if sector >= self.capacity {
426            return Err(BlockError::InvalidSector);
427        }
428
429        if buf.len() < SECTOR_SIZE {
430            return Err(BlockError::BufferTooSmall);
431        }
432
433        // Need mutable buffer for internal DMA operations signature (though we won't modify it if is_write=true)
434        // Our do_request takes &mut [u8], so we need to either change do_request or cast.
435        // It's safer to copy the input slice to a temp buffer if we needed to, but here do_request copies it to DMA anyway.
436        // But do_request signature expects &mut [u8] because it handles both read and write.
437        // We can just cast const to mut since we know we won't write to it if is_write=true.
438        // Or better, let's fix do_request to take Option<(&mut [u8], Direction)>.
439        // For now, let's do a safe copy to avoid unsafe hacks.
440
441        let mut buf_copy = [0u8; SECTOR_SIZE];
442        buf_copy[..SECTOR_SIZE].copy_from_slice(&buf[..SECTOR_SIZE]);
443
444        self.do_request(RequestType::Out, sector, Some((&mut buf_copy, true)))
445    }
446
447    /// Performs the sector count operation.
448    fn sector_count(&self) -> u64 {
449        self.capacity
450    }
451}
452
453/// Global VirtIO block device
454static VIRTIO_BLOCK: SpinLock<Option<Box<VirtioBlockDevice>>> = SpinLock::new(None);
455
456/// VirtIO block IRQ line (will be set during init)
457static mut VIRTIO_BLOCK_IRQ: u8 = 0;
458
459/// Initialize VirtIO block device
460///
461/// Scans PCI bus for VirtIO block devices and initializes the first one found.
462pub fn init() {
463    log::info!("VirtIO-blk: Scanning for devices...");
464
465    // Prefer strict class-based probe (mass storage), with fallback to
466    // vendor+device for odd firmware/virtual setups.
467    let pci_dev = match pci::probe_first(pci::ProbeCriteria {
468        vendor_id: Some(pci::vendor::VIRTIO),
469        device_id: Some(pci::device::VIRTIO_BLOCK),
470        class_code: Some(pci::class::MASS_STORAGE),
471        subclass: None,
472        prog_if: None,
473    })
474    .or_else(|| pci::find_virtio_device(pci::device::VIRTIO_BLOCK))
475    {
476        Some(dev) => dev,
477        None => {
478            log::warn!("VirtIO-blk: No block device found");
479            return;
480        }
481    };
482
483    // Read interrupt line from PCI config
484    let irq_line = pci_dev.read_config_u8(pci::config::INTERRUPT_LINE);
485
486    // Initialize device
487    match unsafe { VirtioBlockDevice::new(pci_dev) } {
488        Ok(device) => {
489            // Store IRQ line for interrupt handler
490            unsafe {
491                VIRTIO_BLOCK_IRQ = irq_line;
492            }
493
494            // Register device
495            *VIRTIO_BLOCK.lock() = Some(Box::new(device));
496
497            // Register IRQ handler in IDT
498            crate::arch::x86_64::idt::register_virtio_block_irq(irq_line);
499
500            log::info!("VirtIO-blk: Device initialized on IRQ {}", irq_line);
501        }
502        Err(e) => {
503            log::error!("VirtIO-blk: Failed to initialize device: {}", e);
504        }
505    }
506}
507
508/// Handle VirtIO block device interrupt
509///
510/// Called from the IDT IRQ handler when the VirtIO device signals completion.
511/// Acknowledges the interrupt and processes completed requests.
512pub fn handle_interrupt() {
513    // Acknowledge the interrupt at the device level
514    let lock = VIRTIO_BLOCK.lock();
515    if let Some(device) = lock.as_ref() {
516        // Read ISR status to check if interrupt is for us
517        let isr_status = device.device.read_isr_status();
518        if isr_status != 0 {
519            // Acknowledge the interrupt
520            device.device.ack_interrupt();
521
522            // Process completed requests (wake up waiting tasks)
523            // For now, just log the interrupt
524            log::trace!("VirtIO-blk: Interrupt handled (ISR={})", isr_status);
525        }
526    }
527}
528
529/// Get the global VirtIO block device
530pub fn get_device() -> Option<&'static VirtioBlockDevice> {
531    unsafe {
532        let lock = VIRTIO_BLOCK.lock();
533        if lock.is_some() {
534            // This is slightly unsafe if the lock is dropped and the box is moved,
535            // but the static Option is never cleared in this kernel.
536            // A safer way is needed for production.
537            let ptr = &**lock.as_ref().unwrap() as *const VirtioBlockDevice;
538            Some(&*ptr)
539        } else {
540            None
541        }
542    }
543}
544
545/// Get the VirtIO block IRQ line
546pub fn get_irq() -> u8 {
547    unsafe { VIRTIO_BLOCK_IRQ }
548}