Skip to content

Commit

Permalink
Factor our code that checks for invalid trailing bits into functions
Browse files Browse the repository at this point in the history
  • Loading branch information
DaGenix committed Nov 12, 2021
1 parent 3bcd4a4 commit a8d2666
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
10 changes: 6 additions & 4 deletions src/stateful_decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ pub enum ProvideQuintetResult {
HaveOctets(HaveOctets),
}

pub fn quintet_has_valid_trailing_bits(last_quintet_bits: u8, quintet: u8) -> bool {
let trailing_bits_mask = 0x1fu8.checked_shr(last_quintet_bits as u32).unwrap_or(0);
quintet & trailing_bits_mask == 0
}

impl NeedQuintets {
pub fn new(last_quintet_bits: u8) -> Result<NeedQuintets, ZBase32Error> {
assert!(last_quintet_bits != 0 && last_quintet_bits <= 5);
Expand All @@ -31,10 +36,7 @@ impl NeedQuintets {
}

if last_quintet {
let trailing_bits_mask = 0x1fu8
.checked_shr(self.last_quintet_bits as u32)
.unwrap_or(0);
if quintet & trailing_bits_mask != 0 {
if !quintet_has_valid_trailing_bits(self.last_quintet_bits, quintet) {
return Err(zbase32_error(ZBase32ErrorInfo::TrailingNonZeroBits));
}
}
Expand Down
8 changes: 6 additions & 2 deletions src/stateful_encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ pub enum ProvideOctetResult {
HaveQuintets(HaveQuintets),
}

pub fn octet_has_valid_trailing_bits(last_octet_bits: u8, octet: u8) -> bool {
let trailing_bits_mask = 0xffu8.checked_shr(last_octet_bits as u32).unwrap_or(0);
octet & trailing_bits_mask == 0
}

impl NeedOctets {
pub fn new(last_octet_bits: u8) -> Result<NeedOctets, ZBase32Error> {
assert!(last_octet_bits != 0 && last_octet_bits <= 8);
Expand All @@ -27,8 +32,7 @@ impl NeedOctets {
last_octet: bool,
) -> Result<ProvideOctetResult, ZBase32Error> {
if last_octet {
let trailing_bits_mask = 0xffu8.checked_shr(self.last_octet_bits as u32).unwrap_or(0);
if octet & trailing_bits_mask != 0 {
if !octet_has_valid_trailing_bits(self.last_octet_bits, octet) {
return Err(zbase32_error(ZBase32ErrorInfo::TrailingNonZeroBits));
}
}
Expand Down

0 comments on commit a8d2666

Please sign in to comment.