Skip to main content

strate_fs_abstraction/
safe_math.rs

1//! Safe arithmetic operations with overflow detection.
2//!
3//! This module provides traits and functions for performing arithmetic
4//! operations that detect overflow and return errors instead of panicking
5//! or wrapping.
6//!
7//! # Security
8//!
9//! Arithmetic overflow in filesystem code can lead to serious security
10//! vulnerabilities like buffer overflows. Always use these checked operations
11//! when computing sizes, offsets, or counts from untrusted metadata.
12
13use crate::error::{FsError, FsResult};
14
15/// Trait for checked arithmetic operations.
16///
17/// Provides methods that return `FsError::ArithmeticOverflow` on overflow
18/// instead of panicking or wrapping.
19pub trait CheckedOps: Sized {
20    /// Adds an offset, returning an error on overflow.
21    fn checked_add_offset(self, offset: Self) -> FsResult<Self>;
22
23    /// Multiplies by a size, returning an error on overflow.
24    fn checked_mul_size(self, size: Self) -> FsResult<Self>;
25
26    /// Subtracts, returning an error on underflow.
27    fn checked_sub_safe(self, other: Self) -> FsResult<Self>;
28
29    /// Left-shifts, returning an error if bits would be lost.
30    fn checked_shl_safe(self, shift: u32) -> FsResult<Self>;
31}
32
33impl CheckedOps for u64 {
34    #[inline]
35    /// Implements checked add offset.
36    fn checked_add_offset(self, offset: u64) -> FsResult<Self> {
37        self.checked_add(offset).ok_or(FsError::ArithmeticOverflow)
38    }
39
40    #[inline]
41    /// Implements checked mul size.
42    fn checked_mul_size(self, size: u64) -> FsResult<Self> {
43        self.checked_mul(size).ok_or(FsError::ArithmeticOverflow)
44    }
45
46    #[inline]
47    /// Implements checked sub safe.
48    fn checked_sub_safe(self, other: u64) -> FsResult<Self> {
49        self.checked_sub(other).ok_or(FsError::ArithmeticOverflow)
50    }
51
52    #[inline]
53    /// Implements checked shl safe.
54    fn checked_shl_safe(self, shift: u32) -> FsResult<Self> {
55        if shift >= 64 {
56            return Err(FsError::ArithmeticOverflow);
57        }
58        // Check if any bits would be lost
59        let result = self << shift;
60        if (result >> shift) != self {
61            return Err(FsError::ArithmeticOverflow);
62        }
63        Ok(result)
64    }
65}
66
67impl CheckedOps for u32 {
68    #[inline]
69    /// Implements checked add offset.
70    fn checked_add_offset(self, offset: u32) -> FsResult<Self> {
71        self.checked_add(offset).ok_or(FsError::ArithmeticOverflow)
72    }
73
74    #[inline]
75    /// Implements checked mul size.
76    fn checked_mul_size(self, size: u32) -> FsResult<Self> {
77        self.checked_mul(size).ok_or(FsError::ArithmeticOverflow)
78    }
79
80    #[inline]
81    /// Implements checked sub safe.
82    fn checked_sub_safe(self, other: u32) -> FsResult<Self> {
83        self.checked_sub(other).ok_or(FsError::ArithmeticOverflow)
84    }
85
86    #[inline]
87    /// Implements checked shl safe.
88    fn checked_shl_safe(self, shift: u32) -> FsResult<Self> {
89        if shift >= 32 {
90            return Err(FsError::ArithmeticOverflow);
91        }
92        let result = self << shift;
93        if (result >> shift) != self {
94            return Err(FsError::ArithmeticOverflow);
95        }
96        Ok(result)
97    }
98}
99
100impl CheckedOps for usize {
101    #[inline]
102    /// Implements checked add offset.
103    fn checked_add_offset(self, offset: usize) -> FsResult<Self> {
104        self.checked_add(offset).ok_or(FsError::ArithmeticOverflow)
105    }
106
107    #[inline]
108    /// Implements checked mul size.
109    fn checked_mul_size(self, size: usize) -> FsResult<Self> {
110        self.checked_mul(size).ok_or(FsError::ArithmeticOverflow)
111    }
112
113    #[inline]
114    /// Implements checked sub safe.
115    fn checked_sub_safe(self, other: usize) -> FsResult<Self> {
116        self.checked_sub(other).ok_or(FsError::ArithmeticOverflow)
117    }
118
119    #[inline]
120    /// Implements checked shl safe.
121    fn checked_shl_safe(self, shift: u32) -> FsResult<Self> {
122        if shift >= (core::mem::size_of::<usize>() * 8) as u32 {
123            return Err(FsError::ArithmeticOverflow);
124        }
125        let result = self << shift;
126        if (result >> shift) != self {
127            return Err(FsError::ArithmeticOverflow);
128        }
129        Ok(result)
130    }
131}
132
133/// Extension trait for checked slice operations.
134pub trait CheckedSliceOps {
135    /// Gets a subslice with bounds checking.
136    fn get_checked(&self, start: usize, len: usize) -> FsResult<&[u8]>;
137
138    /// Reads a big-endian u16 at the given offset.
139    fn read_be_u16(&self, offset: usize) -> FsResult<u16>;
140
141    /// Reads a big-endian u32 at the given offset.
142    fn read_be_u32(&self, offset: usize) -> FsResult<u32>;
143
144    /// Reads a big-endian u64 at the given offset.
145    fn read_be_u64(&self, offset: usize) -> FsResult<u64>;
146
147    /// Reads a little-endian u16 at the given offset.
148    fn read_le_u16(&self, offset: usize) -> FsResult<u16>;
149
150    /// Reads a little-endian u32 at the given offset.
151    fn read_le_u32(&self, offset: usize) -> FsResult<u32>;
152
153    /// Reads a little-endian u64 at the given offset.
154    fn read_le_u64(&self, offset: usize) -> FsResult<u64>;
155}
156
157impl CheckedSliceOps for [u8] {
158    #[inline]
159    /// Returns checked.
160    fn get_checked(&self, start: usize, len: usize) -> FsResult<&[u8]> {
161        let end = start.checked_add(len).ok_or(FsError::ArithmeticOverflow)?;
162        if end > self.len() {
163            return Err(FsError::BufferTooSmall);
164        }
165        Ok(&self[start..end])
166    }
167
168    #[inline]
169    /// Reads be u16.
170    fn read_be_u16(&self, offset: usize) -> FsResult<u16> {
171        let bytes = self.get_checked(offset, 2)?;
172        Ok(u16::from_be_bytes([bytes[0], bytes[1]]))
173    }
174
175    #[inline]
176    /// Reads be u32.
177    fn read_be_u32(&self, offset: usize) -> FsResult<u32> {
178        let bytes = self.get_checked(offset, 4)?;
179        Ok(u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
180    }
181
182    #[inline]
183    /// Reads be u64.
184    fn read_be_u64(&self, offset: usize) -> FsResult<u64> {
185        let bytes = self.get_checked(offset, 8)?;
186        Ok(u64::from_be_bytes([
187            bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
188        ]))
189    }
190
191    #[inline]
192    /// Reads le u16.
193    fn read_le_u16(&self, offset: usize) -> FsResult<u16> {
194        let bytes = self.get_checked(offset, 2)?;
195        Ok(u16::from_le_bytes([bytes[0], bytes[1]]))
196    }
197
198    #[inline]
199    /// Reads le u32.
200    fn read_le_u32(&self, offset: usize) -> FsResult<u32> {
201        let bytes = self.get_checked(offset, 4)?;
202        Ok(u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
203    }
204
205    #[inline]
206    /// Reads le u64.
207    fn read_le_u64(&self, offset: usize) -> FsResult<u64> {
208        let bytes = self.get_checked(offset, 8)?;
209        Ok(u64::from_le_bytes([
210            bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
211        ]))
212    }
213}
214
215/// Saturating divide that returns 0 for division by zero.
216#[inline]
217pub const fn saturating_div(a: u64, b: u64) -> u64 {
218    if b == 0 {
219        0
220    } else {
221        a / b
222    }
223}
224
225/// Computes ceil(a / b) without overflow.
226#[inline]
227pub fn div_ceil(a: u64, b: u64) -> FsResult<u64> {
228    if b == 0 {
229        return Err(FsError::ArithmeticOverflow);
230    }
231    // (a + b - 1) / b, but avoid overflow
232    let result = a / b;
233    if a % b != 0 {
234        result.checked_add(1).ok_or(FsError::ArithmeticOverflow)
235    } else {
236        Ok(result)
237    }
238}
239
240/// Aligns a value up to the given alignment.
241///
242/// # Arguments
243/// * `value` - Value to align
244/// * `align` - Alignment (must be a power of 2)
245#[inline]
246pub fn align_up(value: u64, align: u64) -> FsResult<u64> {
247    if !align.is_power_of_two() {
248        return Err(FsError::AlignmentError);
249    }
250    let mask = align - 1;
251    value
252        .checked_add(mask)
253        .map(|v| v & !mask)
254        .ok_or(FsError::ArithmeticOverflow)
255}
256
257/// Aligns a value down to the given alignment.
258///
259/// # Arguments
260/// * `value` - Value to align
261/// * `align` - Alignment (must be a power of 2)
262#[inline]
263pub const fn align_down(value: u64, align: u64) -> u64 {
264    if !align.is_power_of_two() {
265        return value; // Fallback
266    }
267    value & !(align - 1)
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273
274    #[test]
275    /// Implements test checked add overflow.
276    fn test_checked_add_overflow() {
277        assert!(u64::MAX.checked_add_offset(1).is_err());
278        assert_eq!(1u64.checked_add_offset(2).unwrap(), 3);
279    }
280
281    #[test]
282    /// Implements test checked mul overflow.
283    fn test_checked_mul_overflow() {
284        assert!(u64::MAX.checked_mul_size(2).is_err());
285        assert_eq!(3u64.checked_mul_size(4).unwrap(), 12);
286    }
287
288    #[test]
289    /// Implements test read be.
290    fn test_read_be() {
291        let data = [0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC, 0xDE, 0xF0];
292        assert_eq!(data.read_be_u16(0).unwrap(), 0x1234);
293        assert_eq!(data.read_be_u32(0).unwrap(), 0x12345678);
294        assert_eq!(data.read_be_u64(0).unwrap(), 0x123456789ABCDEF0);
295    }
296
297    #[test]
298    /// Implements test read buffer bounds.
299    fn test_read_buffer_bounds() {
300        let data = [0x12, 0x34];
301        assert!(data.read_be_u32(0).is_err());
302    }
303
304    #[test]
305    /// Implements test align up.
306    fn test_align_up() {
307        assert_eq!(align_up(0, 4).unwrap(), 0);
308        assert_eq!(align_up(1, 4).unwrap(), 4);
309        assert_eq!(align_up(4, 4).unwrap(), 4);
310        assert_eq!(align_up(5, 4).unwrap(), 8);
311    }
312
313    #[test]
314    /// Implements test div ceil.
315    fn test_div_ceil() {
316        assert_eq!(div_ceil(0, 4).unwrap(), 0);
317        assert_eq!(div_ceil(1, 4).unwrap(), 1);
318        assert_eq!(div_ceil(4, 4).unwrap(), 1);
319        assert_eq!(div_ceil(5, 4).unwrap(), 2);
320    }
321}