Skip to main content

strat9_kernel/hardware/nic/
virtio_net.rs

1//! VirtIO Network Device driver
2//!
3//! Provides network I/O via VirtIO-net protocol for QEMU/KVM environments.
4//! Implements the common [`crate::hardware::nic::NetworkDevice`] trait so
5//! this driver plugs into the unified `/dev/net/` scheme.
6//!
7//! Reference: VirtIO spec v1.2, Section 5.1 (Network Device)
8
9use crate::{
10    arch::x86_64::pci::{self, PciDevice},
11    hardware::{
12        nic as net,
13        virtio::{
14            common::{VirtioDevice, Virtqueue},
15            status,
16        },
17    },
18    memory::{self, PhysFrame},
19    sync::{FixedQueue, SpinLock},
20};
21use alloc::sync::Arc;
22use core::{mem, ptr};
23use net_core::{NetError, NetworkDevice};
24use spin::RwLock as SpinRwLock;
25
26/// VirtIO net header size
27const NET_HDR_SIZE: usize = mem::size_of::<VirtioNetHeader>();
28const RX_FRAME_TRACK_CAPACITY: usize = 128;
29
30/// VirtIO net device features
31pub mod features {
32    pub const VIRTIO_NET_F_CSUM: u32 = 1 << 0;
33    pub const VIRTIO_NET_F_GUEST_CSUM: u32 = 1 << 1;
34    pub const VIRTIO_NET_F_MAC: u32 = 1 << 5;
35    pub const VIRTIO_NET_F_GSO: u32 = 1 << 6;
36    pub const VIRTIO_NET_F_GUEST_TSO4: u32 = 1 << 7;
37    pub const VIRTIO_NET_F_GUEST_TSO6: u32 = 1 << 8;
38    pub const VIRTIO_NET_F_GUEST_ECN: u32 = 1 << 9;
39    pub const VIRTIO_NET_F_GUEST_UFO: u32 = 1 << 10;
40    pub const VIRTIO_NET_F_HOST_TSO4: u32 = 1 << 11;
41    pub const VIRTIO_NET_F_HOST_TSO6: u32 = 1 << 12;
42    pub const VIRTIO_NET_F_HOST_ECN: u32 = 1 << 13;
43    pub const VIRTIO_NET_F_HOST_UFO: u32 = 1 << 14;
44    pub const VIRTIO_NET_F_MRG_RXBUF: u32 = 1 << 15;
45    pub const VIRTIO_NET_F_STATUS: u32 = 1 << 16;
46    pub const VIRTIO_NET_F_CTRL_VQ: u32 = 1 << 17;
47    pub const VIRTIO_NET_F_CTRL_RX: u32 = 1 << 18;
48    pub const VIRTIO_NET_F_CTRL_VLAN: u32 = 1 << 19;
49    pub const VIRTIO_NET_F_GUEST_ANNOUNCE: u32 = 1 << 21;
50    pub const VIRTIO_NET_F_MQ: u32 = 1 << 22;
51}
52
53/// VirtIO net status flags
54pub mod net_status {
55    pub const VIRTIO_NET_S_LINK_UP: u16 = 1;
56    pub const VIRTIO_NET_S_ANNOUNCE: u16 = 2;
57}
58
59/// VirtIO net header (prepended to every packet)
60#[repr(C)]
61#[derive(Debug, Clone, Copy, Default)]
62pub struct VirtioNetHeader {
63    pub flags: u8,
64    pub gso_type: u8,
65    pub hdr_len: u16,
66    pub gso_size: u16,
67    pub csum_start: u16,
68    pub csum_offset: u16,
69    pub num_buffers: u16,
70}
71
72/// VirtIO Network Device driver
73pub struct VirtioNetDevice {
74    device: VirtioDevice,
75    rx_queue: SpinLock<Virtqueue>,
76    tx_queue: SpinLock<Virtqueue>,
77    mac_address: [u8; 6],
78    pub rx_frames: SpinLock<FixedQueue<(PhysFrame, u8), RX_FRAME_TRACK_CAPACITY>>, // Track allocated RX frames
79}
80
81// Send and Sync are safe because we use SpinLocks
82unsafe impl Send for VirtioNetDevice {}
83unsafe impl Sync for VirtioNetDevice {}
84
85impl VirtioNetDevice {
86    /// Initialize a VirtIO network device from a PCI device
87    pub unsafe fn new(pci_dev: PciDevice) -> Result<Self, &'static str> {
88        log::info!("VirtIO-net: Initializing device at {:?}", pci_dev.address);
89
90        // Create VirtIO device
91        let device = VirtioDevice::new(pci_dev)?;
92
93        // Reset device
94        device.reset();
95
96        // Acknowledge device
97        device.add_status(status::ACKNOWLEDGE as u8);
98
99        // Indicate we know how to drive it
100        device.add_status(status::DRIVER as u8);
101
102        // Read and negotiate features
103        let _device_features = device.read_device_features();
104
105        // Request MAC address feature
106        let guest_features = features::VIRTIO_NET_F_MAC | features::VIRTIO_NET_F_STATUS;
107        device.write_guest_features(guest_features);
108
109        // Features OK
110        device.add_status(status::FEATURES_OK as u8);
111
112        // Verify features OK
113        if device.get_status() & (status::FEATURES_OK as u8) == 0 {
114            return Err("Device doesn't support our feature set");
115        }
116
117        // Create virtqueues
118        // Queue 0: RX (receive)
119        // Queue 1: TX (transmit)
120        let rx_queue = Virtqueue::new(128)?;
121        let tx_queue = Virtqueue::new(128)?;
122
123        // Setup queues with device
124        device.setup_queue(0, &rx_queue);
125        device.setup_queue(1, &tx_queue);
126
127        // Read MAC address from device config space
128        // For legacy devices, MAC is at offset 20 + 0
129        let mut mac_address = [0u8; 6];
130        for i in 0..6 {
131            mac_address[i] = device.read_reg_u8(20 + i as u16);
132        }
133
134        log::info!(
135            "VirtIO-net: MAC address: {:02x}:{:02x}:{:02x}:{:02x}:{:02x}:{:02x}",
136            mac_address[0],
137            mac_address[1],
138            mac_address[2],
139            mac_address[3],
140            mac_address[4],
141            mac_address[5]
142        );
143
144        // Driver ready
145        device.add_status(status::DRIVER_OK as u8);
146
147        let net_device = Self {
148            device,
149            rx_queue: SpinLock::new(rx_queue),
150            tx_queue: SpinLock::new(tx_queue),
151            mac_address,
152            rx_frames: SpinLock::new(FixedQueue::new()),
153        };
154
155        // Fill RX queue with buffers
156        net_device.refill_rx_queue()?;
157
158        Ok(net_device)
159    }
160
161    /// Fill the RX queue with receive buffers
162    fn refill_rx_queue(&self) -> Result<(), &'static str> {
163        let mut rx_queue = self.rx_queue.lock();
164        let mut rx_frames = self.rx_frames.lock();
165
166        // We want to keep some buffers in the RX queue
167        let current_filled = rx_frames.len();
168        let target_filled = 64;
169        let mut added = 0usize;
170
171        if current_filled >= target_filled {
172            return Ok(());
173        }
174
175        for _ in 0..(target_filled - current_filled) {
176            // Allocate buffer for header + MTU
177            let buf_size = NET_HDR_SIZE + net::MTU;
178            let buf_pages = (buf_size + 4095) / 4096;
179            let buf_order = buf_pages.next_power_of_two().trailing_zeros() as u8;
180
181            let buf_frame = match crate::sync::with_irqs_disabled(|token| {
182                memory::allocate_phys_contiguous(token, buf_order)
183            }) {
184                Ok(frame) => frame,
185                Err(_) => break, // No more memory available
186            };
187
188            let buf_addr = buf_frame.start_address.as_u64();
189            let virt_addr = crate::memory::phys_to_virt(buf_addr);
190
191            // Zero the buffer (header needs to be zeroed mostly)
192            unsafe {
193                ptr::write_bytes(virt_addr as *mut u8, 0, buf_size);
194            }
195
196            // Add buffer to RX queue (device Writable)
197            match rx_queue.add_buffer(&[(buf_addr, buf_size as u32, true)]) {
198                Ok(_) => {
199                    if rx_frames.push_back((buf_frame, buf_order)).is_err() {
200                        crate::sync::with_irqs_disabled(|token| {
201                            memory::free_phys_contiguous(token, buf_frame, buf_order);
202                        });
203                        break;
204                    }
205                    added += 1;
206                }
207                Err(_) => {
208                    // Queue full, free the buffer
209                    crate::sync::with_irqs_disabled(|token| {
210                        memory::free_phys_contiguous(token, buf_frame, buf_order);
211                    });
212                    break;
213                }
214            }
215        }
216
217        // Notify device about new RX buffers
218        if rx_queue.should_notify() {
219            self.device.notify_queue(0);
220        }
221
222        if rx_frames.is_empty() && current_filled == 0 && added == 0 {
223            return Err("Failed to allocate RX buffers");
224        }
225
226        Ok(())
227    }
228
229    /// Read link status from device
230    fn read_link_status(&self) -> u16 {
231        // Status is at offset 6 in device-specific config (offset 20 + 6 = 26)
232        self.device.read_reg_u16(26)
233    }
234}
235
236impl NetworkDevice for VirtioNetDevice {
237    /// Performs the name operation.
238    fn name(&self) -> &str {
239        "virtio-net"
240    }
241
242    /// Performs the receive operation.
243    fn receive(&self, buf: &mut [u8]) -> Result<usize, NetError> {
244        let mut rx_queue = self.rx_queue.lock();
245
246        // Check if there's a used buffer
247        if !rx_queue.has_used() {
248            return Err(NetError::NoPacket);
249        }
250
251        let (token, len) = rx_queue.get_used().ok_or(NetError::NoPacket)?;
252
253        let _desc_index = token as usize;
254        let _desc_table = rx_queue.desc_area(); // Physical address
255
256        let (frame, order) = self
257            .rx_frames
258            .lock()
259            .pop_front()
260            .ok_or(NetError::NotReady)?;
261
262        let buf_addr = frame.start_address.as_u64();
263        let virt_addr = crate::memory::phys_to_virt(buf_addr);
264
265        // Check if token matches what we expect?
266        // We can't easily without reading the descriptor.
267
268        let header_ptr = virt_addr as *const VirtioNetHeader;
269        let data_ptr = (virt_addr + NET_HDR_SIZE as u64) as *const u8;
270
271        let _header = unsafe { ptr::read(header_ptr) };
272        let packet_len = (len as usize).saturating_sub(NET_HDR_SIZE);
273
274        if buf.len() < packet_len {
275            // Buffer too small, packet lost
276            crate::sync::with_irqs_disabled(|token| {
277                memory::free_phys_contiguous(token, frame, order);
278            });
279            drop(rx_queue);
280            // We still need to refill.
281            let _ = self.refill_rx_queue();
282            return Err(NetError::BufferTooSmall);
283        }
284
285        // Copy packet data
286        if packet_len > 0 {
287            unsafe {
288                ptr::copy_nonoverlapping(data_ptr, buf.as_mut_ptr(), packet_len);
289            }
290        }
291
292        // Free the frame
293        crate::sync::with_irqs_disabled(|token| {
294            memory::free_phys_contiguous(token, frame, order);
295        });
296        drop(rx_queue);
297
298        // Refill RX queue
299        let _ = self.refill_rx_queue();
300
301        Ok(packet_len)
302    }
303
304    /// Performs the transmit operation.
305    fn transmit(&self, buf: &[u8]) -> Result<(), NetError> {
306        if buf.len() > net::MTU {
307            return Err(NetError::BufferTooSmall);
308        }
309
310        // Allocate TX buffer (header + data)
311        let buf_size = NET_HDR_SIZE + buf.len();
312        let buf_pages = (buf_size + 4095) / 4096;
313        let buf_order = buf_pages.next_power_of_two().trailing_zeros() as u8;
314
315        let buf_frame = crate::sync::with_irqs_disabled(|token| {
316            memory::allocate_phys_contiguous(token, buf_order)
317        })
318        .map_err(|_| NetError::NotReady)?;
319
320        let buf_addr = buf_frame.start_address.as_u64();
321        let virt_addr = crate::memory::phys_to_virt(buf_addr);
322
323        let header_ptr = virt_addr as *mut VirtioNetHeader;
324        let data_ptr = (virt_addr + NET_HDR_SIZE as u64) as *mut u8;
325
326        // Write header
327        unsafe {
328            ptr::write(header_ptr, VirtioNetHeader::default());
329            ptr::copy_nonoverlapping(buf.as_ptr(), data_ptr, buf.len());
330        }
331
332        // Submit to TX queue
333        let mut tx_queue = self.tx_queue.lock();
334        let token = tx_queue
335            .add_buffer(&[(buf_addr, buf_size as u32, false)]) // Device Readable
336            .map_err(|_| {
337                // Free buffer if failed
338                crate::sync::with_irqs_disabled(|token| {
339                    memory::free_phys_contiguous(token, buf_frame, buf_order);
340                });
341                NetError::TxQueueFull
342            })?;
343
344        if tx_queue.should_notify() {
345            self.device.notify_queue(1);
346        }
347        drop(tx_queue);
348
349        // Wait for completion (simple spin for now)
350        loop {
351            let mut tx_queue = self.tx_queue.lock();
352            if tx_queue.has_used() {
353                if let Some((used_token, _)) = tx_queue.get_used() {
354                    // Assuming correct order for now or just waiting for *any* completion which matches ours
355                    if used_token == token {
356                        break;
357                    }
358                }
359            }
360            drop(tx_queue);
361            core::hint::spin_loop();
362        }
363
364        // Free TX buffer
365        crate::sync::with_irqs_disabled(|token| {
366            memory::free_phys_contiguous(token, buf_frame, buf_order);
367        });
368
369        Ok(())
370    }
371
372    /// Performs the mac address operation.
373    fn mac_address(&self) -> [u8; 6] {
374        self.mac_address
375    }
376
377    /// Performs the link up operation.
378    fn link_up(&self) -> bool {
379        let status = self.read_link_status();
380        status & net_status::VIRTIO_NET_S_LINK_UP != 0
381    }
382}
383
384/// Global VirtIO network device
385static VIRTIO_NET: SpinRwLock<Option<Arc<VirtioNetDevice>>> = SpinRwLock::new(None);
386
387/// Initialize VirtIO network device and register it in the global net registry.
388pub fn init() {
389    log::info!("VirtIO-net: Scanning for devices...");
390
391    // Prefer strict class-based probe (network/ethernet), with fallback to
392    // vendor+device for odd firmware/virtual setups.
393    let pci_dev = match pci::probe_first(pci::ProbeCriteria {
394        vendor_id: Some(pci::vendor::VIRTIO),
395        device_id: Some(pci::device::VIRTIO_NET),
396        class_code: Some(pci::class::NETWORK),
397        subclass: Some(pci::net_subclass::ETHERNET),
398        prog_if: None,
399    })
400    .or_else(|| pci::find_virtio_device(pci::device::VIRTIO_NET))
401    {
402        Some(dev) => dev,
403        None => {
404            log::warn!("VirtIO-net: No network device found");
405            return;
406        }
407    };
408
409    match unsafe { VirtioNetDevice::new(pci_dev) } {
410        Ok(device) => {
411            let arc = Arc::new(device);
412            *VIRTIO_NET.write() = Some(arc.clone());
413            net::register_device(arc);
414        }
415        Err(e) => {
416            log::error!("VirtIO-net: Failed to initialize device: {}", e);
417        }
418    }
419}
420
421/// Get the VirtIO network device instance (if present).
422pub fn get_device() -> Option<Arc<VirtioNetDevice>> {
423    VIRTIO_NET.read().clone()
424}