use crate::{ engine::{general_purpose::INVALID_VALUE, DecodeEstimate, DecodeMetadata, DecodePaddingMode}, DecodeError, PAD_BYTE, }; // decode logic operates on chunks of 8 input bytes without padding const INPUT_CHUNK_LEN: usize = 8; const DECODED_CHUNK_LEN: usize = 6; // we read a u64 and write a u64, but a u64 of input only yields 6 bytes of output, so the last // 2 bytes of any output u64 should not be counted as written to (but must be available in a // slice). const DECODED_CHUNK_SUFFIX: usize = 2; // how many u64's of input to handle at a time const CHUNKS_PER_FAST_LOOP_BLOCK: usize = 4; const INPUT_BLOCK_LEN: usize = CHUNKS_PER_FAST_LOOP_BLOCK * INPUT_CHUNK_LEN; // includes the trailing 2 bytes for the final u64 write const DECODED_BLOCK_LEN: usize = CHUNKS_PER_FAST_LOOP_BLOCK * DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX; #[doc(hidden)] pub struct GeneralPurposeEstimate { /// Total number of decode chunks, including a possibly partial last chunk num_chunks: usize, decoded_len_estimate: usize, } impl GeneralPurposeEstimate { pub(crate) fn new(encoded_len: usize) -> Self { // Formulas that won't overflow Self { num_chunks: encoded_len / INPUT_CHUNK_LEN + (encoded_len % INPUT_CHUNK_LEN > 0) as usize, decoded_len_estimate: (encoded_len / 4 + (encoded_len % 4 > 0) as usize) * 3, } } } impl DecodeEstimate for GeneralPurposeEstimate { fn decoded_len_estimate(&self) -> usize { self.decoded_len_estimate } } /// Helper to avoid duplicating num_chunks calculation, which is costly on short inputs. /// Returns the decode metadata, or an error. // We're on the fragile edge of compiler heuristics here. If this is not inlined, slow. If this is // inlined(always), a different slow. plain ol' inline makes the benchmarks happiest at the moment, // but this is fragile and the best setting changes with only minor code modifications. #[inline] pub(crate) fn decode_helper( input: &[u8], estimate: GeneralPurposeEstimate, output: &mut [u8], decode_table: &[u8; 256], decode_allow_trailing_bits: bool, padding_mode: DecodePaddingMode, ) -> Result { let remainder_len = input.len() % INPUT_CHUNK_LEN; // Because the fast decode loop writes in groups of 8 bytes (unrolled to // CHUNKS_PER_FAST_LOOP_BLOCK times 8 bytes, where possible) and outputs 8 bytes at a time (of // which only 6 are valid data), we need to be sure that we stop using the fast decode loop // soon enough that there will always be 2 more bytes of valid data written after that loop. let trailing_bytes_to_skip = match remainder_len { // if input is a multiple of the chunk size, ignore the last chunk as it may have padding, // and the fast decode logic cannot handle padding 0 => INPUT_CHUNK_LEN, // 1 and 5 trailing bytes are illegal: can't decode 6 bits of input into a byte 1 | 5 => { // trailing whitespace is so common that it's worth it to check the last byte to // possibly return a better error message if let Some(b) = input.last() { if *b != PAD_BYTE && decode_table[*b as usize] == INVALID_VALUE { return Err(DecodeError::InvalidByte(input.len() - 1, *b)); } } return Err(DecodeError::InvalidLength); } // This will decode to one output byte, which isn't enough to overwrite the 2 extra bytes // written by the fast decode loop. So, we have to ignore both these 2 bytes and the // previous chunk. 2 => INPUT_CHUNK_LEN + 2, // If this is 3 un-padded chars, then it would actually decode to 2 bytes. However, if this // is an erroneous 2 chars + 1 pad char that would decode to 1 byte, then it should fail // with an error, not panic from going past the bounds of the output slice, so we let it // use stage 3 + 4. 3 => INPUT_CHUNK_LEN + 3, // This can also decode to one output byte because it may be 2 input chars + 2 padding // chars, which would decode to 1 byte. 4 => INPUT_CHUNK_LEN + 4, // Everything else is a legal decode len (given that we don't require padding), and will // decode to at least 2 bytes of output. _ => remainder_len, }; // rounded up to include partial chunks let mut remaining_chunks = estimate.num_chunks; let mut input_index = 0; let mut output_index = 0; { let length_of_fast_decode_chunks = input.len().saturating_sub(trailing_bytes_to_skip); // Fast loop, stage 1 // manual unroll to CHUNKS_PER_FAST_LOOP_BLOCK of u64s to amortize slice bounds checks if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_BLOCK_LEN) { while input_index <= max_start_index { let input_slice = &input[input_index..(input_index + INPUT_BLOCK_LEN)]; let output_slice = &mut output[output_index..(output_index + DECODED_BLOCK_LEN)]; decode_chunk( &input_slice[0..], input_index, decode_table, &mut output_slice[0..], )?; decode_chunk( &input_slice[8..], input_index + 8, decode_table, &mut output_slice[6..], )?; decode_chunk( &input_slice[16..], input_index + 16, decode_table, &mut output_slice[12..], )?; decode_chunk( &input_slice[24..], input_index + 24, decode_table, &mut output_slice[18..], )?; input_index += INPUT_BLOCK_LEN; output_index += DECODED_BLOCK_LEN - DECODED_CHUNK_SUFFIX; remaining_chunks -= CHUNKS_PER_FAST_LOOP_BLOCK; } } // Fast loop, stage 2 (aka still pretty fast loop) // 8 bytes at a time for whatever we didn't do in stage 1. if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_CHUNK_LEN) { while input_index < max_start_index { decode_chunk( &input[input_index..(input_index + INPUT_CHUNK_LEN)], input_index, decode_table, &mut output [output_index..(output_index + DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX)], )?; output_index += DECODED_CHUNK_LEN; input_index += INPUT_CHUNK_LEN; remaining_chunks -= 1; } } } // Stage 3 // If input length was such that a chunk had to be deferred until after the fast loop // because decoding it would have produced 2 trailing bytes that wouldn't then be // overwritten, we decode that chunk here. This way is slower but doesn't write the 2 // trailing bytes. // However, we still need to avoid the last chunk (partial or complete) because it could // have padding, so we always do 1 fewer to avoid the last chunk. for _ in 1..remaining_chunks { decode_chunk_precise( &input[input_index..], input_index, decode_table, &mut output[output_index..(output_index + DECODED_CHUNK_LEN)], )?; input_index += INPUT_CHUNK_LEN; output_index += DECODED_CHUNK_LEN; } // always have one more (possibly partial) block of 8 input debug_assert!(input.len() - input_index > 1 || input.is_empty()); debug_assert!(input.len() - input_index <= 8); super::decode_suffix::decode_suffix( input, input_index, output, output_index, decode_table, decode_allow_trailing_bits, padding_mode, ) } /// Decode 8 bytes of input into 6 bytes of output. 8 bytes of output will be written, but only the /// first 6 of those contain meaningful data. /// /// `input` is the bytes to decode, of which the first 8 bytes will be processed. /// `index_at_start_of_input` is the offset in the overall input (used for reporting errors /// accurately) /// `decode_table` is the lookup table for the particular base64 alphabet. /// `output` will have its first 8 bytes overwritten, of which only the first 6 are valid decoded /// data. // yes, really inline (worth 30-50% speedup) #[inline(always)] fn decode_chunk( input: &[u8], index_at_start_of_input: usize, decode_table: &[u8; 256], output: &mut [u8], ) -> Result<(), DecodeError> { let morsel = decode_table[input[0] as usize]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0])); } let mut accum = (morsel as u64) << 58; let morsel = decode_table[input[1] as usize]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte( index_at_start_of_input + 1, input[1], )); } accum |= (morsel as u64) << 52; let morsel = decode_table[input[2] as usize]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte( index_at_start_of_input + 2, input[2], )); } accum |= (morsel as u64) << 46; let morsel = decode_table[input[3] as usize]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte( index_at_start_of_input + 3, input[3], )); } accum |= (morsel as u64) << 40; let morsel = decode_table[input[4] as usize]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte( index_at_start_of_input + 4, input[4], )); } accum |= (morsel as u64) << 34; let morsel = decode_table[input[5] as usize]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte( index_at_start_of_input + 5, input[5], )); } accum |= (morsel as u64) << 28; let morsel = decode_table[input[6] as usize]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte( index_at_start_of_input + 6, input[6], )); } accum |= (morsel as u64) << 22; let morsel = decode_table[input[7] as usize]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte( index_at_start_of_input + 7, input[7], )); } accum |= (morsel as u64) << 16; write_u64(output, accum); Ok(()) } /// Decode an 8-byte chunk, but only write the 6 bytes actually decoded instead of including 2 /// trailing garbage bytes. #[inline] fn decode_chunk_precise( input: &[u8], index_at_start_of_input: usize, decode_table: &[u8; 256], output: &mut [u8], ) -> Result<(), DecodeError> { let mut tmp_buf = [0_u8; 8]; decode_chunk( input, index_at_start_of_input, decode_table, &mut tmp_buf[..], )?; output[0..6].copy_from_slice(&tmp_buf[0..6]); Ok(()) } #[inline] fn write_u64(output: &mut [u8], value: u64) { output[..8].copy_from_slice(&value.to_be_bytes()); } #[cfg(test)] mod tests { use super::*; use crate::engine::general_purpose::STANDARD; #[test] fn decode_chunk_precise_writes_only_6_bytes() { let input = b"Zm9vYmFy"; // "foobar" let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7]; decode_chunk_precise(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap(); assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 6, 7], &output); } #[test] fn decode_chunk_writes_8_bytes() { let input = b"Zm9vYmFy"; // "foobar" let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7]; decode_chunk(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap(); assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 0, 0], &output); } #[test] fn estimate_short_lengths() { for (range, (num_chunks, decoded_len_estimate)) in [ (0..=0, (0, 0)), (1..=4, (1, 3)), (5..=8, (1, 6)), (9..=12, (2, 9)), (13..=16, (2, 12)), (17..=20, (3, 15)), ] { for encoded_len in range { let estimate = GeneralPurposeEstimate::new(encoded_len); assert_eq!(num_chunks, estimate.num_chunks); assert_eq!(decoded_len_estimate, estimate.decoded_len_estimate); } } } #[test] fn estimate_via_u128_inflation() { // cover both ends of usize (0..1000) .chain(usize::MAX - 1000..=usize::MAX) .for_each(|encoded_len| { // inflate to 128 bit type to be able to safely use the easy formulas let len_128 = encoded_len as u128; let estimate = GeneralPurposeEstimate::new(encoded_len); assert_eq!( ((len_128 + (INPUT_CHUNK_LEN - 1) as u128) / (INPUT_CHUNK_LEN as u128)) as usize, estimate.num_chunks ); assert_eq!( ((len_128 + 3) / 4 * 3) as usize, estimate.decoded_len_estimate ); }) } }