Skip to main content

strat9_kernel/hardware/storage/
ahci.rs

1//! AHCI (Advanced Host Controller Interface) driver : AHCI spec 1.3.1
2//!
3//! PCI: class=0x01 (Mass Storage), subclass=0x06 (SATA), prog_if=0x01
4//! MMIO base: BAR5 (ABAR)
5//!
6//! Per-port memory layout (packed into one 4 KB page):
7//!   [0x000..0x3FF]  Command List   (1024 B, 32 × 32-byte headers)
8//!   [0x400..0x4FF]  FIS receive    (256 B)
9//!   [0x500..0x5FF]  Command table  (128 B header + 1 × 16-byte PRDT)
10//!
11//! ## IRQ-driven completion (DRV-02 v2)
12//!
13//! When a task context exists (`current_task_id()` is `Some`), commands are
14//! completed via interrupt + `WaitQueue::wait_until()` so the issuing task
15//! blocks without busy-spinning.  During early boot (no task yet) the legacy
16//! polling path is used as a fallback.
17//!
18//! Per-port statics (indexed by `port_num 0..32`) are used so the IRQ handler
19//! can signal completion without acquiring any slow lock:
20//!   - `PORT_VIRT[n]`       : MMIO virtual address of port n registers
21//!   - `PORT_SLOT0_DONE[n]` : set by IRQ handler when slot-0 completes
22//!   - `PORT_SLOT0_ERROR[n]`: set by IRQ handler when a task-file error fires
23//!   - `PORT_WQ[n]`         : WaitQueue; issuing task blocks here
24
25use crate::{
26    hardware::pci_client::{self as pci, ProbeCriteria},
27    memory::{self, phys_to_virt, PhysFrame},
28    sync::{SpinLock, WaitQueue},
29};
30use alloc::{boxed::Box, vec::Vec};
31use core::{
32    ptr,
33    sync::atomic::{AtomicBool, AtomicU64, AtomicU8, Ordering},
34};
35
36pub use super::virtio_block::{BlockDevice, BlockError, SECTOR_SIZE};
37
38// ========== HBA generic registers (at ABAR) ==================================================================================================================================
39const HBA_GHC: u64 = 0x04;
40const HBA_IS: u64 = 0x08;
41const HBA_PI: u64 = 0x0C;
42
43const GHC_AE: u32 = 1 << 31; // AHCI Enable
44const GHC_IE: u32 = 1 << 1; // Global Interrupt Enable
45const GHC_HR: u32 = 1 << 0; // HBA Reset
46
47// ========== Port register offsets (relative to port base = ABAR + 0x100 + n*0x80)
48const PORT_CLB: u64 = 0x00;
49const PORT_CLBU: u64 = 0x04;
50const PORT_FB: u64 = 0x08;
51const PORT_FBU: u64 = 0x0C;
52const PORT_IS: u64 = 0x10;
53const PORT_IE: u64 = 0x14; // PxIE : port interrupt enable
54const PORT_CMD: u64 = 0x18;
55const PORT_TFD: u64 = 0x20;
56const PORT_SIG: u64 = 0x24;
57const PORT_SSTS: u64 = 0x28;
58const PORT_SERR: u64 = 0x30;
59const PORT_CI: u64 = 0x38;
60
61const CMD_ST: u32 = 1 << 0; // Start
62const CMD_FRE: u32 = 1 << 4; // FIS Receive Enable
63const CMD_FR: u32 = 1 << 14; // FIS Receive Running
64const CMD_CR: u32 = 1 << 15; // Command List Running
65
66const TFD_BSY: u32 = 1 << 7;
67const TFD_DRQ: u32 = 1 << 3;
68
69const SSTS_DET_COMM: u32 = 3;
70const SSTS_DET_MASK: u32 = 0xF;
71
72const SIG_SATA: u32 = 0x0000_0101;
73
74// PxIE bits
75const PXIE_DHRE: u32 = 1 << 0; // D2H Register FIS Received Enable (normal DMA completion)
76const PXIE_TFEE: u32 = 1 << 30; // Task File Error Enable
77
78// ========== Per-port memory layout offsets ========================================================
79const CLB_OFF: u64 = 0x000; // Command List (1024 B)
80const FB_OFF: u64 = 0x400; // FIS buffer   (256 B)
81const CTAB_OFF: u64 = 0x500; // Command Table (128 B header + 16 B PRDT)
82
83// Command header field byte offsets within a 32-byte slot
84const CMDH_FLAGS: usize = 0; // u16: cfl[4:0] | a | w | p | r | b | c
85const CMDH_PRDTL: usize = 2; // u16
86const CMDH_CTBA: usize = 8; // u32
87const CMDH_CTBAU: usize = 12; // u32
88
89// Command table FIS and PRDT offsets
90const CTAB_CFIS: usize = 0x00; // H2D FIS (64 B allocated)
91const CTAB_PRDT: usize = 0x80; // PRDT entries
92
93// H2D FIS field offsets (FIS type 0x27, Register Host-to-Device)
94const FIS_TYPE: usize = 0;
95const FIS_FLAGS: usize = 1; // PM port [3:0] | C [7]
96const FIS_CMD: usize = 2;
97const FIS_LBA0: usize = 4;
98const FIS_LBA1: usize = 5;
99const FIS_LBA2: usize = 6;
100const FIS_DEVICE: usize = 7;
101const FIS_LBA3: usize = 8;
102const FIS_LBA4: usize = 9;
103const FIS_LBA5: usize = 10;
104const FIS_CNT_LO: usize = 12;
105const FIS_CNT_HI: usize = 13;
106
107const FIS_TYPE_H2D: u8 = 0x27;
108const FIS_C_BIT: u8 = 0x80; // command (not control)
109const FIS_LBA_MODE: u8 = 1 << 6;
110
111// ATA commands (48-bit LBA)
112const ATA_IDENTIFY: u8 = 0xEC;
113const ATA_READ_DMA_EXT: u8 = 0x25;
114const ATA_WRITE_DMA_EXT: u8 = 0x35;
115
116// PxIS bit 30 = Task File Error Status
117const PXIS_TFES: u32 = 1 << 30;
118
119// ========== Error type ==============================
120
121#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
122pub enum AhciError {
123    #[error("no AHCI controller on PCI bus")]
124    NoController,
125    #[error("invalid BAR5 (ABAR)")]
126    BadAbar,
127    #[error("physical memory allocation failed")]
128    Alloc,
129    #[error("port BSY/DRQ set")]
130    Busy,
131    #[error("command timed out")]
132    Timeout,
133    #[error("device reported task-file error")]
134    DeviceError,
135    #[error("invalid sector number")]
136    InvalidSector,
137    #[error("buffer too small (need ≥ SECTOR_SIZE bytes)")]
138    BufferTooSmall,
139    #[error("no usable SATA port found")]
140    NoPort,
141}
142
143// ========== Internal port handle ==========================================================================================================================================================================
144
145struct AhciPort {
146    port_num: u8,
147    port_virt: u64, // virtual address of port registers
148    mem_phys: u64,  // physical base of the per-port CLB/FB/CTAB frame
149    mem_virt: u64,  // HHDM virtual address of that frame
150    sector_count: u64,
151}
152
153// ========== Controller ==============================
154
155pub struct AhciController {
156    #[allow(dead_code)]
157    abar_virt: u64,
158    ports: Vec<AhciPort>,
159}
160
161// SAFETY: AhciController is only accessed behind SpinLock<Option<...>>
162unsafe impl Send for AhciController {}
163unsafe impl Sync for AhciController {}
164
165// ========== Per-port IRQ completion state ============================================================================================================================================
166// These statics are accessed from the IRQ handler without locks.
167// Indexed by port_num (0..32).
168
169/// AHCI ABAR virtual address : written once during init, read by IRQ handler.
170static AHCI_ABAR_VIRT: AtomicU64 = AtomicU64::new(0);
171
172/// PCI interrupt line used by this controller.
173pub static AHCI_IRQ_LINE: AtomicU8 = AtomicU8::new(0xFF);
174
175/// Per-port MMIO virtual addresses : written once during init.
176static PORT_VIRT: [AtomicU64; 32] = {
177    const INIT: AtomicU64 = AtomicU64::new(0);
178    [INIT; 32]
179};
180
181/// Per-port slot-0 completion flags : set by IRQ handler, cleared by consumer.
182static PORT_SLOT0_DONE: [AtomicBool; 32] = {
183    const INIT: AtomicBool = AtomicBool::new(false);
184    [INIT; 32]
185};
186
187/// Per-port slot-0 error flags : set by IRQ handler on task-file error.
188static PORT_SLOT0_ERROR: [AtomicBool; 32] = {
189    const INIT: AtomicBool = AtomicBool::new(false);
190    [INIT; 32]
191};
192
193/// Per-port wait queues : tasks block here while waiting for IRQ completion.
194static PORT_WQ: [WaitQueue; 32] = {
195    const INIT: WaitQueue = WaitQueue::new();
196    [INIT; 32]
197};
198
199// ========== MMIO helpers ==============================
200
201/// Performs the rd32 operation.
202#[inline]
203unsafe fn rd32(base: u64, off: u64) -> u32 {
204    ptr::read_volatile((base + off) as *const u32)
205}
206
207/// Performs the wr32 operation.
208#[inline]
209unsafe fn wr32(base: u64, off: u64, val: u32) {
210    ptr::write_volatile((base + off) as *mut u32, val);
211}
212
213// ========== Port start/stop ====================
214
215/// Performs the port stop operation.
216fn port_stop(pvirt: u64) {
217    // SAFETY: pvirt is a valid MMIO virtual address for this port's registers
218    unsafe {
219        let mut cmd = rd32(pvirt, PORT_CMD);
220        cmd &= !(CMD_ST | CMD_FRE);
221        wr32(pvirt, PORT_CMD, cmd);
222        // Spec mandates waiting ≤ 500 ms for FR and CR to clear
223        for _ in 0..500_000u32 {
224            if rd32(pvirt, PORT_CMD) & (CMD_FR | CMD_CR) == 0 {
225                return;
226            }
227            core::hint::spin_loop();
228        }
229        log::warn!("AHCI: port stop timed out (port registers @ {:#x})", pvirt);
230    }
231}
232
233/// Performs the port start operation.
234fn port_start(pvirt: u64) {
235    // SAFETY: pvirt is a valid MMIO virtual address
236    unsafe {
237        // Ensure Command List Running is clear before asserting ST
238        let mut timeout = 500_000u32;
239        while rd32(pvirt, PORT_CMD) & CMD_CR != 0 {
240            timeout = timeout.saturating_sub(1);
241            if timeout == 0 {
242                log::warn!("AHCI: port start wait for CR clear timed out");
243                break;
244            }
245            core::hint::spin_loop();
246        }
247        let mut cmd = rd32(pvirt, PORT_CMD);
248        cmd |= CMD_FRE | CMD_ST;
249        wr32(pvirt, PORT_CMD, cmd);
250    }
251}
252
253/// Rebase port: assign our CLB/FB buffers then start the port.
254fn port_rebase(pvirt: u64, phys: u64) {
255    port_stop(pvirt);
256    // SAFETY: pvirt valid MMIO; phys is our allocated frame
257    unsafe {
258        let clb = phys + CLB_OFF;
259        let fb = phys + FB_OFF;
260        wr32(pvirt, PORT_CLB, (clb & 0xFFFF_FFFF) as u32);
261        wr32(pvirt, PORT_CLBU, (clb >> 32) as u32);
262        wr32(pvirt, PORT_FB, (fb & 0xFFFF_FFFF) as u32);
263        wr32(pvirt, PORT_FBU, (fb >> 32) as u32);
264        // Clear any stale interrupt/error status
265        wr32(pvirt, PORT_IS, 0xFFFF_FFFF);
266        wr32(pvirt, PORT_SERR, 0xFFFF_FFFF);
267    }
268    port_start(pvirt);
269}
270
271/// Enable interrupts for a port (DHRE = DMA completion, TFEE = errors).
272fn port_enable_irq(pvirt: u64) {
273    // SAFETY: pvirt is a valid MMIO port register address
274    unsafe {
275        wr32(pvirt, PORT_IE, PXIE_DHRE | PXIE_TFEE);
276    }
277}
278
279// ========== Bounce-buffer management ================================================================================================================================================================
280
281struct Bounce {
282    frame: PhysFrame,
283    order: u8,
284    phys: u64,
285    virt: u64,
286}
287
288impl Bounce {
289    /// Performs the alloc operation.
290    fn alloc(bytes: usize) -> Result<Self, AhciError> {
291        let pages = (bytes + 4095) / 4096;
292        let order = pages.next_power_of_two().trailing_zeros() as u8;
293        let frame =
294            crate::sync::with_irqs_disabled(|token| memory::allocate_phys_contiguous(token, order))
295                .map_err(|_| AhciError::Alloc)?;
296        let phys = frame.start_address.as_u64();
297        Ok(Self {
298            frame,
299            order,
300            phys,
301            virt: phys_to_virt(phys),
302        })
303    }
304
305    /// Performs the free operation.
306    fn free(self) {
307        crate::sync::with_irqs_disabled(|token| {
308            memory::free_phys_contiguous(token, self.frame, self.order);
309        });
310    }
311}
312
313// ========== Command submission ==========
314//
315// Two completion strategies:
316//   1. Task context (current_task_id() is Some): IRQ + WaitQueue : the issuing
317//      task is blocked by the scheduler until the IRQ fires and wakes it.
318//   2. Boot context (no task yet): legacy busy-poll with timeout.
319
320/// Performs the submit cmd operation.
321fn submit_cmd(
322    port: &AhciPort,
323    lba: u64,
324    count: u16,
325    buf: &mut [u8],
326    write: bool,
327    ata_cmd: u8,
328) -> Result<(), AhciError> {
329    let nbytes = (count as usize) * SECTOR_SIZE;
330    if buf.len() < nbytes {
331        return Err(AhciError::BufferTooSmall);
332    }
333
334    // SAFETY: MMIO read to check device readiness
335    let tfd = unsafe { rd32(port.port_virt, PORT_TFD) };
336    if tfd & (TFD_BSY | TFD_DRQ) != 0 {
337        return Err(AhciError::Busy);
338    }
339
340    let bounce = Bounce::alloc(nbytes)?;
341
342    if write {
343        // SAFETY: bounce.virt is a valid HHDM address ≥ nbytes; buf.len() ≥ nbytes
344        unsafe {
345            ptr::copy_nonoverlapping(buf.as_ptr(), bounce.virt as *mut u8, nbytes);
346        }
347    }
348
349    let ctab_phys = port.mem_phys + CTAB_OFF;
350    let cmdh_virt = port.mem_virt + CLB_OFF; // slot 0 = first 32 bytes of CLB
351    let ctab_virt = port.mem_virt + CTAB_OFF;
352
353    // SAFETY: cmdh_virt and ctab_virt point to our allocated frame (physically valid)
354    unsafe {
355        // --- Command header (slot 0, 32 bytes) ---
356        let h = cmdh_virt as *mut u8;
357        ptr::write_bytes(h, 0, 32);
358
359        // CFL = 5 (H2D FIS = 20 B = 5 DWORDs); W bit set for writes
360        let flags: u16 = 5u16 | (if write { 1 << 6 } else { 0 });
361        ptr::write_unaligned(h.add(CMDH_FLAGS) as *mut u16, flags.to_le());
362        ptr::write_unaligned(h.add(CMDH_PRDTL) as *mut u16, 1u16.to_le()); // 1 PRDT entry
363        ptr::write_unaligned(
364            h.add(CMDH_CTBA) as *mut u32,
365            (ctab_phys & 0xFFFF_FFFF) as u32,
366        );
367        ptr::write_unaligned(h.add(CMDH_CTBAU) as *mut u32, (ctab_phys >> 32) as u32);
368
369        // --- Command table ---
370        let t = ctab_virt as *mut u8;
371        ptr::write_bytes(t, 0, CTAB_PRDT + 16);
372
373        // H2D Register FIS (20 bytes at CFIS offset)
374        let f = t.add(CTAB_CFIS);
375        *f.add(FIS_TYPE) = FIS_TYPE_H2D;
376        *f.add(FIS_FLAGS) = FIS_C_BIT;
377        *f.add(FIS_CMD) = ata_cmd;
378        *f.add(FIS_LBA0) = (lba & 0xFF) as u8;
379        *f.add(FIS_LBA1) = ((lba >> 8) & 0xFF) as u8;
380        *f.add(FIS_LBA2) = ((lba >> 16) & 0xFF) as u8;
381        *f.add(FIS_DEVICE) = FIS_LBA_MODE; // LBA addressing, device 0
382        *f.add(FIS_LBA3) = ((lba >> 24) & 0xFF) as u8;
383        *f.add(FIS_LBA4) = ((lba >> 32) & 0xFF) as u8;
384        *f.add(FIS_LBA5) = ((lba >> 40) & 0xFF) as u8;
385        *f.add(FIS_CNT_LO) = (count & 0xFF) as u8;
386        *f.add(FIS_CNT_HI) = (count >> 8) as u8;
387
388        // PRDT entry 0 (16 bytes)
389        let p = t.add(CTAB_PRDT);
390        // DBA: physical address of DMA bounce buffer
391        ptr::write_unaligned(p.add(0) as *mut u32, (bounce.phys & 0xFFFF_FFFF) as u32);
392        ptr::write_unaligned(p.add(4) as *mut u32, (bounce.phys >> 32) as u32);
393        ptr::write_unaligned(p.add(8) as *mut u32, 0u32);
394        // DBC: byte_count - 1; bit 31 = interrupt on completion
395        let dbc = ((nbytes as u32).saturating_sub(1)) | (1 << 31);
396        ptr::write_unaligned(p.add(12) as *mut u32, dbc);
397    }
398
399    let idx = port.port_num as usize;
400
401    // Clear any stale completion state before issuing the command
402    PORT_SLOT0_DONE[idx].store(false, Ordering::Release);
403    PORT_SLOT0_ERROR[idx].store(false, Ordering::Release);
404
405    // Issue command in slot 0
406    // SAFETY: MMIO write to PxCI
407    unsafe { wr32(port.port_virt, PORT_CI, 1) };
408
409    // Completion strategy ==========================================================================================================================================================================
410    if crate::process::current_task_id().is_some() {
411        // Task context: block until the IRQ handler signals DONE.
412        // WaitQueue::wait_until() atomically checks the condition under the
413        // waiters SpinLock (which disables IRQs via CLI), so the completion
414        // interrupt cannot be lost between the check and the block.
415        PORT_WQ[idx].wait_until(|| {
416            if PORT_SLOT0_DONE[idx].load(Ordering::Acquire) {
417                // Consume the flag atomically before returning
418                PORT_SLOT0_DONE[idx].store(false, Ordering::Release);
419                Some(())
420            } else {
421                None
422            }
423        });
424
425        // Check whether the IRQ reported an error
426        if PORT_SLOT0_ERROR[idx].load(Ordering::Acquire) {
427            bounce.free();
428            return Err(AhciError::DeviceError);
429        }
430    } else {
431        // Boot context (no task): fall back to busy-poll with ≈5 s timeout.
432        let mut tries = 5_000_000u32;
433        loop {
434            // SAFETY: MMIO reads
435            let ci = unsafe { rd32(port.port_virt, PORT_CI) };
436            let is = unsafe { rd32(port.port_virt, PORT_IS) };
437
438            if is & PXIS_TFES != 0 {
439                // SAFETY: MMIO writes to clear error status
440                unsafe {
441                    wr32(port.port_virt, PORT_IS, 0xFFFF_FFFF);
442                    wr32(port.port_virt, PORT_SERR, 0xFFFF_FFFF);
443                }
444                bounce.free();
445                return Err(AhciError::DeviceError);
446            }
447
448            if ci & 1 == 0 {
449                break; // slot 0 completed
450            }
451
452            tries = tries.saturating_sub(1);
453            if tries == 0 {
454                bounce.free();
455                return Err(AhciError::Timeout);
456            }
457            core::hint::spin_loop();
458        }
459
460        // SAFETY: MMIO write to clear port interrupt status
461        unsafe { wr32(port.port_virt, PORT_IS, 0xFFFF_FFFF) };
462    }
463
464    if !write {
465        // SAFETY: bounce.virt valid, nbytes ≤ allocated
466        unsafe {
467            ptr::copy_nonoverlapping(bounce.virt as *const u8, buf.as_mut_ptr(), nbytes);
468        }
469    }
470
471    bounce.free();
472    Ok(())
473}
474
475// ========== IRQ handler ==============================
476
477/// Called from the IDT AHCI IRQ handler.
478///
479/// Reads `HBA_IS` to find which ports raised an interrupt, reads and clears
480/// `PxIS` per port, then signals the per-port `WaitQueue` so that any task
481/// blocked in `submit_cmd` can resume.
482///
483/// # Safety of concurrent access
484/// All per-port statics (`PORT_SLOT0_DONE`, `PORT_SLOT0_ERROR`, `PORT_WQ`) are
485/// accessed via atomics or briefly-held SpinLocks (which disable IRQs).  The
486/// IRQ handler itself is not re-entrant (x86 APIC level-triggered delivery
487/// ensures this for the same vector).
488pub fn handle_interrupt() {
489    let abar = AHCI_ABAR_VIRT.load(Ordering::Relaxed);
490    if abar == 0 {
491        return; // controller not yet initialised
492    }
493
494    // SAFETY: abar is the MMIO-mapped AHCI base set during init
495    let global_is = unsafe { rd32(abar, HBA_IS) };
496    if global_is == 0 {
497        return; // spurious
498    }
499
500    for port_num in 0..32u8 {
501        if global_is & (1 << port_num) == 0 {
502            continue;
503        }
504
505        let pvirt = PORT_VIRT[port_num as usize].load(Ordering::Relaxed);
506        if pvirt == 0 {
507            continue; // port not in use
508        }
509
510        // SAFETY: pvirt is the valid MMIO address for this port, set during init
511        let pxis = unsafe { rd32(pvirt, PORT_IS) };
512
513        // Determine outcome and record in the error flag
514        if pxis & PXIS_TFES != 0 {
515            PORT_SLOT0_ERROR[port_num as usize].store(true, Ordering::Release);
516            // SAFETY: MMIO writes to clear error state
517            unsafe {
518                wr32(pvirt, PORT_IS, pxis);
519                wr32(pvirt, PORT_SERR, 0xFFFF_FFFF);
520            }
521        } else {
522            PORT_SLOT0_ERROR[port_num as usize].store(false, Ordering::Release);
523            // SAFETY: W1C : write back PxIS to clear all set bits
524            unsafe { wr32(pvirt, PORT_IS, pxis) };
525        }
526
527        // Clear this port's bit in global IS (W1C)
528        // SAFETY: MMIO write to HBA_IS
529        unsafe { wr32(abar, HBA_IS, 1 << port_num) };
530
531        // Signal command completion
532        PORT_SLOT0_DONE[port_num as usize].store(true, Ordering::Release);
533        PORT_WQ[port_num as usize].wake_one();
534    }
535}
536
537// ========== BlockDevice impl for AhciController ========================================================================================================================
538
539impl AhciController {
540    /// Probe and initialise an AHCI controller from the PCI bus.
541    ///
542    /// # Safety
543    /// Must be called once during single-threaded kernel init (MMIO mapping).
544    pub unsafe fn init() -> Result<Self, AhciError> {
545        // AHCI: class=0x01, subclass=0x06 (SATA), prog_if=0x01 (AHCI 1.0)
546        let pci_dev = pci::probe_first(ProbeCriteria {
547            class_code: Some(pci::class::MASS_STORAGE),
548            subclass: Some(pci::storage_subclass::SATA),
549            prog_if: Some(pci::sata_progif::AHCI),
550            ..ProbeCriteria::any()
551        })
552        .ok_or(AhciError::NoController)?;
553
554        log::info!("AHCI: found controller at {:?}", pci_dev.address);
555
556        // Enable bus-mastering and memory-space access (required for DMA)
557        pci_dev.enable_bus_master();
558        pci_dev.enable_memory_space();
559
560        // Read PCI interrupt line before we need it later
561        let irq_line = pci_dev.read_config_u8(pci::config::INTERRUPT_LINE);
562
563        // BAR5 = ABAR (AHCI Base Memory Register)
564        let abar_phys = pci_dev.read_bar_raw(5).ok_or(AhciError::BadAbar)?;
565        if abar_phys == 0 {
566            return Err(AhciError::BadAbar);
567        }
568
569        // Map the entire HBA register space (0x100 + 32 ports * 0x80 = 0x1100 bytes)
570        crate::memory::paging::ensure_identity_map_range(abar_phys, 0x1200);
571        let abar_virt = phys_to_virt(abar_phys);
572
573        // SAFETY: abar_virt is now a mapped MMIO virtual address
574        // Enable AHCI mode
575        let ghc = rd32(abar_virt, HBA_GHC);
576        if ghc & GHC_AE == 0 {
577            wr32(abar_virt, HBA_GHC, ghc | GHC_AE);
578        }
579
580        // Perform HBA reset sequence (HR), then re-enable AHCI.
581        let mut ghc_after = rd32(abar_virt, HBA_GHC) | GHC_AE;
582        wr32(abar_virt, HBA_GHC, ghc_after | GHC_HR);
583        let mut reset_timeout = 1_000_000u32;
584        while rd32(abar_virt, HBA_GHC) & GHC_HR != 0 {
585            reset_timeout = reset_timeout.saturating_sub(1);
586            if reset_timeout == 0 {
587                log::warn!("AHCI: HBA reset timed out, continuing with current state");
588                break;
589            }
590            core::hint::spin_loop();
591        }
592        ghc_after = rd32(abar_virt, HBA_GHC) | GHC_AE;
593        wr32(abar_virt, HBA_GHC, ghc_after);
594        // Clear global pending interrupts before per-port setup.
595        wr32(abar_virt, HBA_IS, 0xFFFF_FFFF);
596
597        log::debug!(
598            "AHCI: ABAR phys={:#x} virt={:#x}  GHC={:#010x}",
599            abar_phys,
600            abar_virt,
601            rd32(abar_virt, HBA_GHC)
602        );
603
604        let pi = rd32(abar_virt, HBA_PI); // bitmask of implemented ports
605        log::debug!("AHCI: ports implemented mask = {:#010x}", pi);
606
607        let mut ports: Vec<AhciPort> = Vec::new();
608
609        for port_num in 0..32u8 {
610            if pi & (1 << port_num) == 0 {
611                continue;
612            }
613
614            let pvirt = abar_virt + 0x100 + (port_num as u64) * 0x80;
615
616            // Check DET: only accept DET=3 (device present + communication)
617            let ssts = rd32(pvirt, PORT_SSTS);
618            let det = ssts & SSTS_DET_MASK;
619            if det != SSTS_DET_COMM {
620                log::debug!("AHCI: port {} DET={} : no device, skipping", port_num, det);
621                continue;
622            }
623
624            // Only handle plain SATA (signature 0x00000101)
625            let sig = rd32(pvirt, PORT_SIG);
626            if sig != SIG_SATA {
627                log::debug!(
628                    "AHCI: port {} sig={:#010x} : not plain SATA, skipping",
629                    port_num,
630                    sig
631                );
632                continue;
633            }
634
635            // Allocate one 4 KB frame for CLB + FIS + CTAB
636            let frame = crate::sync::with_irqs_disabled(|token| memory::allocate_frame(token))
637                .map_err(|_| AhciError::Alloc)?;
638
639            let mem_phys = frame.start_address.as_u64();
640            let mem_virt = phys_to_virt(mem_phys);
641
642            // Zero the frame so HBA sees clean structures
643            // SAFETY: mem_virt is valid HHDM-mapped physical memory, 4096 bytes
644            ptr::write_bytes(mem_virt as *mut u8, 0, 4096);
645
646            port_rebase(pvirt, mem_phys);
647
648            // Enable per-port interrupts (DHRE + TFEE)
649            port_enable_irq(pvirt);
650
651            // Register this port's MMIO address in the per-port static table
652            // so the IRQ handler can access it without holding the controller lock.
653            PORT_VIRT[port_num as usize].store(pvirt, Ordering::Relaxed);
654
655            // Identify device to read sector count
656            let mut port = AhciPort {
657                port_num,
658                port_virt: pvirt,
659                mem_phys,
660                mem_virt,
661                sector_count: 0,
662            };
663
664            let mut id_buf = [0u8; SECTOR_SIZE];
665            match submit_cmd(&port, 0, 1, &mut id_buf, false, ATA_IDENTIFY) {
666                Ok(()) => {
667                    // Words 100-103 (bytes 200-207): 48-bit LBA native max address
668                    let w0 = u16::from_le_bytes([id_buf[200], id_buf[201]]) as u64;
669                    let w1 = u16::from_le_bytes([id_buf[202], id_buf[203]]) as u64;
670                    let w2 = u16::from_le_bytes([id_buf[204], id_buf[205]]) as u64;
671                    let w3 = u16::from_le_bytes([id_buf[206], id_buf[207]]) as u64;
672                    port.sector_count = w0 | (w1 << 16) | (w2 << 32) | (w3 << 48);
673                    log::info!(
674                        "AHCI: port {} SATA : {} sectors ({} MiB)",
675                        port_num,
676                        port.sector_count,
677                        (port.sector_count * SECTOR_SIZE as u64) / (1024 * 1024)
678                    );
679                }
680                Err(e) => {
681                    log::warn!("AHCI: port {} IDENTIFY failed: {}", port_num, e);
682                }
683            }
684
685            ports.push(port);
686        }
687
688        if ports.is_empty() {
689            return Err(AhciError::NoPort);
690        }
691
692        // Store the ABAR virtual address and IRQ line in statics so the
693        // interrupt handler can reach them without going through the controller lock.
694        AHCI_ABAR_VIRT.store(abar_virt, Ordering::Relaxed);
695        AHCI_IRQ_LINE.store(irq_line, Ordering::Relaxed);
696
697        // Enable global HBA interrupts (GHC.IE)
698        // SAFETY: MMIO write : all port interrupts already enabled above
699        let ghc = rd32(abar_virt, HBA_GHC);
700        wr32(abar_virt, HBA_GHC, ghc | GHC_IE);
701
702        log::info!("AHCI: global interrupts enabled (IRQ line {})", irq_line);
703
704        Ok(AhciController { abar_virt, ports })
705    }
706
707    /// Return sector count of the first port.
708    pub fn sector_count(&self) -> u64 {
709        self.ports.first().map(|p| p.sector_count).unwrap_or(0)
710    }
711
712    /// Performs the first port operation.
713    fn first_port(&self) -> Option<&AhciPort> {
714        self.ports.first()
715    }
716}
717
718impl BlockDevice for AhciController {
719    /// Reads sector.
720    fn read_sector(&self, sector: u64, buf: &mut [u8]) -> Result<(), BlockError> {
721        let port = self.first_port().ok_or(BlockError::NotReady)?;
722        if sector >= port.sector_count {
723            return Err(BlockError::InvalidSector);
724        }
725        if buf.len() < SECTOR_SIZE {
726            return Err(BlockError::BufferTooSmall);
727        }
728        submit_cmd(port, sector, 1, buf, false, ATA_READ_DMA_EXT).map_err(|_| BlockError::IoError)
729    }
730
731    /// Writes sector.
732    fn write_sector(&self, sector: u64, buf: &[u8]) -> Result<(), BlockError> {
733        let port = self.first_port().ok_or(BlockError::NotReady)?;
734        if sector >= port.sector_count {
735            return Err(BlockError::InvalidSector);
736        }
737        if buf.len() < SECTOR_SIZE {
738            return Err(BlockError::BufferTooSmall);
739        }
740        // submit_cmd needs &mut [u8]; copy to a mutable staging buffer
741        let mut tmp = [0u8; SECTOR_SIZE];
742        tmp.copy_from_slice(&buf[..SECTOR_SIZE]);
743        submit_cmd(port, sector, 1, &mut tmp, true, ATA_WRITE_DMA_EXT)
744            .map_err(|_| BlockError::IoError)
745    }
746
747    /// Performs the sector count operation.
748    fn sector_count(&self) -> u64 {
749        self.sector_count()
750    }
751}
752
753// ========== Global singleton + public API ============================================================================================================================================
754
755static AHCI: SpinLock<Option<Box<AhciController>>> = SpinLock::new(None);
756
757/// Scan the PCI bus for an AHCI controller and initialise it.
758///
759/// Called once during kernel boot from `hardware::init()`.
760pub fn init() {
761    log::info!("AHCI: scanning PCI bus...");
762
763    match unsafe { AhciController::init() } {
764        Ok(ctrl) => {
765            *AHCI.lock() = Some(Box::new(ctrl));
766            log::info!("AHCI: controller ready");
767
768            // Register IRQ handler in the IDT now that the controller is live.
769            let irq = AHCI_IRQ_LINE.load(Ordering::Relaxed);
770            crate::arch::x86_64::idt::register_ahci_irq(irq);
771        }
772        Err(AhciError::NoController) => {
773            log::info!("AHCI: no controller found (not a SATA system?)");
774        }
775        Err(e) => {
776            log::error!("AHCI: init failed: {}", e);
777        }
778    }
779}
780
781/// Return a reference to the first usable AHCI controller, if any.
782pub fn get_device() -> Option<&'static AhciController> {
783    // SAFETY: the global Option is only ever set during init and never cleared
784    unsafe {
785        let lock = AHCI.lock();
786        lock.as_ref().map(|b| {
787            let ptr = b.as_ref() as *const AhciController;
788            &*ptr
789        })
790    }
791}