Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions utils/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,14 @@ pub fn union_unique(namespace: &[u8], msg: &[u8]) -> Vec<u8> {
/// This function is used to select a random entry from an array
/// when the bytes are a random seed.
pub fn modulo(bytes: &[u8], n: u64) -> u64 {
let mut result = 0;
if n == 0 {
return 0;
}

let mut result = 0u64;
for &byte in bytes {
result = (result << 8) | (byte as u64);
// Apply modulo after each shift to prevent overflow
result = ((result << 8) % n) | (byte as u64);
result %= n;
}
result
Expand Down Expand Up @@ -190,16 +195,16 @@ mod tests {

// Test case 5: whitespace
let h = "01 02 03";
assert_eq!(from_hex_formatted(h).unwrap(), b.to_vec());
assert_eq!(from_hex_formatted(h).unwrap(), &[0x01, 0x02, 0x03]);

// Test case 6: 0x prefix
let h = "0x010203";
assert_eq!(from_hex_formatted(h).unwrap(), b.to_vec());
assert_eq!(from_hex_formatted(h).unwrap(), &[0x01, 0x02, 0x03]);

// Test case 7: 0x prefix + different whitespace chars
let h = " \n\n0x\r\n01
02\t03\n";
assert_eq!(from_hex_formatted(h).unwrap(), b.to_vec());
assert_eq!(from_hex_formatted(h).unwrap(), &[0x01, 0x02, 0x03]);
}

#[test]
Expand Down Expand Up @@ -320,8 +325,11 @@ mod tests {

// Test case 2: multiple bytes
assert_eq!(modulo(&[0x01, 0x02, 0x03], 10), 1);

// Test case 3: n=0
assert_eq!(modulo(&[0x01, 0x02, 0x03], 0), 0);

// Test case 3: check equivalence with BigUint
// Test case 4: check equivalence with BigUint
let n = 11u64;
for i in 0..100 {
let mut rng = StdRng::seed_from_u64(i);
Expand All @@ -330,6 +338,13 @@ mod tests {
let utils_modulo = modulo(&bytes, n);
assert_eq!(big_modulo, BigUint::from(utils_modulo));
}

// Test case 5: large modulus to check overflow handling
let large_modulus = u64::MAX - 1;
let bytes = [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01];
let big_modulo = BigUint::from_bytes_be(&bytes) % large_modulus;
let utils_modulo = modulo(&bytes, large_modulus);
assert_eq!(big_modulo, BigUint::from(utils_modulo));
}

#[test]
Expand Down
Loading