Skip to main content

strat9_kernel/hardware/storage/
ahci.rs

1//! AHCI (Advanced Host Controller Interface) driver — AHCI spec 1.3.1
2//!
3//! PCI: class=0x01 (Mass Storage), subclass=0x06 (SATA), prog_if=0x01
4//! MMIO base: BAR5 (ABAR)
5//!
6//! Per-port memory layout (packed into one 4 KB page):
7//!   [0x000..0x3FF]  Command List   (1024 B, 32 × 32-byte headers)
8//!   [0x400..0x4FF]  FIS receive    (256 B)
9//!   [0x500..0x5FF]  Command table  (128 B header + 1 × 16-byte PRDT)
10//!
11//! ## IRQ-driven completion (DRV-02 v2)
12//!
13//! When a task context exists (`current_task_id()` is `Some`), commands are
14//! completed via interrupt + `WaitQueue::wait_until()` so the issuing task
15//! blocks without busy-spinning.  During early boot (no task yet) the legacy
16//! polling path is used as a fallback.
17//!
18//! Per-port statics (indexed by `port_num 0..32`) are used so the IRQ handler
19//! can signal completion without acquiring any slow lock:
20//!   - `PORT_VIRT[n]`       — MMIO virtual address of port n registers
21//!   - `PORT_SLOT0_DONE[n]` — set by IRQ handler when slot-0 completes
22//!   - `PORT_SLOT0_ERROR[n]`— set by IRQ handler when a task-file error fires
23//!   - `PORT_WQ[n]`         — WaitQueue; issuing task blocks here
24
25use crate::{
26    hardware::pci_client::{self as pci, ProbeCriteria},
27    memory::{self, phys_to_virt, PhysFrame},
28    sync::{SpinLock, WaitQueue},
29};
30use alloc::{boxed::Box, vec::Vec};
31use core::{
32    ptr,
33    sync::atomic::{AtomicBool, AtomicU64, AtomicU8, Ordering},
34};
35
36pub use super::virtio_block::{BlockDevice, BlockError, SECTOR_SIZE};
37
38// ─── HBA generic registers (at ABAR) ─────────────────────────────────────────
39const HBA_GHC: u64 = 0x04;
40const HBA_IS: u64 = 0x08;
41const HBA_PI: u64 = 0x0C;
42
43const GHC_AE: u32 = 1 << 31; // AHCI Enable
44const GHC_IE: u32 = 1 << 1; // Global Interrupt Enable
45const GHC_HR: u32 = 1 << 0; // HBA Reset
46
47// ─── Port register offsets (relative to port base = ABAR + 0x100 + n*0x80) ──
48const PORT_CLB: u64 = 0x00;
49const PORT_CLBU: u64 = 0x04;
50const PORT_FB: u64 = 0x08;
51const PORT_FBU: u64 = 0x0C;
52const PORT_IS: u64 = 0x10;
53const PORT_IE: u64 = 0x14; // PxIE — port interrupt enable
54const PORT_CMD: u64 = 0x18;
55const PORT_TFD: u64 = 0x20;
56const PORT_SIG: u64 = 0x24;
57const PORT_SSTS: u64 = 0x28;
58const PORT_SERR: u64 = 0x30;
59const PORT_CI: u64 = 0x38;
60
61const CMD_ST: u32 = 1 << 0; // Start
62const CMD_FRE: u32 = 1 << 4; // FIS Receive Enable
63const CMD_FR: u32 = 1 << 14; // FIS Receive Running
64const CMD_CR: u32 = 1 << 15; // Command List Running
65
66const TFD_BSY: u32 = 1 << 7;
67const TFD_DRQ: u32 = 1 << 3;
68
69const SSTS_DET_COMM: u32 = 3;
70const SSTS_DET_MASK: u32 = 0xF;
71
72const SIG_SATA: u32 = 0x0000_0101;
73
74// PxIE bits
75const PXIE_DHRE: u32 = 1 << 0; // D2H Register FIS Received Enable (normal DMA completion)
76const PXIE_TFEE: u32 = 1 << 30; // Task File Error Enable
77
78// ─── Per-port memory layout offsets ──────────────────────────────────────────
79const CLB_OFF: u64 = 0x000; // Command List (1024 B)
80const FB_OFF: u64 = 0x400; // FIS buffer   (256 B)
81const CTAB_OFF: u64 = 0x500; // Command Table (128 B header + 16 B PRDT)
82
83// Command header field byte offsets within a 32-byte slot
84const CMDH_FLAGS: usize = 0; // u16: cfl[4:0] | a | w | p | r | b | c
85const CMDH_PRDTL: usize = 2; // u16
86const CMDH_CTBA: usize = 8; // u32
87const CMDH_CTBAU: usize = 12; // u32
88
89// Command table FIS and PRDT offsets
90const CTAB_CFIS: usize = 0x00; // H2D FIS (64 B allocated)
91const CTAB_PRDT: usize = 0x80; // PRDT entries
92
93// H2D FIS field offsets (FIS type 0x27, Register Host-to-Device)
94const FIS_TYPE: usize = 0;
95const FIS_FLAGS: usize = 1; // PM port [3:0] | C [7]
96const FIS_CMD: usize = 2;
97const FIS_LBA0: usize = 4;
98const FIS_LBA1: usize = 5;
99const FIS_LBA2: usize = 6;
100const FIS_DEVICE: usize = 7;
101const FIS_LBA3: usize = 8;
102const FIS_LBA4: usize = 9;
103const FIS_LBA5: usize = 10;
104const FIS_CNT_LO: usize = 12;
105const FIS_CNT_HI: usize = 13;
106
107const FIS_TYPE_H2D: u8 = 0x27;
108const FIS_C_BIT: u8 = 0x80; // command (not control)
109const FIS_LBA_MODE: u8 = 1 << 6;
110
111// ATA commands (48-bit LBA)
112const ATA_IDENTIFY: u8 = 0xEC;
113const ATA_READ_DMA_EXT: u8 = 0x25;
114const ATA_WRITE_DMA_EXT: u8 = 0x35;
115
116// PxIS bit 30 = Task File Error Status
117const PXIS_TFES: u32 = 1 << 30;
118
119// ─── Error type ──────────────────────────────────────────────────────────────
120
121#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
122pub enum AhciError {
123    #[error("no AHCI controller on PCI bus")]
124    NoController,
125    #[error("invalid BAR5 (ABAR)")]
126    BadAbar,
127    #[error("physical memory allocation failed")]
128    Alloc,
129    #[error("port BSY/DRQ set")]
130    Busy,
131    #[error("command timed out")]
132    Timeout,
133    #[error("device reported task-file error")]
134    DeviceError,
135    #[error("invalid sector number")]
136    InvalidSector,
137    #[error("buffer too small (need ≥ SECTOR_SIZE bytes)")]
138    BufferTooSmall,
139    #[error("no usable SATA port found")]
140    NoPort,
141}
142
143// ─── Internal port handle ─────────────────────────────────────────────────────
144
145struct AhciPort {
146    port_num: u8,
147    port_virt: u64, // virtual address of port registers
148    mem_phys: u64,  // physical base of the per-port CLB/FB/CTAB frame
149    mem_virt: u64,  // HHDM virtual address of that frame
150    sector_count: u64,
151}
152
153// ─── Controller ──────────────────────────────────────────────────────────────
154
155pub struct AhciController {
156    #[allow(dead_code)]
157    abar_virt: u64,
158    ports: Vec<AhciPort>,
159}
160
161// SAFETY: AhciController is only accessed behind SpinLock<Option<...>>
162unsafe impl Send for AhciController {}
163unsafe impl Sync for AhciController {}
164
165// ─── Per-port IRQ completion state ───────────────────────────────────────────
166// These statics are accessed from the IRQ handler without locks.
167// Indexed by port_num (0..32).
168
169/// AHCI ABAR virtual address — written once during init, read by IRQ handler.
170static AHCI_ABAR_VIRT: AtomicU64 = AtomicU64::new(0);
171
172/// PCI interrupt line used by this controller.
173pub static AHCI_IRQ_LINE: AtomicU8 = AtomicU8::new(0xFF);
174
175/// Per-port MMIO virtual addresses — written once during init.
176static PORT_VIRT: [AtomicU64; 32] = {
177    const INIT: AtomicU64 = AtomicU64::new(0);
178    [INIT; 32]
179};
180
181/// Per-port slot-0 completion flags — set by IRQ handler, cleared by consumer.
182static PORT_SLOT0_DONE: [AtomicBool; 32] = {
183    const INIT: AtomicBool = AtomicBool::new(false);
184    [INIT; 32]
185};
186
187/// Per-port slot-0 error flags — set by IRQ handler on task-file error.
188static PORT_SLOT0_ERROR: [AtomicBool; 32] = {
189    const INIT: AtomicBool = AtomicBool::new(false);
190    [INIT; 32]
191};
192
193/// Per-port wait queues — tasks block here while waiting for IRQ completion.
194static PORT_WQ: [WaitQueue; 32] = {
195    const INIT: WaitQueue = WaitQueue::new();
196    [INIT; 32]
197};
198
199// ─── MMIO helpers ─────────────────────────────────────────────────────────────
200
201/// Performs the rd32 operation.
202#[inline]
203unsafe fn rd32(base: u64, off: u64) -> u32 {
204    ptr::read_volatile((base + off) as *const u32)
205}
206
207/// Performs the wr32 operation.
208#[inline]
209unsafe fn wr32(base: u64, off: u64, val: u32) {
210    ptr::write_volatile((base + off) as *mut u32, val);
211}
212
213// ─── Port start/stop ──────────────────────────────────────────────────────────
214
215/// Performs the port stop operation.
216fn port_stop(pvirt: u64) {
217    // SAFETY: pvirt is a valid MMIO virtual address for this port's registers
218    unsafe {
219        let mut cmd = rd32(pvirt, PORT_CMD);
220        cmd &= !(CMD_ST | CMD_FRE);
221        wr32(pvirt, PORT_CMD, cmd);
222        // Spec mandates waiting ≤ 500 ms for FR and CR to clear
223        for _ in 0..500_000u32 {
224            if rd32(pvirt, PORT_CMD) & (CMD_FR | CMD_CR) == 0 {
225                return;
226            }
227            core::hint::spin_loop();
228        }
229        log::warn!("AHCI: port stop timed out (port registers @ {:#x})", pvirt);
230    }
231}
232
233/// Performs the port start operation.
234fn port_start(pvirt: u64) {
235    // SAFETY: pvirt is a valid MMIO virtual address
236    unsafe {
237        // Ensure Command List Running is clear before asserting ST
238        let mut timeout = 500_000u32;
239        while rd32(pvirt, PORT_CMD) & CMD_CR != 0 {
240            timeout = timeout.saturating_sub(1);
241            if timeout == 0 {
242                log::warn!("AHCI: port start wait for CR clear timed out");
243                break;
244            }
245            core::hint::spin_loop();
246        }
247        let mut cmd = rd32(pvirt, PORT_CMD);
248        cmd |= CMD_FRE | CMD_ST;
249        wr32(pvirt, PORT_CMD, cmd);
250    }
251}
252
253/// Rebase port: assign our CLB/FB buffers then start the port.
254fn port_rebase(pvirt: u64, phys: u64) {
255    port_stop(pvirt);
256    // SAFETY: pvirt valid MMIO; phys is our allocated frame
257    unsafe {
258        let clb = phys + CLB_OFF;
259        let fb = phys + FB_OFF;
260        wr32(pvirt, PORT_CLB, (clb & 0xFFFF_FFFF) as u32);
261        wr32(pvirt, PORT_CLBU, (clb >> 32) as u32);
262        wr32(pvirt, PORT_FB, (fb & 0xFFFF_FFFF) as u32);
263        wr32(pvirt, PORT_FBU, (fb >> 32) as u32);
264        // Clear any stale interrupt/error status
265        wr32(pvirt, PORT_IS, 0xFFFF_FFFF);
266        wr32(pvirt, PORT_SERR, 0xFFFF_FFFF);
267    }
268    port_start(pvirt);
269}
270
271/// Enable interrupts for a port (DHRE = DMA completion, TFEE = errors).
272fn port_enable_irq(pvirt: u64) {
273    // SAFETY: pvirt is a valid MMIO port register address
274    unsafe {
275        wr32(pvirt, PORT_IE, PXIE_DHRE | PXIE_TFEE);
276    }
277}
278
279// ─── Bounce-buffer management ─────────────────────────────────────────────────
280
281struct Bounce {
282    frame: PhysFrame,
283    order: u8,
284    phys: u64,
285    virt: u64,
286}
287
288impl Bounce {
289    /// Performs the alloc operation.
290    fn alloc(bytes: usize) -> Result<Self, AhciError> {
291        let pages = (bytes + 4095) / 4096;
292        let order = pages.next_power_of_two().trailing_zeros() as u8;
293        let frame = crate::sync::with_irqs_disabled(|token| {
294            memory::allocate_frames(token, order)
295        })
296        .map_err(|_| AhciError::Alloc)?;
297        let phys = frame.start_address.as_u64();
298        Ok(Self {
299            frame,
300            order,
301            phys,
302            virt: phys_to_virt(phys),
303        })
304    }
305
306    /// Performs the free operation.
307    fn free(self) {
308        crate::sync::with_irqs_disabled(|token| {
309            memory::free_frames(token, self.frame, self.order);
310        });
311    }
312}
313
314// ─── Command submission ───────────────────────────────────────────────────────
315//
316// Two completion strategies:
317//   1. Task context (current_task_id() is Some): IRQ + WaitQueue — the issuing
318//      task is blocked by the scheduler until the IRQ fires and wakes it.
319//   2. Boot context (no task yet): legacy busy-poll with timeout.
320
321/// Performs the submit cmd operation.
322fn submit_cmd(
323    port: &AhciPort,
324    lba: u64,
325    count: u16,
326    buf: &mut [u8],
327    write: bool,
328    ata_cmd: u8,
329) -> Result<(), AhciError> {
330    let nbytes = (count as usize) * SECTOR_SIZE;
331    if buf.len() < nbytes {
332        return Err(AhciError::BufferTooSmall);
333    }
334
335    // SAFETY: MMIO read to check device readiness
336    let tfd = unsafe { rd32(port.port_virt, PORT_TFD) };
337    if tfd & (TFD_BSY | TFD_DRQ) != 0 {
338        return Err(AhciError::Busy);
339    }
340
341    let bounce = Bounce::alloc(nbytes)?;
342
343    if write {
344        // SAFETY: bounce.virt is a valid HHDM address ≥ nbytes; buf.len() ≥ nbytes
345        unsafe {
346            ptr::copy_nonoverlapping(buf.as_ptr(), bounce.virt as *mut u8, nbytes);
347        }
348    }
349
350    let ctab_phys = port.mem_phys + CTAB_OFF;
351    let cmdh_virt = port.mem_virt + CLB_OFF; // slot 0 = first 32 bytes of CLB
352    let ctab_virt = port.mem_virt + CTAB_OFF;
353
354    // SAFETY: cmdh_virt and ctab_virt point to our allocated frame (physically valid)
355    unsafe {
356        // --- Command header (slot 0, 32 bytes) ---
357        let h = cmdh_virt as *mut u8;
358        ptr::write_bytes(h, 0, 32);
359
360        // CFL = 5 (H2D FIS = 20 B = 5 DWORDs); W bit set for writes
361        let flags: u16 = 5u16 | (if write { 1 << 6 } else { 0 });
362        ptr::write_unaligned(h.add(CMDH_FLAGS) as *mut u16, flags.to_le());
363        ptr::write_unaligned(h.add(CMDH_PRDTL) as *mut u16, 1u16.to_le()); // 1 PRDT entry
364        ptr::write_unaligned(
365            h.add(CMDH_CTBA) as *mut u32,
366            (ctab_phys & 0xFFFF_FFFF) as u32,
367        );
368        ptr::write_unaligned(h.add(CMDH_CTBAU) as *mut u32, (ctab_phys >> 32) as u32);
369
370        // --- Command table ---
371        let t = ctab_virt as *mut u8;
372        ptr::write_bytes(t, 0, CTAB_PRDT + 16);
373
374        // H2D Register FIS (20 bytes at CFIS offset)
375        let f = t.add(CTAB_CFIS);
376        *f.add(FIS_TYPE) = FIS_TYPE_H2D;
377        *f.add(FIS_FLAGS) = FIS_C_BIT;
378        *f.add(FIS_CMD) = ata_cmd;
379        *f.add(FIS_LBA0) = (lba & 0xFF) as u8;
380        *f.add(FIS_LBA1) = ((lba >> 8) & 0xFF) as u8;
381        *f.add(FIS_LBA2) = ((lba >> 16) & 0xFF) as u8;
382        *f.add(FIS_DEVICE) = FIS_LBA_MODE; // LBA addressing, device 0
383        *f.add(FIS_LBA3) = ((lba >> 24) & 0xFF) as u8;
384        *f.add(FIS_LBA4) = ((lba >> 32) & 0xFF) as u8;
385        *f.add(FIS_LBA5) = ((lba >> 40) & 0xFF) as u8;
386        *f.add(FIS_CNT_LO) = (count & 0xFF) as u8;
387        *f.add(FIS_CNT_HI) = (count >> 8) as u8;
388
389        // PRDT entry 0 (16 bytes)
390        let p = t.add(CTAB_PRDT);
391        // DBA: physical address of DMA bounce buffer
392        ptr::write_unaligned(p.add(0) as *mut u32, (bounce.phys & 0xFFFF_FFFF) as u32);
393        ptr::write_unaligned(p.add(4) as *mut u32, (bounce.phys >> 32) as u32);
394        ptr::write_unaligned(p.add(8) as *mut u32, 0u32);
395        // DBC: byte_count - 1; bit 31 = interrupt on completion
396        let dbc = ((nbytes as u32).saturating_sub(1)) | (1 << 31);
397        ptr::write_unaligned(p.add(12) as *mut u32, dbc);
398    }
399
400    let idx = port.port_num as usize;
401
402    // Clear any stale completion state before issuing the command
403    PORT_SLOT0_DONE[idx].store(false, Ordering::Release);
404    PORT_SLOT0_ERROR[idx].store(false, Ordering::Release);
405
406    // Issue command in slot 0
407    // SAFETY: MMIO write to PxCI
408    unsafe { wr32(port.port_virt, PORT_CI, 1) };
409
410    // ── Completion strategy ────────────────────────────────────────────────────
411    if crate::process::current_task_id().is_some() {
412        // Task context: block until the IRQ handler signals DONE.
413        // WaitQueue::wait_until() atomically checks the condition under the
414        // waiters SpinLock (which disables IRQs via CLI), so the completion
415        // interrupt cannot be lost between the check and the block.
416        PORT_WQ[idx].wait_until(|| {
417            if PORT_SLOT0_DONE[idx].load(Ordering::Acquire) {
418                // Consume the flag atomically before returning
419                PORT_SLOT0_DONE[idx].store(false, Ordering::Release);
420                Some(())
421            } else {
422                None
423            }
424        });
425
426        // Check whether the IRQ reported an error
427        if PORT_SLOT0_ERROR[idx].load(Ordering::Acquire) {
428            bounce.free();
429            return Err(AhciError::DeviceError);
430        }
431    } else {
432        // Boot context (no task): fall back to busy-poll with ≈5 s timeout.
433        let mut tries = 5_000_000u32;
434        loop {
435            // SAFETY: MMIO reads
436            let ci = unsafe { rd32(port.port_virt, PORT_CI) };
437            let is = unsafe { rd32(port.port_virt, PORT_IS) };
438
439            if is & PXIS_TFES != 0 {
440                // SAFETY: MMIO writes to clear error status
441                unsafe {
442                    wr32(port.port_virt, PORT_IS, 0xFFFF_FFFF);
443                    wr32(port.port_virt, PORT_SERR, 0xFFFF_FFFF);
444                }
445                bounce.free();
446                return Err(AhciError::DeviceError);
447            }
448
449            if ci & 1 == 0 {
450                break; // slot 0 completed
451            }
452
453            tries = tries.saturating_sub(1);
454            if tries == 0 {
455                bounce.free();
456                return Err(AhciError::Timeout);
457            }
458            core::hint::spin_loop();
459        }
460
461        // SAFETY: MMIO write to clear port interrupt status
462        unsafe { wr32(port.port_virt, PORT_IS, 0xFFFF_FFFF) };
463    }
464
465    if !write {
466        // SAFETY: bounce.virt valid, nbytes ≤ allocated
467        unsafe {
468            ptr::copy_nonoverlapping(bounce.virt as *const u8, buf.as_mut_ptr(), nbytes);
469        }
470    }
471
472    bounce.free();
473    Ok(())
474}
475
476// ─── IRQ handler ─────────────────────────────────────────────────────────────
477
478/// Called from the IDT AHCI IRQ handler.
479///
480/// Reads `HBA_IS` to find which ports raised an interrupt, reads and clears
481/// `PxIS` per port, then signals the per-port `WaitQueue` so that any task
482/// blocked in `submit_cmd` can resume.
483///
484/// # Safety of concurrent access
485/// All per-port statics (`PORT_SLOT0_DONE`, `PORT_SLOT0_ERROR`, `PORT_WQ`) are
486/// accessed via atomics or briefly-held SpinLocks (which disable IRQs).  The
487/// IRQ handler itself is not re-entrant (x86 APIC level-triggered delivery
488/// ensures this for the same vector).
489pub fn handle_interrupt() {
490    let abar = AHCI_ABAR_VIRT.load(Ordering::Relaxed);
491    if abar == 0 {
492        return; // controller not yet initialised
493    }
494
495    // SAFETY: abar is the MMIO-mapped AHCI base set during init
496    let global_is = unsafe { rd32(abar, HBA_IS) };
497    if global_is == 0 {
498        return; // spurious
499    }
500
501    for port_num in 0..32u8 {
502        if global_is & (1 << port_num) == 0 {
503            continue;
504        }
505
506        let pvirt = PORT_VIRT[port_num as usize].load(Ordering::Relaxed);
507        if pvirt == 0 {
508            continue; // port not in use
509        }
510
511        // SAFETY: pvirt is the valid MMIO address for this port, set during init
512        let pxis = unsafe { rd32(pvirt, PORT_IS) };
513
514        // Determine outcome and record in the error flag
515        if pxis & PXIS_TFES != 0 {
516            PORT_SLOT0_ERROR[port_num as usize].store(true, Ordering::Release);
517            // SAFETY: MMIO writes to clear error state
518            unsafe {
519                wr32(pvirt, PORT_IS, pxis);
520                wr32(pvirt, PORT_SERR, 0xFFFF_FFFF);
521            }
522        } else {
523            PORT_SLOT0_ERROR[port_num as usize].store(false, Ordering::Release);
524            // SAFETY: W1C — write back PxIS to clear all set bits
525            unsafe { wr32(pvirt, PORT_IS, pxis) };
526        }
527
528        // Clear this port's bit in global IS (W1C)
529        // SAFETY: MMIO write to HBA_IS
530        unsafe { wr32(abar, HBA_IS, 1 << port_num) };
531
532        // Signal command completion
533        PORT_SLOT0_DONE[port_num as usize].store(true, Ordering::Release);
534        PORT_WQ[port_num as usize].wake_one();
535    }
536}
537
538// ─── BlockDevice impl for AhciController ─────────────────────────────────────
539
540impl AhciController {
541    /// Probe and initialise an AHCI controller from the PCI bus.
542    ///
543    /// # Safety
544    /// Must be called once during single-threaded kernel init (MMIO mapping).
545    pub unsafe fn init() -> Result<Self, AhciError> {
546        // AHCI: class=0x01, subclass=0x06 (SATA), prog_if=0x01 (AHCI 1.0)
547        let pci_dev = pci::probe_first(ProbeCriteria {
548            class_code: Some(pci::class::MASS_STORAGE),
549            subclass: Some(pci::storage_subclass::SATA),
550            prog_if: Some(pci::sata_progif::AHCI),
551            ..ProbeCriteria::any()
552        })
553        .ok_or(AhciError::NoController)?;
554
555        log::info!("AHCI: found controller at {:?}", pci_dev.address);
556
557        // Enable bus-mastering and memory-space access (required for DMA)
558        pci_dev.enable_bus_master();
559        pci_dev.enable_memory_space();
560
561        // Read PCI interrupt line before we need it later
562        let irq_line = pci_dev.read_config_u8(pci::config::INTERRUPT_LINE);
563
564        // BAR5 = ABAR (AHCI Base Memory Register)
565        let abar_phys = pci_dev.read_bar_raw(5).ok_or(AhciError::BadAbar)?;
566        if abar_phys == 0 {
567            return Err(AhciError::BadAbar);
568        }
569
570        // Map the entire HBA register space (0x100 + 32 ports * 0x80 = 0x1100 bytes)
571        crate::memory::paging::ensure_identity_map_range(abar_phys, 0x1200);
572        let abar_virt = phys_to_virt(abar_phys);
573
574        // SAFETY: abar_virt is now a mapped MMIO virtual address
575        // Enable AHCI mode
576        let ghc = rd32(abar_virt, HBA_GHC);
577        if ghc & GHC_AE == 0 {
578            wr32(abar_virt, HBA_GHC, ghc | GHC_AE);
579        }
580
581        // Perform HBA reset sequence (HR), then re-enable AHCI.
582        let mut ghc_after = rd32(abar_virt, HBA_GHC) | GHC_AE;
583        wr32(abar_virt, HBA_GHC, ghc_after | GHC_HR);
584        let mut reset_timeout = 1_000_000u32;
585        while rd32(abar_virt, HBA_GHC) & GHC_HR != 0 {
586            reset_timeout = reset_timeout.saturating_sub(1);
587            if reset_timeout == 0 {
588                log::warn!("AHCI: HBA reset timed out, continuing with current state");
589                break;
590            }
591            core::hint::spin_loop();
592        }
593        ghc_after = rd32(abar_virt, HBA_GHC) | GHC_AE;
594        wr32(abar_virt, HBA_GHC, ghc_after);
595        // Clear global pending interrupts before per-port setup.
596        wr32(abar_virt, HBA_IS, 0xFFFF_FFFF);
597
598        log::debug!(
599            "AHCI: ABAR phys={:#x} virt={:#x}  GHC={:#010x}",
600            abar_phys,
601            abar_virt,
602            rd32(abar_virt, HBA_GHC)
603        );
604
605        let pi = rd32(abar_virt, HBA_PI); // bitmask of implemented ports
606        log::debug!("AHCI: ports implemented mask = {:#010x}", pi);
607
608        let mut ports: Vec<AhciPort> = Vec::new();
609
610        for port_num in 0..32u8 {
611            if pi & (1 << port_num) == 0 {
612                continue;
613            }
614
615            let pvirt = abar_virt + 0x100 + (port_num as u64) * 0x80;
616
617            // Check DET: only accept DET=3 (device present + communication)
618            let ssts = rd32(pvirt, PORT_SSTS);
619            let det = ssts & SSTS_DET_MASK;
620            if det != SSTS_DET_COMM {
621                log::debug!("AHCI: port {} DET={} — no device, skipping", port_num, det);
622                continue;
623            }
624
625            // Only handle plain SATA (signature 0x00000101)
626            let sig = rd32(pvirt, PORT_SIG);
627            if sig != SIG_SATA {
628                log::debug!(
629                    "AHCI: port {} sig={:#010x} — not plain SATA, skipping",
630                    port_num,
631                    sig
632                );
633                continue;
634            }
635
636            // Allocate one 4 KB frame for CLB + FIS + CTAB
637            let frame = crate::sync::with_irqs_disabled(|token| {
638                memory::allocate_frame(token)
639            })
640            .map_err(|_| AhciError::Alloc)?;
641
642            let mem_phys = frame.start_address.as_u64();
643            let mem_virt = phys_to_virt(mem_phys);
644
645            // Zero the frame so HBA sees clean structures
646            // SAFETY: mem_virt is valid HHDM-mapped physical memory, 4096 bytes
647            ptr::write_bytes(mem_virt as *mut u8, 0, 4096);
648
649            port_rebase(pvirt, mem_phys);
650
651            // Enable per-port interrupts (DHRE + TFEE)
652            port_enable_irq(pvirt);
653
654            // Register this port's MMIO address in the per-port static table
655            // so the IRQ handler can access it without holding the controller lock.
656            PORT_VIRT[port_num as usize].store(pvirt, Ordering::Relaxed);
657
658            // Identify device to read sector count
659            let mut port = AhciPort {
660                port_num,
661                port_virt: pvirt,
662                mem_phys,
663                mem_virt,
664                sector_count: 0,
665            };
666
667            let mut id_buf = [0u8; SECTOR_SIZE];
668            match submit_cmd(&port, 0, 1, &mut id_buf, false, ATA_IDENTIFY) {
669                Ok(()) => {
670                    // Words 100-103 (bytes 200-207): 48-bit LBA native max address
671                    let w0 = u16::from_le_bytes([id_buf[200], id_buf[201]]) as u64;
672                    let w1 = u16::from_le_bytes([id_buf[202], id_buf[203]]) as u64;
673                    let w2 = u16::from_le_bytes([id_buf[204], id_buf[205]]) as u64;
674                    let w3 = u16::from_le_bytes([id_buf[206], id_buf[207]]) as u64;
675                    port.sector_count = w0 | (w1 << 16) | (w2 << 32) | (w3 << 48);
676                    log::info!(
677                        "AHCI: port {} SATA — {} sectors ({} MiB)",
678                        port_num,
679                        port.sector_count,
680                        (port.sector_count * SECTOR_SIZE as u64) / (1024 * 1024)
681                    );
682                }
683                Err(e) => {
684                    log::warn!("AHCI: port {} IDENTIFY failed: {}", port_num, e);
685                }
686            }
687
688            ports.push(port);
689        }
690
691        if ports.is_empty() {
692            return Err(AhciError::NoPort);
693        }
694
695        // Store the ABAR virtual address and IRQ line in statics so the
696        // interrupt handler can reach them without going through the controller lock.
697        AHCI_ABAR_VIRT.store(abar_virt, Ordering::Relaxed);
698        AHCI_IRQ_LINE.store(irq_line, Ordering::Relaxed);
699
700        // Enable global HBA interrupts (GHC.IE)
701        // SAFETY: MMIO write — all port interrupts already enabled above
702        let ghc = rd32(abar_virt, HBA_GHC);
703        wr32(abar_virt, HBA_GHC, ghc | GHC_IE);
704
705        log::info!("AHCI: global interrupts enabled (IRQ line {})", irq_line);
706
707        Ok(AhciController { abar_virt, ports })
708    }
709
710    /// Return sector count of the first port.
711    pub fn sector_count(&self) -> u64 {
712        self.ports.first().map(|p| p.sector_count).unwrap_or(0)
713    }
714
715    /// Performs the first port operation.
716    fn first_port(&self) -> Option<&AhciPort> {
717        self.ports.first()
718    }
719}
720
721impl BlockDevice for AhciController {
722    /// Reads sector.
723    fn read_sector(&self, sector: u64, buf: &mut [u8]) -> Result<(), BlockError> {
724        let port = self.first_port().ok_or(BlockError::NotReady)?;
725        if sector >= port.sector_count {
726            return Err(BlockError::InvalidSector);
727        }
728        if buf.len() < SECTOR_SIZE {
729            return Err(BlockError::BufferTooSmall);
730        }
731        submit_cmd(port, sector, 1, buf, false, ATA_READ_DMA_EXT).map_err(|_| BlockError::IoError)
732    }
733
734    /// Writes sector.
735    fn write_sector(&self, sector: u64, buf: &[u8]) -> Result<(), BlockError> {
736        let port = self.first_port().ok_or(BlockError::NotReady)?;
737        if sector >= port.sector_count {
738            return Err(BlockError::InvalidSector);
739        }
740        if buf.len() < SECTOR_SIZE {
741            return Err(BlockError::BufferTooSmall);
742        }
743        // submit_cmd needs &mut [u8]; copy to a mutable staging buffer
744        let mut tmp = [0u8; SECTOR_SIZE];
745        tmp.copy_from_slice(&buf[..SECTOR_SIZE]);
746        submit_cmd(port, sector, 1, &mut tmp, true, ATA_WRITE_DMA_EXT)
747            .map_err(|_| BlockError::IoError)
748    }
749
750    /// Performs the sector count operation.
751    fn sector_count(&self) -> u64 {
752        self.sector_count()
753    }
754}
755
756// ─── Global singleton + public API ───────────────────────────────────────────
757
758static AHCI: SpinLock<Option<Box<AhciController>>> = SpinLock::new(None);
759
760/// Scan the PCI bus for an AHCI controller and initialise it.
761///
762/// Called once during kernel boot from `hardware::init()`.
763pub fn init() {
764    log::info!("AHCI: scanning PCI bus...");
765
766    match unsafe { AhciController::init() } {
767        Ok(ctrl) => {
768            *AHCI.lock() = Some(Box::new(ctrl));
769            log::info!("AHCI: controller ready");
770
771            // Register IRQ handler in the IDT now that the controller is live.
772            let irq = AHCI_IRQ_LINE.load(Ordering::Relaxed);
773            crate::arch::x86_64::idt::register_ahci_irq(irq);
774        }
775        Err(AhciError::NoController) => {
776            log::info!("AHCI: no controller found (not a SATA system?)");
777        }
778        Err(e) => {
779            log::error!("AHCI: init failed: {}", e);
780        }
781    }
782}
783
784/// Return a reference to the first usable AHCI controller, if any.
785pub fn get_device() -> Option<&'static AhciController> {
786    // SAFETY: the global Option is only ever set during init and never cleared
787    unsafe {
788        let lock = AHCI.lock();
789        lock.as_ref().map(|b| {
790            let ptr = b.as_ref() as *const AhciController;
791            &*ptr
792        })
793    }
794}