diff options
author | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2024-04-16 23:00:39 +0000 |
---|---|---|
committer | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2024-04-16 23:00:39 +0000 |
commit | 0374097d7efc61030dd81f45ef21d193b6594d3a (patch) | |
tree | 570a3ad2048e4b3ef70475fabb43c717cf850c44 | |
parent | 18b5393108ff31363e47c4ae01da2cc4c0600e93 (diff) | |
parent | 5c5ea94606b98d0113ab1e89f4ff5304af5c27c3 (diff) | |
download | libbootloader-0374097d7efc61030dd81f45ef21d193b6594d3a.tar.gz |
Snap for 11724015 from 5c5ea94606b98d0113ab1e89f4ff5304af5c27c3 to sdk-release
Change-Id: Id179585bdca60261f4fe9c6c4819f8f9798fedf6
-rw-r--r-- | gbl/efi/BUILD | 1 | ||||
-rw-r--r-- | gbl/efi/src/avb.rs | 101 | ||||
-rw-r--r-- | gbl/libavb/BUILD | 9 | ||||
-rw-r--r-- | gbl/libavb/BUILD.avb.bazel | 12 | ||||
-rw-r--r-- | gbl/libavb/src/lib.rs | 154 | ||||
-rw-r--r-- | gbl/libsafemath/BUILD | 30 | ||||
-rw-r--r-- | gbl/libsafemath/Cargo.toml | 17 | ||||
-rw-r--r-- | gbl/libsafemath/src/lib.rs | 487 | ||||
-rw-r--r-- | gbl/tests/BUILD | 1 |
9 files changed, 708 insertions, 104 deletions
diff --git a/gbl/efi/BUILD b/gbl/efi/BUILD index f8e8a10..73f9683 100644 --- a/gbl/efi/BUILD +++ b/gbl/efi/BUILD @@ -44,6 +44,7 @@ rust_binary( deps = [ "@avb", "@avb//:avb_crypto_ops_sha_impl_staticlib", + "@gbl//libavb:sysdeps", "@gbl//libboot", "@gbl//libbootconfig", "@gbl//libbootimg", diff --git a/gbl/efi/src/avb.rs b/gbl/efi/src/avb.rs index 049f0f0..5a80000 100644 --- a/gbl/efi/src/avb.rs +++ b/gbl/efi/src/avb.rs @@ -16,15 +16,14 @@ // is mainly for use by the boot demo. Eventually, these backends will be implemented from the // `GblOps` interface in libgbl, where EFI services will be one level lower as its backend instead. -use core::cmp::{min, Ord}; -use core::ffi::CStr; - use crate::utils::EfiMultiBlockDevices; use avb::{IoError, IoResult, Ops, PublicKeyForPartitionInfo}; -use efi::{efi_free, efi_malloc}; +use core::ffi::CStr; use gbl_storage::AsMultiBlockDevices; use uuid::Uuid; +extern crate avb_sysdeps; + pub struct GblEfiAvbOps<'a, 'b> { gpt_dev: &'b mut EfiMultiBlockDevices<'a>, preloaded_partitions: Option<&'b [(&'b str, &'b [u8])]>, @@ -156,97 +155,3 @@ impl<'b> Ops<'b> for GblEfiAvbOps<'_, 'b> { unimplemented!(); } } - -#[no_mangle] -pub extern "C" fn avb_abort() -> ! { - panic!("avb_abort"); -} - -#[no_mangle] -pub extern "C" fn avb_malloc_(size: usize) -> *mut core::ffi::c_void { - efi_malloc(size) as *mut _ -} - -#[no_mangle] -pub extern "C" fn avb_free(ptr: *mut core::ffi::c_void) { - efi_free(ptr as *mut _); -} - -#[no_mangle] -pub extern "C" fn avb_strlen(s: *const core::ffi::c_char) -> usize { - // SAFETY: libavb guarantees to pass valid NULL-terminated strings to this function. The - // returned reference is only used to compute string length. - unsafe { CStr::from_ptr(s as *const _) }.to_bytes().len() -} - -#[no_mangle] -pub extern "C" fn avb_div_by_10(dividend: *mut u64) -> u32 { - // SAFETY: libavb guarantees to pass valid pointer to u64 integer here - let val = unsafe { &mut *dividend }; - let rem = *val % 10; - *val /= 10; - rem.try_into().unwrap() -} - -#[no_mangle] -pub extern "C" fn avb_memcpy( - dest: *mut core::ffi::c_void, - src: *const core::ffi::c_void, - n: usize, -) -> *mut core::ffi::c_void { - // SAFETY: libavb guarantees to pass valid pointers. - unsafe { (src.cast::<u8>()).copy_to(dest as *mut _, n) }; - dest -} - -#[no_mangle] -pub extern "C" fn avb_memcmp( - src1: *const core::ffi::c_void, - src2: *const core::ffi::c_void, - n: usize, -) -> core::ffi::c_int { - // SAFETY: libavb guarantees to pass valid pointers. References are only used within function. - let (lhs, rhs) = unsafe { - ( - core::slice::from_raw_parts(src1 as *const u8, n), - core::slice::from_raw_parts(src2 as *const u8, n), - ) - }; - Ord::cmp(lhs, rhs) as i32 -} - -#[no_mangle] -pub extern "C" fn avb_strcmp( - s1: *const core::ffi::c_char, - s2: *const core::ffi::c_char, -) -> core::ffi::c_int { - // SAFETY: libavb guarantees to pass valid NULL-terminated strings. References are only used - // within function. - let (lhs, rhs) = unsafe { (CStr::from_ptr(s1 as *const _), CStr::from_ptr(s2 as *const _)) }; - Ord::cmp(lhs, rhs) as i32 -} - -#[no_mangle] -pub extern "C" fn avb_strncmp( - s1: *const core::ffi::c_char, - s2: *const core::ffi::c_char, - n: usize, -) -> core::ffi::c_int { - // SAFETY: libavb guarantees to pass valid NULL-terminated strings. References are only used - // within function. - let (lhs, rhs) = unsafe { (CStr::from_ptr(s1 as *const _), CStr::from_ptr(s2 as *const _)) }; - let cmp_size = min(min(lhs.to_bytes().len(), rhs.to_bytes().len()), n); - Ord::cmp(&lhs.to_bytes()[..cmp_size], &rhs.to_bytes()[..cmp_size]) as i32 -} - -#[no_mangle] -pub extern "C" fn avb_memset( - dest: *mut core::ffi::c_void, - c: core::ffi::c_int, - n: usize, -) -> *mut core::ffi::c_void { - // SAFETY: libavb guarantees to pass valid buffer. Reference is only used within function. - let arr = unsafe { core::slice::from_raw_parts_mut(dest as *mut u8, n) }; - arr.fill(c as u8); - dest -} diff --git a/gbl/libavb/BUILD b/gbl/libavb/BUILD index 277329b..2d05993 100644 --- a/gbl/libavb/BUILD +++ b/gbl/libavb/BUILD @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@rules_rust//rust:defs.bzl", "rust_library") + package( default_visibility = ["//visibility:public"], ) @@ -23,3 +25,10 @@ cc_library( hdrs = ["deps/include/inttypes.h"], includes = ["./deps/include"], ) + +rust_library( + name = "sysdeps", + srcs = glob(["**/*.rs"]), + crate_name = "avb_sysdeps", + edition = "2021", +) diff --git a/gbl/libavb/BUILD.avb.bazel b/gbl/libavb/BUILD.avb.bazel index 9ac978c..c556611 100644 --- a/gbl/libavb/BUILD.avb.bazel +++ b/gbl/libavb/BUILD.avb.bazel @@ -55,10 +55,10 @@ cc_library( "libavb/avb_version.h", "libavb/libavb.h", "libavb/sha/avb_crypto_ops_impl.h", - "libavb_atx/avb_atx_ops.h", - "libavb_atx/avb_atx_types.h", - "libavb_atx/avb_atx_validate.h", - "libavb_atx/libavb_atx.h", + "libavb_cert/avb_cert_ops.h", + "libavb_cert/avb_cert_types.h", + "libavb_cert/avb_cert_validate.h", + "libavb_cert/libavb_cert.h", ], includes = [ ".", @@ -85,7 +85,7 @@ cc_library( "libavb/avb_util.c", "libavb/avb_vbmeta_image.c", "libavb/avb_version.c", - "libavb_atx/avb_atx_validate.c", + "libavb_cert/avb_cert_validate.c", # Contains noop placeholder for avb_printv/avb_printf "@gbl//libavb:print.c", @@ -168,7 +168,7 @@ use zerocopy::{FromBytes, FromZeroes}; "-I{}".format(LLVM_PREBUILTS_C_INCLUDE), "-nostdinc", ], - header = "libavb_atx/libavb_atx.h", + header = "libavb_cert/libavb_cert.h", ) rust_library( diff --git a/gbl/libavb/src/lib.rs b/gbl/libavb/src/lib.rs new file mode 100644 index 0000000..fc2d987 --- /dev/null +++ b/gbl/libavb/src/lib.rs @@ -0,0 +1,154 @@ +// Copyright 2024, The Android Open Source Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! This file provides one possible implementation of the sysdeps functions for libavb. +//! Global allocator is required. + +#![cfg_attr(not(test), no_std)] + +extern crate alloc; +use alloc::alloc::{alloc, dealloc}; +use core::{ + alloc::Layout, + cmp::{min, Ord}, + ffi::CStr, + mem::size_of, + ptr::{null_mut, NonNull}, +}; + +/// `avb_malloc_()` requires allocation to be word aligned. +const AVB_MALLOC_ALIGNMENT: usize = 2; + +#[no_mangle] +pub extern "C" fn avb_abort() -> ! { + panic!("avb_abort"); +} + +#[no_mangle] +pub extern "C" fn avb_malloc_(size: usize) -> *mut core::ffi::c_void { + (|| { + // Allocate extra to store the size value. + let size = size_of::<usize>().checked_add(size)?; + // SAFETY: + // * On success, `alloc` guarantees to allocate enough memory. + // * `size.to_le_bytes().as_ptr()` is guaranteed valid memory. + // * Alignment is 1 for bytes copy. + unsafe { + let ptr = + NonNull::new(alloc(Layout::from_size_align(size, AVB_MALLOC_ALIGNMENT).ok()?))?; + ptr.as_ptr().copy_from(size.to_le_bytes().as_ptr(), size_of::<usize>()); + let ptr = ptr.as_ptr().add(size_of::<usize>()); + Some(ptr) + } + })() + .unwrap_or(null_mut()) as _ +} + +#[no_mangle] +pub extern "C" fn avb_free(ptr: *mut core::ffi::c_void) { + assert_ne!(ptr, null_mut()); + let mut ptr = ptr as *mut u8; + let mut size_bytes = [0u8; size_of::<usize>()]; + // SAFETY: + // * `ptr` is allocated by `avb_malloc_` and guarantees to have enough memory for a preceding + // usize value and payload. + // * `size_bytes.as_mut_ptr()` is a valid memory location. + // * Alignment is 1 for bytes copy. + unsafe { + ptr = ptr.sub(size_of::<usize>()); + ptr.copy_to(size_bytes.as_mut_ptr(), size_of::<usize>()) + }; + let size = usize::from_le_bytes(size_bytes); + // SAFETY: Call to global allocator. + unsafe { dealloc(ptr, Layout::from_size_align(size, AVB_MALLOC_ALIGNMENT).unwrap()) }; +} + +#[no_mangle] +pub extern "C" fn avb_strlen(s: *const core::ffi::c_char) -> usize { + // SAFETY: libavb guarantees to pass valid NULL-terminated strings to this function. The + // returned reference is only used to compute string length. + unsafe { CStr::from_ptr(s as *const _) }.to_bytes().len() +} + +#[no_mangle] +pub extern "C" fn avb_div_by_10(dividend: *mut u64) -> u32 { + // SAFETY: libavb guarantees to pass valid pointer to u64 integer here + let val = unsafe { &mut *dividend }; + let rem = *val % 10; + *val /= 10; + rem.try_into().unwrap() +} + +#[no_mangle] +pub extern "C" fn avb_memcpy( + dest: *mut core::ffi::c_void, + src: *const core::ffi::c_void, + n: usize, +) -> *mut core::ffi::c_void { + // SAFETY: libavb guarantees to pass valid pointers. + unsafe { (src.cast::<u8>()).copy_to(dest as *mut _, n) }; + dest +} + +#[no_mangle] +pub extern "C" fn avb_memcmp( + src1: *const core::ffi::c_void, + src2: *const core::ffi::c_void, + n: usize, +) -> core::ffi::c_int { + // SAFETY: libavb guarantees to pass valid pointers. References are only used within function. + let (lhs, rhs) = unsafe { + ( + core::slice::from_raw_parts(src1 as *const u8, n), + core::slice::from_raw_parts(src2 as *const u8, n), + ) + }; + Ord::cmp(lhs, rhs) as i32 +} + +#[no_mangle] +pub extern "C" fn avb_strcmp( + s1: *const core::ffi::c_char, + s2: *const core::ffi::c_char, +) -> core::ffi::c_int { + // SAFETY: libavb guarantees to pass valid NULL-terminated strings. References are only used + // within function. + let (lhs, rhs) = unsafe { (CStr::from_ptr(s1 as *const _), CStr::from_ptr(s2 as *const _)) }; + Ord::cmp(lhs, rhs) as i32 +} + +#[no_mangle] +pub extern "C" fn avb_strncmp( + s1: *const core::ffi::c_char, + s2: *const core::ffi::c_char, + n: usize, +) -> core::ffi::c_int { + // SAFETY: libavb guarantees to pass valid NULL-terminated strings. References are only used + // within function. + let (lhs, rhs) = unsafe { (CStr::from_ptr(s1 as *const _), CStr::from_ptr(s2 as *const _)) }; + let cmp_size = min(min(lhs.to_bytes().len(), rhs.to_bytes().len()), n); + Ord::cmp(&lhs.to_bytes()[..cmp_size], &rhs.to_bytes()[..cmp_size]) as i32 +} + +#[no_mangle] +pub extern "C" fn avb_memset( + dest: *mut core::ffi::c_void, + c: core::ffi::c_int, + n: usize, +) -> *mut core::ffi::c_void { + // SAFETY: libavb guarantees to pass valid buffer. Reference is only used within function. + let arr = unsafe { core::slice::from_raw_parts_mut(dest as *mut u8, n) }; + arr.fill(c as u8); + dest +} diff --git a/gbl/libsafemath/BUILD b/gbl/libsafemath/BUILD new file mode 100644 index 0000000..ea163f9 --- /dev/null +++ b/gbl/libsafemath/BUILD @@ -0,0 +1,30 @@ +# Copyright (C) 2024 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_rust//rust:defs.bzl", "rust_library", "rust_test") + +rust_library( + name = "libsafemath", + srcs = [ + "src/lib.rs", + ], + crate_name = "safemath", + edition = "2021", + visibility = ["//visibility:public"], +) + +rust_test( + name = "libsafemath_test", + crate = ":libsafemath", +) diff --git a/gbl/libsafemath/Cargo.toml b/gbl/libsafemath/Cargo.toml new file mode 100644 index 0000000..79a4588 --- /dev/null +++ b/gbl/libsafemath/Cargo.toml @@ -0,0 +1,17 @@ +# Copyright (C) 2024 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +[package] +name = "safemath" +version = "0.1.0" diff --git a/gbl/libsafemath/src/lib.rs b/gbl/libsafemath/src/lib.rs new file mode 100644 index 0000000..0cdaa7a --- /dev/null +++ b/gbl/libsafemath/src/lib.rs @@ -0,0 +1,487 @@ +// Copyright (C) 2024 The Android Open Source Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! # safemath library +//! +//! This library provides an API to safely work with unsigned integers. At a high level, all math +//! operations are checked by default rather than having to remember to call specific `checked_*` +//! functions, so that the burden is on the programmer if they want to perform unchecked math +//! rather than the other way around: +//! +//! ``` +//! use safemath::SafeNum; +//! +//! let safe = SafeNum::from(0); +//! let result = safe - 1; +//! assert!(u32::try_from(result).is_err()); +//! +//! let safe_chain = (SafeNum::from(BIG_NUMBER) * HUGE_NUMBER) / MAYBE_ZERO; +//! // If any operation would have caused an overflow or division by zero, +//! // the number is flagged and the lexical location is specified for logging. +//! if safe_chain.has_error() { +//! eprintln!("safe_chain error = {:#?}", safe_chain); +//! } +//! ``` +//! +//! In addition to checked-by-default arithmetic, the API exposed here support +//! more natural usage than the `checked_*` functions by allowing chaining +//! of operations without having to check the result at each step. +//! This is similar to how floating-point `NaN` works - you can continue to use the +//! value, but continued operations will just propagate `NaN`. +//! +//! ## Supported Operations +//! +//! ### Arithmetic +//! The basic arithmetic operations are supported: +//! addition, subtraction, multiplication, division, and remainder. +//! The right hand side may be another SafeNum or any integer, +//! and the result is always another SafeNum. +//! If the operation would result in an overflow or division by zero, +//! or if converting the right hand element to a `u64` would cause an error, +//! the result is an error-tagged SafeNum that tracks the lexical origin of the error. +//! +//! ### Conversion from and to SafeNum +//! SafeNums support conversion to and from all integer types. +//! Conversion to SafeNum from signed integers and from usize and u128 +//! can fail, generating an error value that is then propagated. +//! Conversion from SafeNum to all integers is only exposed via `try_from` +//! in order to force the user to handle potential resultant errors. +//! +//! E.g. +//! ``` +//! fn call_func(_: u32, _: u32) { +//! } +//! +//! fn do_a_thing(a: SafeNum) -> Result<(), safemath::Error> { +//! call_func(16, a.try_into()?); +//! Ok(()) +//! } +//! ``` +//! +//! ### Comparison +//! SafeNums can be checked for equality against each other. +//! Valid numbers are equal to other numbers of the same magnitude. +//! Errored SafeNums are only equal to themselves. +//! Note that because errors propagate from their first introduction in an +//! arithmetic chain this can lead to surprising results. +//! +//! E.g. +//! ``` +//! let overflow = SafeNum::MAX + 1; +//! let otherflow = SafeNum::MAX + 1; +//! +//! assert_ne!(overflow, otherflow); +//! assert_eq!(overflow + otherflow, overflow); +//! assert_eq!(otherflow + overflow, otherflow); +//! ``` +//! +//! Inequality comparison operators are deliberately not provided. +//! By necessity they would have similar caveats to floating point comparisons, +//! which are easy to use incorrectly and unintuitive to use correctly. +//! +//! The required alternative is to convert to a real integer type before comparing, +//! forcing any errors upwards. +//! +//! E.g. +//! ``` +//! impl From<safemath::Error> for &'static str { +//! fn from(_: safemath::Error) -> Self { +//! "checked arithmetic error" +//! } +//! } +//! +//! fn my_op(a: SafeNum, b: SafeNum, c: SafeNum, d: SafeNum) -> Result<bool, &'static str> { +//! Ok(safemath::Primitive::try_from(a)? < b.try_into()? +//! && safemath::Primitive::try_from(c)? >= d.try_into()?) +//! } +//! ``` +//! +//! ### Miscellaneous +//! SafeNums also provide helper methods to round up or down +//! to the nearest multiple of another number +//! and helper predicate methods that indicate whether the SafeNum +//! is valid or is tracking an error. +//! +//! Also provided are constants `SafeNum::MAX`, `SafeNum::MIN`, and `SafeNum::ZERO`. +//! +//! Warning: SafeNums can help prevent, isolate, and detect arithmetic overflow +//! but they are not a panacea. In particular, chains of different operations +//! are not guaranteed to be associative or commutative. +//! +//! E.g. +//! ``` +//! let a = SafeNum::MAX - 1 + 1; +//! let b = SafeNum::MAX + 1 - 1; +//! assert_ne!(a, b); +//! assert!(a.is_valid()); +//! assert!(b.has_error()); +//! +//! let c = (SafeNum::MAX + 31) / 31; +//! let d = SafeNum::MAX / 31 + 31 / 31; +//! assert_ne!(c, d); +//! assert!(c.has_error()); +//! assert!(d.is_valid()); +//! ``` +//! +//! Note: SafeNum arithmetic is much slower than arithmetic on integer primitives. +//! If you are concerned about performance, be sure to run benchmarks. + +#![cfg_attr(not(test), no_std)] + +use core::convert::TryFrom; +use core::fmt; +use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Rem, RemAssign, Sub, SubAssign}; +use core::panic::Location; + +pub type Primitive = u64; +pub type Error = &'static Location<'static>; + +#[derive(Copy, Clone, PartialEq, Eq)] +pub struct SafeNum(Result<Primitive, Error>); + +impl fmt::Debug for SafeNum { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.0 { + Ok(val) => write!(f, "{}", val), + Err(location) => write!(f, "error at {}", location), + } + } +} + +impl SafeNum { + pub const MAX: SafeNum = SafeNum(Ok(u64::MAX)); + pub const MIN: SafeNum = SafeNum(Ok(u64::MIN)); + pub const ZERO: SafeNum = SafeNum(Ok(0)); + + /// Round `self` down to the nearest multiple of `rhs`. + #[track_caller] + pub fn round_down<T>(self, rhs: T) -> Self + where + Self: Rem<T, Output = Self>, + { + self - (self % rhs) + } + + /// Round `self` up to the nearest multiple of `rhs`. + #[track_caller] + pub fn round_up<T>(self, rhs: T) -> Self + where + Self: Add<T, Output = Self>, + T: Copy + Into<Self>, + { + ((self + rhs) - 1).round_down(rhs) + } + + /// Returns whether self is the result of an operation that has errored. + pub const fn has_error(&self) -> bool { + self.0.is_err() + } + + /// Returns whether self represents a valid, non-overflowed integer. + pub const fn is_valid(&self) -> bool { + self.0.is_ok() + } +} + +macro_rules! try_conversion_func { + ($other_type:tt) => { + impl TryFrom<SafeNum> for $other_type { + type Error = Error; + + #[track_caller] + fn try_from(val: SafeNum) -> Result<Self, Self::Error> { + Self::try_from(val.0?).map_err(|_| Location::caller()) + } + } + }; +} + +macro_rules! conversion_func { + ($from_type:tt) => { + impl From<$from_type> for SafeNum { + fn from(val: $from_type) -> SafeNum { + Self(Ok(val.into())) + } + } + + try_conversion_func!($from_type); + }; +} + +macro_rules! conversion_func_maybe_error { + ($from_type:tt) => { + impl From<$from_type> for SafeNum { + #[track_caller] + fn from(val: $from_type) -> Self { + Self(Primitive::try_from(val).map_err(|_| Location::caller())) + } + } + + try_conversion_func!($from_type); + }; +} + +macro_rules! arithmetic_impl { + ($trait_name:ident, $op:ident, $assign_trait_name:ident, $assign_op:ident, $func:ident) => { + impl<T: Into<SafeNum>> $trait_name<T> for SafeNum { + type Output = Self; + #[track_caller] + fn $op(self, rhs: T) -> Self { + let rhs: Self = rhs.into(); + + match (self.0, rhs.0) { + (Err(_), _) => self, + (_, Err(_)) => rhs, + (Ok(lhs), Ok(rhs)) => Self(lhs.$func(rhs).ok_or_else(Location::caller)), + } + } + } + + impl<T> $assign_trait_name<T> for SafeNum + where + Self: $trait_name<T, Output = Self>, + { + #[track_caller] + fn $assign_op(&mut self, rhs: T) { + *self = self.$op(rhs) + } + } + }; +} + +conversion_func!(u8); +conversion_func!(u16); +conversion_func!(u32); +conversion_func!(u64); +conversion_func_maybe_error!(usize); +conversion_func_maybe_error!(u128); +conversion_func_maybe_error!(i8); +conversion_func_maybe_error!(i16); +conversion_func_maybe_error!(i32); +conversion_func_maybe_error!(i64); +conversion_func_maybe_error!(i128); +conversion_func_maybe_error!(isize); +arithmetic_impl!(Add, add, AddAssign, add_assign, checked_add); +arithmetic_impl!(Sub, sub, SubAssign, sub_assign, checked_sub); +arithmetic_impl!(Mul, mul, MulAssign, mul_assign, checked_mul); +arithmetic_impl!(Div, div, DivAssign, div_assign, checked_div); +arithmetic_impl!(Rem, rem, RemAssign, rem_assign, checked_rem); + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_addition() { + let a: SafeNum = 2100.into(); + let b: SafeNum = 12.into(); + assert_eq!(a + b, 2112.into()); + } + + #[test] + fn test_subtraction() { + let a: SafeNum = 667.into(); + let b: SafeNum = 1.into(); + assert_eq!(a - b, 666.into()); + } + + #[test] + fn test_multiplication() { + let a: SafeNum = 17.into(); + let b: SafeNum = 3.into(); + assert_eq!(a * b, 51.into()); + } + + #[test] + fn test_division() { + let a: SafeNum = 1066.into(); + let b: SafeNum = 41.into(); + assert_eq!(a / b, 26.into()); + } + + #[test] + fn test_remainder() { + let a: SafeNum = 613.into(); + let b: SafeNum = 10.into(); + assert_eq!(a % b, 3.into()); + } + + #[test] + fn test_addition_poison() { + let base: SafeNum = 2.into(); + let poison = base + SafeNum::MAX; + assert!(u64::try_from(poison).is_err()); + + let a = poison - 1; + let b = poison - 2; + + assert_eq!(a, poison); + assert_eq!(b, poison); + } + + #[test] + fn test_subtraction_poison() { + let base: SafeNum = 2.into(); + let poison = base - SafeNum::MAX; + assert!(u64::try_from(poison).is_err()); + + let a = poison + 1; + let b = poison + 2; + + assert_eq!(a, poison); + assert_eq!(b, poison); + } + + #[test] + fn test_multiplication_poison() { + let base: SafeNum = 2.into(); + let poison = base * SafeNum::MAX; + assert!(u64::try_from(poison).is_err()); + + let a = poison / 2; + let b = poison / 4; + + assert_eq!(a, poison); + assert_eq!(b, poison); + } + + #[test] + fn test_division_poison() { + let base: SafeNum = 2.into(); + let poison = base / 0; + assert!(u64::try_from(poison).is_err()); + + let a = poison * 2; + let b = poison * 4; + + assert_eq!(a, poison); + assert_eq!(b, poison); + } + + #[test] + fn test_remainder_poison() { + let base: SafeNum = 2.into(); + let poison = base % 0; + assert!(u64::try_from(poison).is_err()); + + let a = poison * 2; + let b = poison * 4; + + assert_eq!(a, poison); + assert_eq!(b, poison); + } + + macro_rules! conversion_test { + ($name:ident) => { + mod $name { + use super::*; + use core::convert::TryInto; + + #[test] + fn test_between_safenum() { + let var: $name = 16; + let sn: SafeNum = var.into(); + let res: $name = sn.try_into().unwrap(); + assert_eq!(var, res); + } + + #[test] + fn test_arithmetic_safenum() { + let primitive: $name = ((((0 + 11) * 11) / 3) % 32) - 3; + let safe = ((((SafeNum::ZERO + $name::try_from(11u8).unwrap()) + * $name::try_from(11u8).unwrap()) + / $name::try_from(3u8).unwrap()) + % $name::try_from(32u8).unwrap()) + - $name::try_from(3u8).unwrap(); + assert_eq!($name::try_from(safe).unwrap(), primitive); + } + } + }; + } + + conversion_test!(u8); + conversion_test!(u16); + conversion_test!(u32); + conversion_test!(u64); + conversion_test!(u128); + conversion_test!(usize); + conversion_test!(i8); + conversion_test!(i16); + conversion_test!(i32); + conversion_test!(i64); + conversion_test!(i128); + conversion_test!(isize); + + macro_rules! correctness_tests { + ($name:ident, $operation:ident, $assign_operation:ident) => { + mod $operation { + use super::*; + use core::ops::$name; + + #[test] + fn test_correctness() { + let normal = 300u64; + let safe: SafeNum = normal.into(); + let rhs = 7u64; + assert_eq!( + u64::try_from(safe.$operation(rhs)).unwrap(), + normal.$operation(rhs) + ); + } + + #[test] + fn test_assign() { + let mut var: SafeNum = 2112.into(); + let rhs = 666u64; + let expect = var.$operation(rhs); + var.$assign_operation(rhs); + assert_eq!(var, expect); + } + + #[test] + fn test_assign_poison() { + let mut var = SafeNum::MIN - 1; + let expected = var - 1; + var.$assign_operation(2); + // Poison saturates and doesn't perform additional changes + assert_eq!(var, expected); + } + } + }; + } + + correctness_tests!(Add, add, add_assign); + correctness_tests!(Sub, sub, sub_assign); + correctness_tests!(Mul, mul, mul_assign); + correctness_tests!(Div, div, div_assign); + correctness_tests!(Rem, rem, rem_assign); + + #[test] + fn test_round_down() { + let x: SafeNum = 255.into(); + assert_eq!(x.round_down(32), 224.into()); + assert_eq!((x + 1).round_down(64), 256.into()); + assert_eq!(x.round_down(256), SafeNum::ZERO); + assert!(x.round_down(SafeNum::MIN).has_error()); + } + + #[test] + fn test_round_up() { + let x: SafeNum = 255.into(); + assert_eq!(x.round_up(32), 256.into()); + assert_eq!(x.round_up(51), x); + assert_eq!(SafeNum::ZERO.round_up(x), SafeNum::ZERO); + assert!(SafeNum::MAX.round_up(32).has_error()); + } +} diff --git a/gbl/tests/BUILD b/gbl/tests/BUILD index 72d1ad6..a072b39 100644 --- a/gbl/tests/BUILD +++ b/gbl/tests/BUILD @@ -21,6 +21,7 @@ test_suite( "@gbl//libfastboot:libfastboot_test", "@gbl//libfdt:libfdt_test", "@gbl//libgbl:libgbl_test", + "@gbl//libsafemath:libsafemath_test", "@gbl//libstorage:libstorage_doc_test", "@gbl//libstorage:libstorage_test", "@gbl//third_party/libzbi:libzbi_test", |