Skip to main content

strat9_kernel/hardware/nic/
virtio_net.rs

1//! VirtIO Network Device driver
2//!
3//! Provides network I/O via VirtIO-net protocol for QEMU/KVM environments.
4//! Implements the common [`crate::hardware::nic::NetworkDevice`] trait so
5//! this driver plugs into the unified `/dev/net/` scheme.
6//!
7//! Reference: VirtIO spec v1.2, Section 5.1 (Network Device)
8
9use crate::{
10    arch::x86_64::pci::{self, PciDevice},
11    hardware::{
12        nic as net,
13        virtio::{
14            common::{VirtioDevice, Virtqueue},
15            status,
16        },
17    },
18    memory::{self, PhysFrame},
19    sync::SpinLock,
20};
21use alloc::{collections::VecDeque, sync::Arc};
22use core::{mem, ptr};
23use net_core::{NetError, NetworkDevice};
24use spin::RwLock as SpinRwLock;
25
26/// VirtIO net header size
27const NET_HDR_SIZE: usize = mem::size_of::<VirtioNetHeader>();
28
29/// VirtIO net device features
30pub mod features {
31    pub const VIRTIO_NET_F_CSUM: u32 = 1 << 0;
32    pub const VIRTIO_NET_F_GUEST_CSUM: u32 = 1 << 1;
33    pub const VIRTIO_NET_F_MAC: u32 = 1 << 5;
34    pub const VIRTIO_NET_F_GSO: u32 = 1 << 6;
35    pub const VIRTIO_NET_F_GUEST_TSO4: u32 = 1 << 7;
36    pub const VIRTIO_NET_F_GUEST_TSO6: u32 = 1 << 8;
37    pub const VIRTIO_NET_F_GUEST_ECN: u32 = 1 << 9;
38    pub const VIRTIO_NET_F_GUEST_UFO: u32 = 1 << 10;
39    pub const VIRTIO_NET_F_HOST_TSO4: u32 = 1 << 11;
40    pub const VIRTIO_NET_F_HOST_TSO6: u32 = 1 << 12;
41    pub const VIRTIO_NET_F_HOST_ECN: u32 = 1 << 13;
42    pub const VIRTIO_NET_F_HOST_UFO: u32 = 1 << 14;
43    pub const VIRTIO_NET_F_MRG_RXBUF: u32 = 1 << 15;
44    pub const VIRTIO_NET_F_STATUS: u32 = 1 << 16;
45    pub const VIRTIO_NET_F_CTRL_VQ: u32 = 1 << 17;
46    pub const VIRTIO_NET_F_CTRL_RX: u32 = 1 << 18;
47    pub const VIRTIO_NET_F_CTRL_VLAN: u32 = 1 << 19;
48    pub const VIRTIO_NET_F_GUEST_ANNOUNCE: u32 = 1 << 21;
49    pub const VIRTIO_NET_F_MQ: u32 = 1 << 22;
50}
51
52/// VirtIO net status flags
53pub mod net_status {
54    pub const VIRTIO_NET_S_LINK_UP: u16 = 1;
55    pub const VIRTIO_NET_S_ANNOUNCE: u16 = 2;
56}
57
58/// VirtIO net header (prepended to every packet)
59#[repr(C)]
60#[derive(Debug, Clone, Copy, Default)]
61pub struct VirtioNetHeader {
62    pub flags: u8,
63    pub gso_type: u8,
64    pub hdr_len: u16,
65    pub gso_size: u16,
66    pub csum_start: u16,
67    pub csum_offset: u16,
68    pub num_buffers: u16,
69}
70
71/// VirtIO Network Device driver
72pub struct VirtioNetDevice {
73    device: VirtioDevice,
74    rx_queue: SpinLock<Virtqueue>,
75    tx_queue: SpinLock<Virtqueue>,
76    mac_address: [u8; 6],
77    pub rx_frames: SpinLock<VecDeque<(PhysFrame, u8)>>, // Track allocated RX frames
78}
79
80// Send and Sync are safe because we use SpinLocks
81unsafe impl Send for VirtioNetDevice {}
82unsafe impl Sync for VirtioNetDevice {}
83
84impl VirtioNetDevice {
85    /// Initialize a VirtIO network device from a PCI device
86    pub unsafe fn new(pci_dev: PciDevice) -> Result<Self, &'static str> {
87        log::info!("VirtIO-net: Initializing device at {:?}", pci_dev.address);
88
89        // Create VirtIO device
90        let device = VirtioDevice::new(pci_dev)?;
91
92        // Reset device
93        device.reset();
94
95        // Acknowledge device
96        device.add_status(status::ACKNOWLEDGE as u8);
97
98        // Indicate we know how to drive it
99        device.add_status(status::DRIVER as u8);
100
101        // Read and negotiate features
102        let _device_features = device.read_device_features();
103
104        // Request MAC address feature
105        let guest_features = features::VIRTIO_NET_F_MAC | features::VIRTIO_NET_F_STATUS;
106        device.write_guest_features(guest_features);
107
108        // Features OK
109        device.add_status(status::FEATURES_OK as u8);
110
111        // Verify features OK
112        if device.get_status() & (status::FEATURES_OK as u8) == 0 {
113            return Err("Device doesn't support our feature set");
114        }
115
116        // Create virtqueues
117        // Queue 0: RX (receive)
118        // Queue 1: TX (transmit)
119        let rx_queue = Virtqueue::new(128)?;
120        let tx_queue = Virtqueue::new(128)?;
121
122        // Setup queues with device
123        device.setup_queue(0, &rx_queue);
124        device.setup_queue(1, &tx_queue);
125
126        // Read MAC address from device config space
127        // For legacy devices, MAC is at offset 20 + 0
128        let mut mac_address = [0u8; 6];
129        for i in 0..6 {
130            mac_address[i] = device.read_reg_u8(20 + i as u16);
131        }
132
133        log::info!(
134            "VirtIO-net: MAC address: {:02x}:{:02x}:{:02x}:{:02x}:{:02x}:{:02x}",
135            mac_address[0],
136            mac_address[1],
137            mac_address[2],
138            mac_address[3],
139            mac_address[4],
140            mac_address[5]
141        );
142
143        // Driver ready
144        device.add_status(status::DRIVER_OK as u8);
145
146        let net_device = Self {
147            device,
148            rx_queue: SpinLock::new(rx_queue),
149            tx_queue: SpinLock::new(tx_queue),
150            mac_address,
151            rx_frames: SpinLock::new(VecDeque::new()),
152        };
153
154        // Fill RX queue with buffers
155        net_device.refill_rx_queue()?;
156
157        Ok(net_device)
158    }
159
160    /// Fill the RX queue with receive buffers
161    fn refill_rx_queue(&self) -> Result<(), &'static str> {
162        let mut rx_queue = self.rx_queue.lock();
163        let mut rx_frames = self.rx_frames.lock();
164
165        // We want to keep some buffers in the RX queue
166        let current_filled = rx_frames.len();
167        let target_filled = 64;
168        let mut added = 0usize;
169
170        if current_filled >= target_filled {
171            return Ok(());
172        }
173
174        for _ in 0..(target_filled - current_filled) {
175            // Allocate buffer for header + MTU
176            let buf_size = NET_HDR_SIZE + net::MTU;
177            let buf_pages = (buf_size + 4095) / 4096;
178            let buf_order = buf_pages.next_power_of_two().trailing_zeros() as u8;
179
180            let buf_frame = match crate::sync::with_irqs_disabled(|token| {
181                memory::allocate_frames(token, buf_order)
182            }) {
183                Ok(frame) => frame,
184                Err(_) => break, // No more memory available
185            };
186
187            let buf_addr = buf_frame.start_address.as_u64();
188            let virt_addr = crate::memory::phys_to_virt(buf_addr);
189
190            // Zero the buffer (header needs to be zeroed mostly)
191            unsafe {
192                ptr::write_bytes(virt_addr as *mut u8, 0, buf_size);
193            }
194
195            // Add buffer to RX queue (device Writable)
196            match rx_queue.add_buffer(&[(buf_addr, buf_size as u32, true)]) {
197                Ok(_) => {
198                    rx_frames.push_back((buf_frame, buf_order));
199                    added += 1;
200                }
201                Err(_) => {
202                    // Queue full, free the buffer
203                    crate::sync::with_irqs_disabled(|token| {
204                        memory::free_frames(token, buf_frame, buf_order);
205                    });
206                    break;
207                }
208            }
209        }
210
211        // Notify device about new RX buffers
212        if rx_queue.should_notify() {
213            self.device.notify_queue(0);
214        }
215
216        if rx_frames.is_empty() && current_filled == 0 && added == 0 {
217            return Err("Failed to allocate RX buffers");
218        }
219
220        Ok(())
221    }
222
223    /// Read link status from device
224    fn read_link_status(&self) -> u16 {
225        // Status is at offset 6 in device-specific config (offset 20 + 6 = 26)
226        self.device.read_reg_u16(26)
227    }
228}
229
230impl NetworkDevice for VirtioNetDevice {
231    /// Performs the name operation.
232    fn name(&self) -> &str {
233        "virtio-net"
234    }
235
236    /// Performs the receive operation.
237    fn receive(&self, buf: &mut [u8]) -> Result<usize, NetError> {
238        let mut rx_queue = self.rx_queue.lock();
239
240        // Check if there's a used buffer
241        if !rx_queue.has_used() {
242            return Err(NetError::NoPacket);
243        }
244
245        let (token, len) = rx_queue.get_used().ok_or(NetError::NoPacket)?;
246
247        let _desc_index = token as usize;
248        let _desc_table = rx_queue.desc_area(); // Physical address
249
250        let (frame, order) = self
251            .rx_frames
252            .lock()
253            .pop_front()
254            .ok_or(NetError::NotReady)?;
255
256        let buf_addr = frame.start_address.as_u64();
257        let virt_addr = crate::memory::phys_to_virt(buf_addr);
258
259        // Check if token matches what we expect?
260        // We can't easily without reading the descriptor.
261
262        let header_ptr = virt_addr as *const VirtioNetHeader;
263        let data_ptr = (virt_addr + NET_HDR_SIZE as u64) as *const u8;
264
265        let _header = unsafe { ptr::read(header_ptr) };
266        let packet_len = (len as usize).saturating_sub(NET_HDR_SIZE);
267
268        if buf.len() < packet_len {
269            // Buffer too small, packet lost
270            crate::sync::with_irqs_disabled(|token| {
271                memory::free_frames(token, frame, order);
272            });
273            drop(rx_queue);
274            // We still need to refill.
275            let _ = self.refill_rx_queue();
276            return Err(NetError::BufferTooSmall);
277        }
278
279        // Copy packet data
280        if packet_len > 0 {
281            unsafe {
282                ptr::copy_nonoverlapping(data_ptr, buf.as_mut_ptr(), packet_len);
283            }
284        }
285
286        // Free the frame
287        crate::sync::with_irqs_disabled(|token| {
288            memory::free_frames(token, frame, order);
289        });
290        drop(rx_queue);
291
292        // Refill RX queue
293        let _ = self.refill_rx_queue();
294
295        Ok(packet_len)
296    }
297
298    /// Performs the transmit operation.
299    fn transmit(&self, buf: &[u8]) -> Result<(), NetError> {
300        if buf.len() > net::MTU {
301            return Err(NetError::BufferTooSmall);
302        }
303
304        // Allocate TX buffer (header + data)
305        let buf_size = NET_HDR_SIZE + buf.len();
306        let buf_pages = (buf_size + 4095) / 4096;
307        let buf_order = buf_pages.next_power_of_two().trailing_zeros() as u8;
308
309        let buf_frame = crate::sync::with_irqs_disabled(|token| {
310            memory::allocate_frames(token, buf_order)
311        })
312        .map_err(|_| NetError::NotReady)?;
313
314        let buf_addr = buf_frame.start_address.as_u64();
315        let virt_addr = crate::memory::phys_to_virt(buf_addr);
316
317        let header_ptr = virt_addr as *mut VirtioNetHeader;
318        let data_ptr = (virt_addr + NET_HDR_SIZE as u64) as *mut u8;
319
320        // Write header
321        unsafe {
322            ptr::write(header_ptr, VirtioNetHeader::default());
323            ptr::copy_nonoverlapping(buf.as_ptr(), data_ptr, buf.len());
324        }
325
326        // Submit to TX queue
327        let mut tx_queue = self.tx_queue.lock();
328        let token = tx_queue
329            .add_buffer(&[(buf_addr, buf_size as u32, false)]) // Device Readable
330            .map_err(|_| {
331                // Free buffer if failed
332                crate::sync::with_irqs_disabled(|token| {
333                    memory::free_frames(token, buf_frame, buf_order);
334                });
335                NetError::TxQueueFull
336            })?;
337
338        if tx_queue.should_notify() {
339            self.device.notify_queue(1);
340        }
341        drop(tx_queue);
342
343        // Wait for completion (simple spin for now)
344        loop {
345            let mut tx_queue = self.tx_queue.lock();
346            if tx_queue.has_used() {
347                if let Some((used_token, _)) = tx_queue.get_used() {
348                    // Assuming correct order for now or just waiting for *any* completion which matches ours
349                    if used_token == token {
350                        break;
351                    }
352                }
353            }
354            drop(tx_queue);
355            core::hint::spin_loop();
356        }
357
358        // Free TX buffer
359        crate::sync::with_irqs_disabled(|token| {
360            memory::free_frames(token, buf_frame, buf_order);
361        });
362
363        Ok(())
364    }
365
366    /// Performs the mac address operation.
367    fn mac_address(&self) -> [u8; 6] {
368        self.mac_address
369    }
370
371    /// Performs the link up operation.
372    fn link_up(&self) -> bool {
373        let status = self.read_link_status();
374        status & net_status::VIRTIO_NET_S_LINK_UP != 0
375    }
376}
377
378/// Global VirtIO network device
379static VIRTIO_NET: SpinRwLock<Option<Arc<VirtioNetDevice>>> = SpinRwLock::new(None);
380
381/// Initialize VirtIO network device and register it in the global net registry.
382pub fn init() {
383    log::info!("VirtIO-net: Scanning for devices...");
384
385    // Prefer strict class-based probe (network/ethernet), with fallback to
386    // vendor+device for odd firmware/virtual setups.
387    let pci_dev = match pci::probe_first(pci::ProbeCriteria {
388        vendor_id: Some(pci::vendor::VIRTIO),
389        device_id: Some(pci::device::VIRTIO_NET),
390        class_code: Some(pci::class::NETWORK),
391        subclass: Some(pci::net_subclass::ETHERNET),
392        prog_if: None,
393    })
394    .or_else(|| pci::find_virtio_device(pci::device::VIRTIO_NET))
395    {
396        Some(dev) => dev,
397        None => {
398            log::warn!("VirtIO-net: No network device found");
399            return;
400        }
401    };
402
403    match unsafe { VirtioNetDevice::new(pci_dev) } {
404        Ok(device) => {
405            let arc = Arc::new(device);
406            *VIRTIO_NET.write() = Some(arc.clone());
407            net::register_device(arc);
408        }
409        Err(e) => {
410            log::error!("VirtIO-net: Failed to initialize device: {}", e);
411        }
412    }
413}
414
415/// Get the VirtIO network device instance (if present).
416pub fn get_device() -> Option<Arc<VirtioNetDevice>> {
417    VIRTIO_NET.read().clone()
418}