Skip to main content

strat9_kernel/syscall/
fork.rs

1//! `fork()` syscall implementation with copy-on-write (COW).
2
3use crate::{
4    memory::{AddressSpace, FrameAllocator as _},
5    process::{
6        current_task_clone,
7        scheduler::add_task_with_parent,
8        signal::{SigActionData, SigStack, SignalSet},
9        task::{CpuContext, KernelStack, Pid, SyncUnsafeCell, Task},
10        TaskId, TaskState,
11    },
12    syscall::{error::SyscallError, SyscallFrame},
13};
14use alloc::{boxed::Box, sync::Arc};
15use core::{
16    mem::offset_of,
17    sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering},
18};
19use x86_64::structures::paging::{mapper::TranslateResult, FrameAllocator}; // Required for allocate_frame
20
21/// Result returned by [`sys_fork`].
22pub struct ForkResult {
23    pub child_pid: Pid,
24}
25
26/// Performs the local invlpg operation.
27#[inline]
28fn local_invlpg(vaddr: u64) {
29    // Local TLB invalidation is sufficient here: this kernel currently runs
30    // one task per user address space (no shared user CR3 across CPUs).
31    unsafe {
32        core::arch::asm!("invlpg [{}]", in(reg) vaddr, options(nostack, preserves_flags));
33    }
34}
35
36#[repr(C)]
37#[derive(Clone, Copy)]
38struct ForkUserContext {
39    r15: u64,
40    r14: u64,
41    r13: u64,
42    r12: u64,
43    rbp: u64,
44    rbx: u64,
45    r11: u64,
46    r10: u64,
47    r9: u64,
48    r8: u64,
49    rsi: u64,
50    rdi: u64,
51    rdx: u64,
52    rcx: u64,
53    user_rip: u64,
54    user_cs: u64,
55    user_rflags: u64,
56    user_rsp: u64,
57    user_ss: u64,
58}
59
60const OFF_R15: usize = offset_of!(ForkUserContext, r15);
61const OFF_R14: usize = offset_of!(ForkUserContext, r14);
62const OFF_R13: usize = offset_of!(ForkUserContext, r13);
63const OFF_R12: usize = offset_of!(ForkUserContext, r12);
64const OFF_RBP: usize = offset_of!(ForkUserContext, rbp);
65const OFF_RBX: usize = offset_of!(ForkUserContext, rbx);
66const OFF_R11: usize = offset_of!(ForkUserContext, r11);
67const OFF_R10: usize = offset_of!(ForkUserContext, r10);
68const OFF_R9: usize = offset_of!(ForkUserContext, r9);
69const OFF_R8: usize = offset_of!(ForkUserContext, r8);
70const OFF_RSI: usize = offset_of!(ForkUserContext, rsi);
71const OFF_RDI: usize = offset_of!(ForkUserContext, rdi);
72const OFF_RDX: usize = offset_of!(ForkUserContext, rdx);
73const OFF_RCX: usize = offset_of!(ForkUserContext, rcx);
74const OFF_USER_RIP: usize = offset_of!(ForkUserContext, user_rip);
75const OFF_USER_CS: usize = offset_of!(ForkUserContext, user_cs);
76const OFF_USER_RFLAGS: usize = offset_of!(ForkUserContext, user_rflags);
77const OFF_USER_RSP: usize = offset_of!(ForkUserContext, user_rsp);
78const OFF_USER_SS: usize = offset_of!(ForkUserContext, user_ss);
79
80/// Child bootstrap: restore user register snapshot and enter Ring 3.
81extern "C" fn fork_child_start(ctx_ptr: u64) -> ! {
82    let boxed = unsafe { Box::from_raw(ctx_ptr as *mut ForkUserContext) };
83    let ctx = *boxed;
84    unsafe { fork_iret_from_ctx(&ctx as *const ForkUserContext) }
85}
86
87/// Performs the fork iret from ctx operation.
88#[unsafe(naked)]
89unsafe extern "C" fn fork_iret_from_ctx(_ctx: *const ForkUserContext) -> ! {
90    core::arch::naked_asm!(
91        "mov rsi, rdi",
92
93        // ===== Build IRET frame FIRST, using r8 as scratch ===========
94        // (r8 has not been restored yet, so we can clobber it safely)
95        "mov r8, [rsi + {off_user_ss}]",
96        "push r8",                            // SS
97        "mov r8, [rsi + {off_user_rsp}]",
98        "push r8",                            // user RSP
99        "mov r8, [rsi + {off_user_rflags}]",
100        "push r8",                            // user RFLAGS
101        "mov r8, [rsi + {off_user_cs}]",
102        "push r8",                            // CS
103        "mov r8, [rsi + {off_user_rip}]",
104        "push r8",                            // user RIP
105
106        // ===== Now restore ALL general-purpose registers============
107        "mov r15, [rsi + {off_r15}]",
108        "mov r14, [rsi + {off_r14}]",
109        "mov r13, [rsi + {off_r13}]",
110        "mov r12, [rsi + {off_r12}]",
111        "mov rbp, [rsi + {off_rbp}]",
112        "mov rbx, [rsi + {off_rbx}]",
113        "mov r11, [rsi + {off_r11}]",
114        "mov r10, [rsi + {off_r10}]",
115        "mov r9,  [rsi + {off_r9}]",
116        "mov r8,  [rsi + {off_r8}]",          // r8 now gets its correct value
117        "mov rdx, [rsi + {off_rdx}]",
118        "mov rcx, [rsi + {off_rcx}]",
119        "mov rdi, [rsi + {off_rdi}]",
120        "mov rax, 0",                         // child fork() returns 0
121        "mov rsi, [rsi + {off_rsi}]",         // rsi restored last
122        "iretq",
123        off_r15 = const OFF_R15,
124        off_r14 = const OFF_R14,
125        off_r13 = const OFF_R13,
126        off_r12 = const OFF_R12,
127        off_rbp = const OFF_RBP,
128        off_rbx = const OFF_RBX,
129        off_r11 = const OFF_R11,
130        off_r10 = const OFF_R10,
131        off_r9 = const OFF_R9,
132        off_r8 = const OFF_R8,
133        off_rsi = const OFF_RSI,
134        off_rdi = const OFF_RDI,
135        off_rdx = const OFF_RDX,
136        off_rcx = const OFF_RCX,
137        off_user_rip = const OFF_USER_RIP,
138        off_user_cs = const OFF_USER_CS,
139        off_user_rflags = const OFF_USER_RFLAGS,
140        off_user_rsp = const OFF_USER_RSP,
141        off_user_ss = const OFF_USER_SS,
142    );
143}
144
145/// Performs the build child task operation.
146fn build_child_task(
147    parent: &Arc<Task>,
148    child_as: Arc<AddressSpace>,
149    bootstrap_ctx: Box<ForkUserContext>,
150) -> Result<Arc<Task>, SyscallError> {
151    let kernel_stack =
152        KernelStack::allocate(Task::DEFAULT_STACK_SIZE).map_err(|_| SyscallError::OutOfMemory)?;
153    let context = CpuContext::new(fork_child_start as *const () as u64, &kernel_stack);
154
155    let parent_caps = unsafe { (&*parent.process.capabilities.get()).clone() };
156    let parent_fd = unsafe { (&*parent.process.fd_table.get()).clone_for_fork() };
157    let parent_blocked = parent.blocked_signals.clone();
158    let parent_actions: [SigActionData; 64] = unsafe { *parent.process.signal_actions.get() };
159    let parent_sigstack: Option<SigStack> = unsafe { *parent.signal_stack.get() };
160
161    let (pid, tid, tgid) = Task::allocate_process_ids();
162    let task = Arc::new(Task {
163        id: TaskId::new(),
164        pid,
165        tid,
166        tgid,
167        pgid: AtomicU32::new(parent.pgid.load(Ordering::Relaxed)),
168        sid: AtomicU32::new(parent.sid.load(Ordering::Relaxed)),
169        uid: AtomicU32::new(parent.uid.load(Ordering::Relaxed)),
170        euid: AtomicU32::new(parent.euid.load(Ordering::Relaxed)),
171        gid: AtomicU32::new(parent.gid.load(Ordering::Relaxed)),
172        egid: AtomicU32::new(parent.egid.load(Ordering::Relaxed)),
173        state: SyncUnsafeCell::new(TaskState::Ready),
174        priority: parent.priority,
175        context: SyncUnsafeCell::new(context),
176        kernel_stack,
177        user_stack: None,
178
179        name: "fork-child",
180        process: alloc::sync::Arc::new(crate::process::process::Process {
181            pid,
182            address_space: crate::process::task::SyncUnsafeCell::new(child_as),
183            fd_table: crate::process::task::SyncUnsafeCell::new(parent_fd),
184            capabilities: crate::process::task::SyncUnsafeCell::new(parent_caps),
185            signal_actions: crate::process::task::SyncUnsafeCell::new(parent_actions),
186            brk: core::sync::atomic::AtomicU64::new(
187                parent
188                    .process
189                    .brk
190                    .load(core::sync::atomic::Ordering::Relaxed),
191            ),
192            mmap_hint: core::sync::atomic::AtomicU64::new(
193                parent
194                    .process
195                    .mmap_hint
196                    .load(core::sync::atomic::Ordering::Relaxed),
197            ),
198            cwd: crate::process::task::SyncUnsafeCell::new(
199                unsafe { &*parent.process.cwd.get() }.clone(),
200            ),
201            umask: core::sync::atomic::AtomicU32::new(
202                parent
203                    .process
204                    .umask
205                    .load(core::sync::atomic::Ordering::Relaxed),
206            ),
207        }),
208        // POSIX: pending signals are NOT inherited by the child.
209        pending_signals: SignalSet::new(),
210        // POSIX: signal mask IS inherited.
211        blocked_signals: parent_blocked,
212        signal_stack: SyncUnsafeCell::new(parent_sigstack),
213        itimers: crate::process::timer::ITimers::new(),
214        wake_pending: AtomicBool::new(false),
215        wake_deadline_ns: AtomicU64::new(0),
216        trampoline_entry: AtomicU64::new(0),
217        trampoline_stack_top: AtomicU64::new(0),
218        trampoline_arg0: AtomicU64::new(0),
219        ticks: AtomicU64::new(0),
220        sched_policy: SyncUnsafeCell::new(parent.sched_policy()),
221        vruntime: AtomicU64::new(parent.vruntime()),
222        // POSIX: clear_child_tid is NOT inherited — child starts with 0.
223        clear_child_tid: AtomicU64::new(0),
224        // POSIX: cwd IS inherited.
225        // POSIX: umask IS inherited.
226        // FS.base: child starts with 0 (its own TLS not yet set up).
227        user_fs_base: AtomicU64::new(0),
228        fpu_state: {
229            let parent_fpu = unsafe { &*parent.fpu_state.get() };
230            let mut child_fpu = crate::process::task::ExtendedState::new();
231            child_fpu.copy_from(parent_fpu);
232            SyncUnsafeCell::new(child_fpu)
233        },
234        xcr0_mask: AtomicU64::new(parent.xcr0_mask.load(core::sync::atomic::Ordering::Relaxed)),
235    });
236
237    // CpuContext initial stack layout: r15, r14, r13(arg), r12(entry), rbp, rbx, ret
238    unsafe {
239        let ctx = &mut *task.context.get();
240        let frame = ctx.saved_rsp as *mut u64;
241        *frame.add(2) = Box::into_raw(bootstrap_ctx) as u64;
242    }
243
244    Ok(task)
245}
246
247/// SYS_PROC_FORK (302): fork with copy-on-write address-space cloning.
248pub fn sys_fork(frame: &SyscallFrame) -> Result<ForkResult, SyscallError> {
249    let parent = current_task_clone().ok_or(SyscallError::PermissionDenied)?;
250
251    // 1. Sanity check: cannot fork a kernel thread.
252    if parent.is_kernel() {
253        log::warn!("fork: attempt to fork kernel thread '{}'", parent.name);
254        return Err(SyscallError::PermissionDenied);
255    }
256
257    // 2. Capability check: check if task is restricted from forking.
258    // For now, we allow fork for all user processes unless restricted.
259    // TODO: implement ResourceType::Process/Task restricted capabilities.
260
261    let parent_as = unsafe { &*parent.process.address_space.get() };
262
263    // 3. Memory check: ensure parent has actual user-space mappings.
264    if !parent_as.has_user_mappings() {
265        log::warn!(
266            "fork: attempt to fork task '{}' with no user mappings",
267            parent.name
268        );
269        return Err(SyscallError::InvalidArgument);
270    }
271
272    let child_as = parent_as
273        .clone_cow()
274        .map_err(|_| SyscallError::OutOfMemory)?;
275
276    let child_user_ctx = Box::new(ForkUserContext {
277        r15: frame.r15,
278        r14: frame.r14,
279        r13: frame.r13,
280        r12: frame.r12,
281        rbp: frame.rbp,
282        rbx: frame.rbx,
283        r11: frame.r11,
284        r10: frame.r10,
285        r9: frame.r9,
286        r8: frame.r8,
287        rsi: frame.rsi,
288        rdi: frame.rdi,
289        rdx: frame.rdx,
290        rcx: frame.rcx,
291        user_rip: frame.iret_rip,
292        user_cs: frame.iret_cs,
293        user_rflags: frame.iret_rflags,
294        user_rsp: frame.iret_rsp,
295        user_ss: frame.iret_ss,
296    });
297
298    let child_task = build_child_task(&parent, child_as, child_user_ctx)?;
299    let child_pid = child_task.pid;
300    add_task_with_parent(child_task, parent.id);
301
302    Ok(ForkResult { child_pid })
303}
304
305/// Called from the page fault handler when a write fault occurs on a present page.
306/// Returns Ok(()) if the fault was successfully handled (COW resolution),
307/// or Err if it wasn't a COW fault (real access violation).
308pub fn handle_cow_fault(virt_addr: u64, address_space: &AddressSpace) -> Result<(), &'static str> {
309    use crate::memory::paging::BuddyFrameAllocator;
310    use x86_64::{
311        structures::paging::{Mapper, Page, PageTableFlags, Size4KiB, Translate},
312        VirtAddr,
313    };
314
315    let page = Page::<Size4KiB>::containing_address(VirtAddr::new(virt_addr));
316
317    // SAFETY: we are in an exception handler, address space is active.
318    let mut mapper = unsafe { address_space.mapper() };
319
320    //Check if page is mapped and has COW flag
321    let (phys_frame, flags): (
322        x86_64::structures::paging::PhysFrame<Size4KiB>,
323        PageTableFlags,
324    ) = match mapper.translate(VirtAddr::new(virt_addr)) {
325        TranslateResult::Mapped {
326            frame: x86_64::structures::paging::mapper::MappedFrame::Size4KiB(frame),
327            offset: _,
328            flags,
329        } => (frame, flags),
330        _ => return Err("Page not mapped or huge page"),
331    };
332
333    // We use BIT_9 as software COW flag
334    const COW_BIT: PageTableFlags = PageTableFlags::BIT_9;
335
336    if !flags.contains(COW_BIT) {
337        return Err("Not a COW page");
338    }
339
340    let old_frame = crate::memory::PhysFrame {
341        start_address: phys_frame.start_address(),
342    };
343
344    let refcount = crate::memory::cow::frame_get_refcount(old_frame);
345
346    if refcount == 1 {
347        // Case 1: we are the sole owner. Just make it writable.
348        let new_flags = (flags | PageTableFlags::WRITABLE) & !COW_BIT;
349
350        unsafe {
351            mapper
352                .update_flags(page, new_flags)
353                .map_err(|_| "Failed to update flags")?
354                .flush();
355        }
356        // Only the current CPU can hold this CR3 in the current design.
357        local_invlpg(virt_addr);
358        return Ok(());
359    }
360
361    // Case 2: shared page. Copy to new frame.
362    let mut frame_allocator = BuddyFrameAllocator;
363
364    let new_frame = frame_allocator
365        .allocate_frame()
366        .ok_or("OOM during COW copy")?;
367
368    // Copy content
369    unsafe {
370        let src = crate::memory::phys_to_virt(old_frame.start_address.as_u64()) as *const u8;
371        let dst = crate::memory::phys_to_virt(new_frame.start_address().as_u64()) as *mut u8;
372        core::ptr::copy_nonoverlapping(src, dst, 4096);
373    }
374
375    // Update mapping to new frame, Writable, no COW
376    let new_flags = (flags | PageTableFlags::WRITABLE) & !COW_BIT;
377
378    // Replace existing mapping (present+COW) by the private writable mapping.
379    let old_unmapped = mapper
380        .unmap(page)
381        .map_err(|_| "Failed to unmap old COW frame")?
382        .0;
383    debug_assert_eq!(old_unmapped.start_address(), old_frame.start_address);
384
385    let remap_res = unsafe { mapper.map_to(page, new_frame, new_flags, &mut frame_allocator) };
386    if remap_res.is_err() {
387        unsafe {
388            let _ = mapper.map_to(page, phys_frame, flags, &mut frame_allocator);
389        }
390        crate::sync::with_irqs_disabled(|token| {
391            crate::memory::free_frame(token, crate::memory::PhysFrame {
392                start_address: new_frame.start_address(),
393            });
394        });
395        return Err("Failed to map new COW frame");
396    }
397    match remap_res {
398        Ok(flush) => flush.flush(),
399        Err(_) => unreachable!("checked remap result above"),
400    }
401
402    crate::memory::cow::frame_inc_ref(crate::memory::PhysFrame {
403        start_address: new_frame.start_address(),
404    });
405
406    // Only the current CPU can hold this CR3 in the current design.
407    local_invlpg(virt_addr);
408
409    // Decrement refcount of old frame after the new mapping is installed.
410    crate::memory::cow::frame_dec_ref(old_frame);
411
412    Ok(())
413}