Skip to main content

strat9_kernel/hardware/storage/
nvme.rs

1// NVMe block device driver
2// Reference: NVM Express Base Specification 2.0
3
4use crate::{
5    hardware::pci_client::{self as pci, Bar, ProbeCriteria},
6    memory::{allocate_dma_frame, phys_to_virt},
7};
8use alloc::{format, string::String, sync::Arc, vec::Vec};
9use core::{
10    ptr,
11    sync::atomic::{AtomicU8, Ordering},
12};
13use spin::Mutex;
14
15const NVME_PAGE_SIZE: usize = 4096;
16
17#[repr(transparent)]
18struct VolatileCell<T> {
19    value: T,
20}
21
22impl<T> VolatileCell<T> {
23    /// Performs the read operation.
24    fn read(&self) -> T
25    where
26        T: Copy,
27    {
28        unsafe { ptr::read_volatile(&self.value) }
29    }
30    /// Performs the write operation.
31    fn write(&self, val: T) {
32        unsafe { ptr::write_volatile(core::ptr::addr_of!(self.value) as *mut T, val) }
33    }
34}
35
36unsafe impl<T: Send> Send for VolatileCell<T> {}
37unsafe impl<T: Sync> Sync for VolatileCell<T> {}
38
39#[repr(C)]
40struct Capability {
41    value: VolatileCell<u64>,
42}
43
44impl Capability {
45    /// Performs the max queue entries operation.
46    fn max_queue_entries(&self) -> u16 {
47        (self.value.read() & 0xFFFF) as u16
48    }
49    /// Performs the doorbell stride operation.
50    fn doorbell_stride(&self) -> u64 {
51        (self.value.read() >> 32) & 0xF
52    }
53}
54
55#[repr(transparent)]
56struct Version {
57    value: VolatileCell<u32>,
58}
59
60#[repr(C)]
61struct ControllerConfig {
62    value: VolatileCell<u32>,
63}
64
65impl ControllerConfig {
66    /// Performs the clear io fields operation.
67    fn clear_io_fields(&self) {
68        let mut val = self.value.read();
69        val &= !(((0xF) << 16) | ((0xF) << 20) | ((0x7) << 4));
70        self.value.write(val);
71    }
72    /// Sets iosqes.
73    fn set_iosqes(&self, size: u32) {
74        let mut val = self.value.read();
75        val |= (size & 0xF) << 16;
76        self.value.write(val);
77    }
78    /// Sets iocqes.
79    fn set_iocqes(&self, size: u32) {
80        let mut val = self.value.read();
81        val |= (size & 0xF) << 20;
82        self.value.write(val);
83    }
84    /// Sets css.
85    fn set_css(&self, css: u32) {
86        let mut val = self.value.read();
87        val |= (css & 0x7) << 4;
88        self.value.write(val);
89    }
90    /// Sets enable.
91    fn set_enable(&self, enable: bool) {
92        let mut val = self.value.read();
93        if enable {
94            val |= 1;
95        } else {
96            val &= !1;
97        }
98        self.value.write(val);
99    }
100    /// Returns whether enabled.
101    fn is_enabled(&self) -> bool {
102        (self.value.read() & 1) != 0
103    }
104}
105
106#[repr(transparent)]
107struct ControllerStatus {
108    value: VolatileCell<u32>,
109}
110
111impl ControllerStatus {
112    /// Returns whether ready.
113    fn is_ready(&self) -> bool {
114        (self.value.read() & 1) != 0
115    }
116    /// Returns whether fatal.
117    fn is_fatal(&self) -> bool {
118        (self.value.read() >> 1) & 1 != 0
119    }
120}
121
122#[repr(C)]
123struct Registers {
124    capability: Capability,
125    version: Version,
126    _intms: VolatileCell<u32>,
127    _intmc: VolatileCell<u32>,
128    cc: ControllerConfig,
129    _reserved1: VolatileCell<u32>,
130    csts: ControllerStatus,
131    _reserved2: VolatileCell<u32>,
132    aqa: VolatileCell<u32>,
133    asq: VolatileCell<u64>,
134    acq: VolatileCell<u64>,
135}
136
137#[derive(Debug, Clone, Copy)]
138enum NvmeError {
139    ControllerFatal,
140    Timeout,
141    InvalidNamespace,
142    IoError,
143}
144
145#[derive(Debug, Clone)]
146pub struct NvmeNamespace {
147    pub nsid: u32,
148    pub size: u64,
149    pub block_size: u32,
150}
151
152pub struct NvmeController {
153    registers: usize,
154    admin_queue: Mutex<QueuePair>,
155    namespaces: Vec<NvmeNamespace>,
156    pub name: String,
157}
158
159unsafe impl Send for NvmeController {}
160unsafe impl Sync for NvmeController {}
161
162impl NvmeController {
163    /// Creates a new instance.
164    unsafe fn new(registers: usize, name: String) -> Result<Self, NvmeError> {
165        let regs = &*(registers as *const Registers);
166        let dstrd = regs.capability.doorbell_stride() as usize;
167        let max_entries = regs.capability.max_queue_entries();
168        let queue_size = core::cmp::min(max_entries as usize, 1024);
169
170        let admin_queue = QueuePair::new(registers, queue_size, dstrd);
171        let mut controller = Self {
172            registers,
173            admin_queue: Mutex::new(admin_queue),
174            namespaces: Vec::new(),
175            name,
176        };
177
178        controller.init_admin_queue()?;
179        controller.identify_namespaces()?;
180        Ok(controller)
181    }
182
183    /// Performs the submit admin command operation.
184    fn submit_admin_command(&self, command: Command) -> Result<CompletionEntry, NvmeError> {
185        let mut admin = self.admin_queue.lock();
186        admin.submit_command(command).ok_or(NvmeError::IoError)
187    }
188
189    /// Initializes admin queue.
190    fn init_admin_queue(&mut self) -> Result<(), NvmeError> {
191        let regs = unsafe { &*(self.registers as *const Registers) };
192        let (admin_sq_phys, admin_cq_phys, queue_size) = {
193            let q = self.admin_queue.lock();
194            (q.submission_phys(), q.completion_phys(), q.size)
195        };
196
197        if queue_size == 0 {
198            return Err(NvmeError::IoError);
199        }
200        let qsz = ((queue_size as u32).saturating_sub(1)) & 0x0FFF;
201
202        if regs.cc.is_enabled() {
203            regs.cc.set_enable(false);
204            let mut disable_timeout = 1_000_000u32;
205            while regs.csts.is_ready() {
206                core::hint::spin_loop();
207                disable_timeout = disable_timeout.saturating_sub(1);
208                if disable_timeout == 0 {
209                    return Err(NvmeError::Timeout);
210                }
211            }
212        }
213
214        regs.aqa.write(qsz | (qsz << 16));
215        regs.asq.write(admin_sq_phys);
216        regs.acq.write(admin_cq_phys);
217
218        regs.cc.clear_io_fields();
219        regs.cc.set_css(0);
220        regs.cc.set_iosqes(6);
221        regs.cc.set_iocqes(6);
222        regs.cc.set_enable(true);
223
224        let mut timeout = 1_000_000;
225        while !regs.csts.is_ready() {
226            core::hint::spin_loop();
227            timeout -= 1;
228            if timeout == 0 {
229                return Err(NvmeError::Timeout);
230            }
231        }
232
233        if regs.csts.is_fatal() {
234            return Err(NvmeError::ControllerFatal);
235        }
236
237        log::info!(
238            "NVMe: Controller v{}.{}.{} ready",
239            regs.version.value.read() >> 16,
240            (regs.version.value.read() >> 8) & 0xFF,
241            regs.version.value.read() & 0xFF
242        );
243        Ok(())
244    }
245
246    /// Performs the identify operation.
247    fn identify(&self, cns: u8, nsid: u32) -> Result<*mut u8, NvmeError> {
248        let frame = allocate_dma_frame().ok_or(NvmeError::IoError)?;
249        let phys = frame.start_address.as_u64();
250        let virt = phys_to_virt(phys) as *mut u8;
251        unsafe {
252            ptr::write_bytes(virt, 0, NVME_PAGE_SIZE);
253        }
254
255        let cmd = Command {
256            opcode: 0x06,
257            nsid,
258            prp1: phys,
259            cdw10: cns as u32,
260            ..Default::default()
261        };
262
263        let completion = self.submit_admin_command(cmd)?;
264        if completion.status_code() != 0 {
265            return Err(NvmeError::IoError);
266        }
267        Ok(virt)
268    }
269
270    /// Performs the identify namespaces operation.
271    fn identify_namespaces(&mut self) -> Result<(), NvmeError> {
272        let ctrl_data = self.identify(0x01, 0)?;
273        let nn = unsafe { ptr::read(ctrl_data.add(520) as *const u32) };
274        if nn == 0 {
275            return Err(NvmeError::InvalidNamespace);
276        }
277
278        for nsid in 1..=nn {
279            if let Ok(ns_data) = self.identify(0x00, nsid) {
280                unsafe {
281                    let nsze = ptr::read(ns_data.add(16) as *const u64);
282                    let flbas = ptr::read(ns_data.add(26) as *const u8) as usize;
283                    let lbaf_index = flbas & 0xF;
284                    let lbaf_offset = 128 + lbaf_index * 16;
285                    let lbaf_data = ptr::read(ns_data.add(lbaf_offset) as *const u16);
286                    let block_size = (1 << lbaf_data) as u32;
287
288                    self.namespaces.push(NvmeNamespace {
289                        nsid,
290                        size: nsze,
291                        block_size,
292                    });
293                    log::info!(
294                        "NVMe: Namespace {} - {} blocks @ {} bytes",
295                        nsid,
296                        nsze,
297                        block_size
298                    );
299                }
300            }
301        }
302        Ok(())
303    }
304
305    /// Performs the namespace count operation.
306    pub fn namespace_count(&self) -> usize {
307        self.namespaces.len()
308    }
309    /// Returns namespace.
310    pub fn get_namespace(&self, index: usize) -> Option<&NvmeNamespace> {
311        self.namespaces.get(index)
312    }
313}
314
315#[repr(C)]
316#[derive(Default, Copy, Clone)]
317struct Command {
318    opcode: u8,
319    flags: u8,
320    command_id: u16,
321    nsid: u32,
322    cdw2: u32,
323    cdw3: u32,
324    prp1: u64,
325    prp2: u64,
326    cdw10: u32,
327    cdw11: u32,
328    cdw12: u32,
329    cdw13: u32,
330    cdw14: u32,
331    cdw15: u32,
332}
333
334#[repr(C)]
335#[derive(Copy, Clone)]
336struct CompletionEntry {
337    dw0: u32,
338    dw1: u32,
339    sq_head: u16,
340    sq_id: u16,
341    command_id: u16,
342    status: u16,
343}
344
345impl CompletionEntry {
346    /// Performs the status code operation.
347    fn status_code(&self) -> u8 {
348        ((self.status >> 1) & 0xFF) as u8
349    }
350}
351
352struct QueuePair {
353    #[allow(dead_code)]
354    id: u16,
355    size: usize,
356    command_id: u16,
357    submission: Queue<Submission>,
358    completion: Queue<Completion>,
359}
360
361struct Submission;
362struct Completion;
363
364trait QueueType {
365    type EntryType;
366    const DOORBELL_OFFSET: usize;
367}
368
369impl QueueType for Submission {
370    type EntryType = Command;
371    const DOORBELL_OFFSET: usize = 0;
372}
373
374impl QueueType for Completion {
375    type EntryType = CompletionEntry;
376    const DOORBELL_OFFSET: usize = 1;
377}
378
379struct Queue<T: QueueType> {
380    doorbell: *const VolatileCell<u32>,
381    entries: *mut T::EntryType,
382    size: usize,
383    index: usize,
384    phase: bool,
385    phys_addr: u64,
386}
387
388impl<T: QueueType> Queue<T> {
389    /// Creates a new instance.
390    fn new(registers_base: usize, size: usize, queue_id: u16, dstrd: usize) -> Self {
391        let doorbell_offset =
392            0x1000 + ((((queue_id as usize) * 2) + T::DOORBELL_OFFSET) * (4 << dstrd));
393        let doorbell =
394            unsafe { &*((registers_base + doorbell_offset) as *const VolatileCell<u32>) };
395
396        let frame = allocate_dma_frame().expect("NVMe: failed to allocate queue frame");
397        let phys_addr = frame.start_address.as_u64();
398        let virt_addr = phys_to_virt(phys_addr);
399
400        unsafe {
401            ptr::write_bytes(
402                virt_addr as *mut u8,
403                0,
404                size * core::mem::size_of::<T::EntryType>(),
405            );
406        }
407
408        Self {
409            doorbell,
410            entries: virt_addr as *mut T::EntryType,
411            size,
412            index: 0,
413            phase: true,
414            phys_addr,
415        }
416    }
417
418    /// Performs the phys addr operation.
419    fn phys_addr(&self) -> u64 {
420        self.phys_addr
421    }
422}
423
424impl Queue<Completion> {
425    /// Performs the poll completion operation.
426    fn poll_completion(&mut self) -> Option<CompletionEntry> {
427        unsafe {
428            let entry = &*self.entries.add(self.index);
429            let status = entry.status;
430            if ((status & 0x1) != 0) == self.phase {
431                let completion = ptr::read(entry);
432                if (completion.status >> 9) & 0x7 != 0 || (completion.status >> 1) & 0xFF != 0 {
433                    log::error!("NVMe: completion error");
434                    return None;
435                }
436                self.index = (self.index + 1) % self.size;
437                if self.index == 0 {
438                    self.phase = !self.phase;
439                }
440                (*self.doorbell).write(self.index as u32);
441                Some(completion)
442            } else {
443                None
444            }
445        }
446    }
447}
448
449impl Queue<Submission> {
450    /// Performs the submit command operation.
451    fn submit_command(&mut self, command: Command, idx: usize) {
452        unsafe {
453            ptr::write(self.entries.add(idx), command);
454            (*self.doorbell).write(((idx + 1) % self.size) as u32);
455        }
456        core::sync::atomic::fence(core::sync::atomic::Ordering::SeqCst);
457    }
458}
459
460impl QueuePair {
461    /// Creates a new instance.
462    fn new(registers_base: usize, size: usize, dstrd: usize) -> Self {
463        static NEXT_ID: AtomicU8 = AtomicU8::new(0);
464        let id = NEXT_ID.fetch_add(1, Ordering::SeqCst) as u16;
465        Self {
466            id,
467            size,
468            command_id: 0,
469            submission: Queue::new(registers_base, size, id, dstrd),
470            completion: Queue::new(registers_base, size, id, dstrd),
471        }
472    }
473
474    /// Performs the submission phys operation.
475    fn submission_phys(&self) -> u64 {
476        self.submission.phys_addr()
477    }
478    /// Performs the completion phys operation.
479    fn completion_phys(&self) -> u64 {
480        self.completion.phys_addr()
481    }
482
483    /// Performs the submit command operation.
484    fn submit_command(&mut self, command: Command) -> Option<CompletionEntry> {
485        let slot = self.command_id as usize % self.size;
486        let mut cmd = command;
487        unsafe {
488            ptr::write(&mut cmd.command_id as *mut u16, self.command_id);
489        }
490        self.command_id = self.command_id.wrapping_add(1);
491        self.submission.submit_command(cmd, slot);
492        let mut timeout = 5_000_000u32;
493        loop {
494            if let Some(c) = self.completion.poll_completion() {
495                return Some(c);
496            }
497            timeout = timeout.saturating_sub(1);
498            if timeout == 0 {
499                log::error!("NVMe: admin command timeout");
500                return None;
501            }
502            core::hint::spin_loop();
503        }
504    }
505}
506
507static NVME_CONTROLLERS: Mutex<Vec<Arc<NvmeController>>> = Mutex::new(Vec::new());
508
509/// Performs the init operation.
510pub fn init() {
511    log::info!("[NVMe] Scanning for NVMe controllers...");
512
513    let candidates = pci::probe_all(ProbeCriteria {
514        vendor_id: None,
515        device_id: None,
516        class_code: Some(pci::class::MASS_STORAGE),
517        subclass: Some(pci::storage_subclass::NVM),
518        prog_if: None,
519    });
520
521    for (i, pci_dev) in candidates.into_iter().enumerate() {
522        log::info!(
523            "NVMe: Found controller at {:?} (VEN:{:04x} DEV:{:04x})",
524            pci_dev.address,
525            pci_dev.vendor_id,
526            pci_dev.device_id
527        );
528
529        pci_dev.enable_bus_master();
530        pci_dev.enable_memory_space();
531
532        let bar = match pci_dev.read_bar(0) {
533            Some(Bar::Memory64 { addr, .. }) => addr,
534            _ => {
535                log::warn!("NVMe: Invalid BAR0");
536                continue;
537            }
538        };
539
540        let registers = phys_to_virt(bar) as usize;
541        let name = format!("nvme{}", i);
542
543        match unsafe { NvmeController::new(registers, name) } {
544            Ok(controller) => {
545                NVME_CONTROLLERS.lock().push(Arc::new(controller));
546            }
547            Err(e) => {
548                log::warn!("NVMe: Failed to initialize controller: {:?}", e);
549            }
550        }
551    }
552
553    log::info!(
554        "[NVMe] Found {} controller(s)",
555        NVME_CONTROLLERS.lock().len()
556    );
557}
558
559/// Returns first controller.
560pub fn get_first_controller() -> Option<Arc<NvmeController>> {
561    NVME_CONTROLLERS.lock().first().cloned()
562}
563
564/// Performs the list controllers operation.
565pub fn list_controllers() -> Vec<String> {
566    NVME_CONTROLLERS
567        .lock()
568        .iter()
569        .map(|c| c.name.clone())
570        .collect()
571}