diff options
Diffstat (limited to 'src/engine/general_purpose/decode.rs')
-rw-r--r-- | src/engine/general_purpose/decode.rs | 59 |
1 files changed, 47 insertions, 12 deletions
diff --git a/src/engine/general_purpose/decode.rs b/src/engine/general_purpose/decode.rs index e9fd788..21a386f 100644 --- a/src/engine/general_purpose/decode.rs +++ b/src/engine/general_purpose/decode.rs @@ -1,5 +1,5 @@ use crate::{ - engine::{general_purpose::INVALID_VALUE, DecodeEstimate, DecodePaddingMode}, + engine::{general_purpose::INVALID_VALUE, DecodeEstimate, DecodeMetadata, DecodePaddingMode}, DecodeError, PAD_BYTE, }; @@ -30,16 +30,11 @@ pub struct GeneralPurposeEstimate { impl GeneralPurposeEstimate { pub(crate) fn new(encoded_len: usize) -> Self { + // Formulas that won't overflow Self { - num_chunks: encoded_len - .checked_add(INPUT_CHUNK_LEN - 1) - .expect("Overflow when calculating number of chunks in input") - / INPUT_CHUNK_LEN, - decoded_len_estimate: encoded_len - .checked_add(3) - .expect("Overflow when calculating decoded len estimate") - / 4 - * 3, + num_chunks: encoded_len / INPUT_CHUNK_LEN + + (encoded_len % INPUT_CHUNK_LEN > 0) as usize, + decoded_len_estimate: (encoded_len / 4 + (encoded_len % 4 > 0) as usize) * 3, } } } @@ -51,7 +46,7 @@ impl DecodeEstimate for GeneralPurposeEstimate { } /// Helper to avoid duplicating num_chunks calculation, which is costly on short inputs. -/// Returns the number of bytes written, or an error. +/// Returns the decode metadata, or an error. // We're on the fragile edge of compiler heuristics here. If this is not inlined, slow. If this is // inlined(always), a different slow. plain ol' inline makes the benchmarks happiest at the moment, // but this is fragile and the best setting changes with only minor code modifications. @@ -63,7 +58,7 @@ pub(crate) fn decode_helper( decode_table: &[u8; 256], decode_allow_trailing_bits: bool, padding_mode: DecodePaddingMode, -) -> Result<usize, DecodeError> { +) -> Result<DecodeMetadata, DecodeError> { let remainder_len = input.len() % INPUT_CHUNK_LEN; // Because the fast decode loop writes in groups of 8 bytes (unrolled to @@ -345,4 +340,44 @@ mod tests { decode_chunk(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap(); assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 0, 0], &output); } + + #[test] + fn estimate_short_lengths() { + for (range, (num_chunks, decoded_len_estimate)) in [ + (0..=0, (0, 0)), + (1..=4, (1, 3)), + (5..=8, (1, 6)), + (9..=12, (2, 9)), + (13..=16, (2, 12)), + (17..=20, (3, 15)), + ] { + for encoded_len in range { + let estimate = GeneralPurposeEstimate::new(encoded_len); + assert_eq!(num_chunks, estimate.num_chunks); + assert_eq!(decoded_len_estimate, estimate.decoded_len_estimate); + } + } + } + + #[test] + fn estimate_via_u128_inflation() { + // cover both ends of usize + (0..1000) + .chain(usize::MAX - 1000..=usize::MAX) + .for_each(|encoded_len| { + // inflate to 128 bit type to be able to safely use the easy formulas + let len_128 = encoded_len as u128; + + let estimate = GeneralPurposeEstimate::new(encoded_len); + assert_eq!( + ((len_128 + (INPUT_CHUNK_LEN - 1) as u128) / (INPUT_CHUNK_LEN as u128)) + as usize, + estimate.num_chunks + ); + assert_eq!( + ((len_128 + 3) / 4 * 3) as usize, + estimate.decoded_len_estimate + ); + }) + } } |