Skip to main content

strat9_kernel/hardware/storage/
virtio_block.rs

1//! VirtIO Block Device driver
2//!
3//! Provides disk I/O via VirtIO-blk protocol for QEMU/KVM environments.
4//! Implements the BlockDevice trait for integration with filesystem layers.
5//!
6//! Reference: VirtIO spec v1.2, Section 5.2 (Block Device)
7
8use crate::{
9    arch::x86_64::pci::{self, PciDevice},
10    hardware::virtio::{
11        common::{VirtioDevice, Virtqueue},
12        status,
13    },
14    memory,
15    sync::SpinLock,
16};
17use alloc::{boxed::Box, vec::Vec};
18use core::{mem, ptr};
19
20/// Block device sector size
21pub const SECTOR_SIZE: usize = 512;
22
23/// VirtIO block device features
24pub mod features {
25    pub const VIRTIO_BLK_F_SIZE_MAX: u32 = 1 << 1;
26    pub const VIRTIO_BLK_F_SEG_MAX: u32 = 1 << 2;
27    pub const VIRTIO_BLK_F_GEOMETRY: u32 = 1 << 4;
28    pub const VIRTIO_BLK_F_RO: u32 = 1 << 5;
29    pub const VIRTIO_BLK_F_BLK_SIZE: u32 = 1 << 6;
30    pub const VIRTIO_BLK_F_FLUSH: u32 = 1 << 9;
31    pub const VIRTIO_BLK_F_TOPOLOGY: u32 = 1 << 10;
32    pub const VIRTIO_BLK_F_CONFIG_WCE: u32 = 1 << 11;
33    pub const VIRTIO_BLK_F_DISCARD: u32 = 1 << 13;
34    pub const VIRTIO_BLK_F_WRITE_ZEROES: u32 = 1 << 14;
35}
36
37/// VirtIO block request types
38#[allow(dead_code)]
39#[repr(u32)]
40pub enum RequestType {
41    /// Read from device
42    In = 0,
43    /// Write to device
44    Out = 1,
45    /// Flush write cache
46    Flush = 4,
47    /// Get device ID
48    GetId = 8,
49    /// Discard sectors
50    Discard = 11,
51    /// Write zeroes
52    WriteZeroes = 13,
53}
54
55/// VirtIO block request header
56#[repr(C)]
57#[derive(Debug, Clone, Copy)]
58pub struct BlockRequestHeader {
59    pub request_type: u32,
60    pub reserved: u32,
61    pub sector: u64,
62}
63
64/// VirtIO block request status
65#[repr(u8)]
66#[derive(Debug, Clone, Copy, PartialEq, Eq)]
67pub enum BlockStatus {
68    Ok = 0,
69    IoError = 1,
70    Unsupported = 2,
71}
72
73/// Block device configuration space
74#[repr(C)]
75#[allow(dead_code)]
76struct BlockConfig {
77    capacity: u64,
78    size_max: u32,
79    seg_max: u32,
80    geometry_cylinders: u16,
81    geometry_heads: u8,
82    geometry_sectors: u8,
83    blk_size: u32,
84    // ... other fields omitted for brevity
85}
86
87/// Block device trait (implemented by VirtIO-blk driver)
88pub trait BlockDevice {
89    /// Read sectors from the device
90    fn read_sector(&self, sector: u64, buf: &mut [u8]) -> Result<(), BlockError>;
91    /// Write sectors to the device
92    fn write_sector(&self, sector: u64, buf: &[u8]) -> Result<(), BlockError>;
93    /// Get the total number of sectors
94    fn sector_count(&self) -> u64;
95}
96
97#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
98pub enum BlockError {
99    #[error("device I/O error")]
100    IoError,
101    #[error("invalid sector number")]
102    InvalidSector,
103    #[error("buffer too small")]
104    BufferTooSmall,
105    #[error("device not ready")]
106    NotReady,
107}
108
109/// VirtIO Block Device driver
110pub struct VirtioBlockDevice {
111    device: VirtioDevice,
112    queue: SpinLock<Virtqueue>,
113    capacity: u64,
114}
115
116// Send and Sync are safe because we use SpinLocks
117unsafe impl Send for VirtioBlockDevice {}
118unsafe impl Sync for VirtioBlockDevice {}
119
120impl VirtioBlockDevice {
121    /// Initialize a VirtIO block device from a PCI device
122    ///
123    /// # Safety
124    /// The PCI device must be a valid VirtIO block device
125    pub unsafe fn new(pci_dev: PciDevice) -> Result<Self, &'static str> {
126        log::info!("VirtIO-blk: Initializing device at {:?}", pci_dev.address);
127
128        // Create VirtIO device
129        let device = VirtioDevice::new(pci_dev)?;
130
131        // Reset device
132        device.reset();
133
134        // Acknowledge device
135        device.add_status(status::ACKNOWLEDGE as u8);
136
137        // Indicate we know how to drive it
138        device.add_status(status::DRIVER as u8);
139
140        // Read and negotiate features
141        let device_features = device.read_device_features();
142        log::debug!("VirtIO-blk: Device features: 0x{:08x}", device_features);
143
144        // We don't need any special features for basic operation yet
145        let guest_features = 0;
146        device.write_guest_features(guest_features);
147
148        // Features OK
149        device.add_status(status::FEATURES_OK as u8);
150
151        // Verify features OK
152        if device.get_status() & (status::FEATURES_OK as u8) == 0 {
153            return Err("Device doesn't support our feature set");
154        }
155
156        // Create virtqueue (queue 0 is the request queue)
157        let queue = Virtqueue::new(128)?;
158
159        // Setup queue with device
160        device.setup_queue(0, &queue);
161
162        // Driver ready
163        device.add_status(status::DRIVER_OK as u8);
164
165        // Read device capacity from config space (offset 0 in device-specific config)
166        // For legacy devices, device-specific config starts at offset 20 (after header)
167        // Or strictly speaking, after common config.
168        // Legacy VirtIO Header: 20 bytes.
169        // Block Config starts at offset 20.
170        let capacity_low = device.read_reg_u32(20);
171        let capacity_high = device.read_reg_u32(24);
172        let capacity = ((capacity_high as u64) << 32) | (capacity_low as u64);
173
174        log::info!(
175            "VirtIO-blk: Capacity: {} sectors ({} MB)",
176            capacity,
177            (capacity * SECTOR_SIZE as u64) / (1024 * 1024)
178        );
179
180        log::info!("VirtIO-blk: Device initialized successfully");
181
182        Ok(Self {
183            device,
184            queue: SpinLock::new(queue),
185            capacity,
186        })
187    }
188
189    /// Submit a block request and wait for completion
190    fn do_request(
191        &self,
192        request_type: RequestType,
193        sector: u64,
194        mut data_buf: Option<(&mut [u8], bool)>, // (buffer, is_write)
195    ) -> Result<(), BlockError> {
196        // Allocate request header and status byte
197        // We use a single frame for both if possible, or small allocations?
198        // To be safe and simple with the frame allocator, we'll alloc a frame.
199        // In a real optimized driver, we would have a slab allocator or pre-allocated pool.
200
201        let metadata_frame = crate::sync::with_irqs_disabled(|token| {
202            memory::allocate_frame(token)
203        })
204        .map_err(|_| BlockError::NotReady)?;
205
206        let metadata_phys = metadata_frame.start_address.as_u64();
207        let metadata_virt = crate::memory::phys_to_virt(metadata_phys);
208        let status_offset = mem::size_of::<BlockRequestHeader>() as u64;
209
210        // Layout: [Header (16 bytes)] ... [Status (1 byte)]
211        let header_ptr = metadata_virt as *mut BlockRequestHeader;
212        let status_ptr = (metadata_virt + status_offset) as *mut u8;
213
214        // Setup request header (CPU access → virtual address)
215        unsafe {
216            ptr::write(
217                header_ptr,
218                BlockRequestHeader {
219                    request_type: request_type as u32,
220                    reserved: 0,
221                    sector,
222                },
223            );
224            ptr::write(status_ptr, 0xFF); // Initialize with invalid status
225        }
226
227        // Build descriptor chain (DMA → physical addresses)
228        let mut buffers = Vec::with_capacity(3);
229
230        // 1. Header (Device Readable) — physical addr for DMA
231        buffers.push((
232            metadata_phys,
233            mem::size_of::<BlockRequestHeader>() as u32,
234            false,
235        ));
236
237        // 2. Data (Device Readable OR Writable)
238        // If data_buf is provided
239        let mut data_frame_info = None;
240
241        if let Some((buf, is_write)) = data_buf.as_mut() {
242            // We need a physically contiguous buffer for DMA
243            // For now, we allocate a bounce buffer.
244            // TODO: Support scatter-gather if the input buffer crosses page boundaries or isn't physical.
245
246            let buf_size = buf.len();
247            let buf_pages = (buf_size + 4095) / 4096;
248            let buf_order = buf_pages.next_power_of_two().trailing_zeros() as u8;
249
250            let buf_frame = crate::sync::with_irqs_disabled(|token| {
251                memory::allocate_frames(token, buf_order)
252            })
253            .map_err(|_| {
254                crate::sync::with_irqs_disabled(|token| {
255                    memory::free_frame(token, metadata_frame);
256                });
257                BlockError::NotReady
258            })?;
259
260            let buf_phys = buf_frame.start_address.as_u64();
261            let buf_virt = crate::memory::phys_to_virt(buf_phys);
262            data_frame_info = Some((buf_frame, buf_order));
263
264            // If WRITE (Out): Copy from source buf to DMA bounce buffer (CPU access → virtual)
265            if *is_write {
266                unsafe {
267                    ptr::copy_nonoverlapping(buf.as_ptr(), buf_virt as *mut u8, buf_size);
268                }
269            }
270
271            // `is_write` param tells us if we are writing TO disk.
272            // Write to disk: device reads from memory (flags = 0)
273            // Read from disk: device writes to memory (flags = WRITE)
274            let device_writable = !*is_write;
275
276            // DMA → physical address
277            buffers.push((buf_phys, buf_size as u32, device_writable));
278        }
279
280        // 3. Status (Device Writable) — physical addr for DMA
281        buffers.push((metadata_phys + status_offset, 1, true));
282
283        // Submit request
284        let mut queue = self.queue.lock();
285        let token = match queue.add_buffer(&buffers) {
286            Ok(t) => t,
287            Err(_) => {
288                drop(queue);
289                // Cleanup
290                crate::sync::with_irqs_disabled(|token| {
291                    memory::free_frame(token, metadata_frame);
292                    if let Some((f, o)) = data_frame_info {
293                        memory::free_frames(token, f, o);
294                    }
295                });
296                return Err(BlockError::IoError);
297            }
298        };
299
300        // Notify device
301        if queue.should_notify() {
302            self.device.notify_queue(0);
303        }
304        drop(queue);
305
306        // Wait for completion (busy-poll for now).
307        //
308        // IMPORTANT:
309        // Do not use HLT here. This path can run from syscall context where IF
310        // may be masked, and HLT would deadlock the CPU.
311        // TODO: replace with proper waitqueue + interrupt completion.
312        loop {
313            let mut queue = self.queue.lock();
314            if queue.has_used() {
315                // We don't check the token because we are single-threaded/blocking per device for now
316                // But to be correct we should find OUR token.
317                // virtio::common::Virtqueue::get_used currently pops the *next* used.
318                // If there are multiple in flight, we might pop someone else's.
319                // But here we are blocking, so only one in flight effectively.
320                if let Some((used_token, _len)) = queue.get_used() {
321                    if used_token == token {
322                        break;
323                    } else {
324                        // This shouldn't happen in single-threaded blocking mode
325                        // If it does, we just dropped someone else's completion.
326                        log::warn!("VirtIO-blk: Received unexpected token {}", used_token);
327                    }
328                }
329            }
330            drop(queue);
331            core::hint::spin_loop();
332        }
333
334        // Check status
335        let status_byte = unsafe { ptr::read(status_ptr) };
336
337        // Post-processing
338        if let Some((buf, is_write)) = data_buf {
339            if let Some((buf_frame, buf_order)) = data_frame_info {
340                let buf_virt = crate::memory::phys_to_virt(buf_frame.start_address.as_u64());
341
342                // If Read (In): Copy from DMA buf to destination buf (CPU access → virtual)
343                if !is_write && status_byte == BlockStatus::Ok as u8 {
344                    unsafe {
345                        ptr::copy_nonoverlapping(
346                            buf_virt as *const u8,
347                            buf.as_mut_ptr(),
348                            buf.len(),
349                        );
350                    }
351                }
352
353                // Free DMA buffer
354                crate::sync::with_irqs_disabled(|token| {
355                    memory::free_frames(token, buf_frame, buf_order);
356                });
357            }
358        }
359
360        // Free metadata frame
361        crate::sync::with_irqs_disabled(|token| {
362            memory::free_frame(token, metadata_frame);
363        });
364
365        if status_byte == BlockStatus::Ok as u8 {
366            Ok(())
367        } else {
368            log::error!("VirtIO-blk: Request failed with status {}", status_byte);
369            Err(BlockError::IoError)
370        }
371    }
372}
373
374impl BlockDevice for VirtioBlockDevice {
375    /// Reads sector.
376    fn read_sector(&self, sector: u64, buf: &mut [u8]) -> Result<(), BlockError> {
377        if sector >= self.capacity {
378            return Err(BlockError::InvalidSector);
379        }
380
381        if buf.len() < SECTOR_SIZE {
382            return Err(BlockError::BufferTooSmall);
383        }
384
385        self.do_request(RequestType::In, sector, Some((buf, false)))
386    }
387
388    /// Writes sector.
389    fn write_sector(&self, sector: u64, buf: &[u8]) -> Result<(), BlockError> {
390        if sector >= self.capacity {
391            return Err(BlockError::InvalidSector);
392        }
393
394        if buf.len() < SECTOR_SIZE {
395            return Err(BlockError::BufferTooSmall);
396        }
397
398        // Need mutable buffer for internal DMA operations signature (though we won't modify it if is_write=true)
399        // Our do_request takes &mut [u8], so we need to either change do_request or cast.
400        // It's safer to copy the input slice to a temp buffer if we needed to, but here do_request copies it to DMA anyway.
401        // But do_request signature expects &mut [u8] because it handles both read and write.
402        // We can just cast const to mut since we know we won't write to it if is_write=true.
403        // Or better, let's fix do_request to take Option<(&mut [u8], Direction)>.
404        // For now, let's do a safe copy to avoid unsafe hacks.
405
406        let mut buf_copy = [0u8; SECTOR_SIZE];
407        buf_copy[..SECTOR_SIZE].copy_from_slice(&buf[..SECTOR_SIZE]);
408
409        self.do_request(RequestType::Out, sector, Some((&mut buf_copy, true)))
410    }
411
412    /// Performs the sector count operation.
413    fn sector_count(&self) -> u64 {
414        self.capacity
415    }
416}
417
418/// Global VirtIO block device
419static VIRTIO_BLOCK: SpinLock<Option<Box<VirtioBlockDevice>>> = SpinLock::new(None);
420
421/// VirtIO block IRQ line (will be set during init)
422static mut VIRTIO_BLOCK_IRQ: u8 = 0;
423
424/// Initialize VirtIO block device
425///
426/// Scans PCI bus for VirtIO block devices and initializes the first one found.
427pub fn init() {
428    log::info!("VirtIO-blk: Scanning for devices...");
429
430    // Prefer strict class-based probe (mass storage), with fallback to
431    // vendor+device for odd firmware/virtual setups.
432    let pci_dev = match pci::probe_first(pci::ProbeCriteria {
433        vendor_id: Some(pci::vendor::VIRTIO),
434        device_id: Some(pci::device::VIRTIO_BLOCK),
435        class_code: Some(pci::class::MASS_STORAGE),
436        subclass: None,
437        prog_if: None,
438    })
439    .or_else(|| pci::find_virtio_device(pci::device::VIRTIO_BLOCK))
440    {
441        Some(dev) => dev,
442        None => {
443            log::warn!("VirtIO-blk: No block device found");
444            return;
445        }
446    };
447
448    // Read interrupt line from PCI config
449    let irq_line = pci_dev.read_config_u8(pci::config::INTERRUPT_LINE);
450
451    // Initialize device
452    match unsafe { VirtioBlockDevice::new(pci_dev) } {
453        Ok(device) => {
454            // Store IRQ line for interrupt handler
455            unsafe {
456                VIRTIO_BLOCK_IRQ = irq_line;
457            }
458
459            // Register device
460            *VIRTIO_BLOCK.lock() = Some(Box::new(device));
461
462            // Register IRQ handler in IDT
463            crate::arch::x86_64::idt::register_virtio_block_irq(irq_line);
464
465            log::info!("VirtIO-blk: Device initialized on IRQ {}", irq_line);
466        }
467        Err(e) => {
468            log::error!("VirtIO-blk: Failed to initialize device: {}", e);
469        }
470    }
471}
472
473/// Handle VirtIO block device interrupt
474///
475/// Called from the IDT IRQ handler when the VirtIO device signals completion.
476/// Acknowledges the interrupt and processes completed requests.
477pub fn handle_interrupt() {
478    // Acknowledge the interrupt at the device level
479    let lock = VIRTIO_BLOCK.lock();
480    if let Some(device) = lock.as_ref() {
481        // Read ISR status to check if interrupt is for us
482        let isr_status = device.device.read_isr_status();
483        if isr_status != 0 {
484            // Acknowledge the interrupt
485            device.device.ack_interrupt();
486
487            // Process completed requests (wake up waiting tasks)
488            // For now, just log the interrupt
489            log::trace!("VirtIO-blk: Interrupt handled (ISR={})", isr_status);
490        }
491    }
492}
493
494/// Get the global VirtIO block device
495pub fn get_device() -> Option<&'static VirtioBlockDevice> {
496    unsafe {
497        let lock = VIRTIO_BLOCK.lock();
498        if lock.is_some() {
499            // This is slightly unsafe if the lock is dropped and the box is moved,
500            // but the static Option is never cleared in this kernel.
501            // A safer way is needed for production.
502            let ptr = &**lock.as_ref().unwrap() as *const VirtioBlockDevice;
503            Some(&*ptr)
504        } else {
505            None
506        }
507    }
508}
509
510/// Get the VirtIO block IRQ line
511pub fn get_irq() -> u8 {
512    unsafe { VIRTIO_BLOCK_IRQ }
513}