diff --git a/src/libextra/crypto/cryptoutil.rs b/src/libextra/crypto/cryptoutil.rs index b89f77ec5c1..9fd02baf541 100644 --- a/src/libextra/crypto/cryptoutil.rs +++ b/src/libextra/crypto/cryptoutil.rs @@ -8,7 +8,7 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use std::num::One; +use std::num::{One, Zero, CheckedAdd}; use std::vec::bytes::{MutableByteVector, copy_memory}; @@ -97,50 +97,73 @@ pub fn read_u32v_le(dst: &mut[u32], input: &[u8]) { } -/// Returns true if adding the two parameters will result in integer overflow -pub fn will_add_overflow(x: T, y: T) -> bool { - // This doesn't handle negative values! Don't copy this code elsewhere without considering if - // negative values are important to you! - let max: T = Bounded::max_value(); - return x > max - y; +trait ToBits { + /// Convert the value in bytes to the number of bits, a tuple where the 1st item is the + /// high-order value and the 2nd item is the low order value. + fn to_bits(self) -> (Self, Self); } -/// Shifts the second parameter and then adds it to the first. fails!() if there would be unsigned -/// integer overflow. -pub fn shift_add_check_overflow(x: T, mut y: T, shift: T) -> T { - if y.leading_zeros() < shift { - fail!("Could not add values - integer overflow."); +impl ToBits for u64 { + fn to_bits(self) -> (u64, u64) { + return (self >> 61, self << 3); } - y = y << shift; - - if will_add_overflow(x.clone(), y.clone()) { - fail!("Could not add values - integer overflow."); - } - - return x + y; } -/// Shifts the second parameter and then adds it to the first, which is a tuple where the first -/// element is the high order value. fails!() if there would be unsigned integer overflow. -pub fn shift_add_check_overflow_tuple - - (x: (T, T), mut y: T, shift: T) -> (T, T) { - if y.leading_zeros() < shift { - fail!("Could not add values - integer overflow."); - } - y = y << shift; +/// Adds the specified number of bytes to the bit count. fail!() if this would cause numeric +/// overflow. +pub fn add_bytes_to_bits(bits: T, bytes: T) -> T { + let (new_high_bits, new_low_bits) = bytes.to_bits(); - match x { - (hi, low) => { - let one: T = One::one(); - if will_add_overflow(low.clone(), y.clone()) { - if will_add_overflow(hi.clone(), one.clone()) { - fail!("Could not add values - integer overflow."); - } else { - return (hi + one, low + y); - } + if new_high_bits > Zero::zero() { + fail!("Numeric overflow occured.") + } + + match bits.checked_add(&new_low_bits) { + Some(x) => return x, + None => fail!("Numeric overflow occured.") + } +} + +/// Adds the specified number of bytes to the bit count, which is a tuple where the first element is +/// the high order value. fail!() if this would cause numeric overflow. +pub fn add_bytes_to_bits_tuple + + (bits: (T, T), bytes: T) -> (T, T) { + let (new_high_bits, new_low_bits) = bytes.to_bits(); + let (hi, low) = bits; + + // Add the low order value - if there is no overflow, then add the high order values + // If the addition of the low order values causes overflow, add one to the high order values + // before adding them. + match low.checked_add(&new_low_bits) { + Some(x) => { + if new_high_bits == Zero::zero() { + // This is the fast path - every other alternative will rarely occur in practice + // considering how large an input would need to be for those paths to be used. + return (hi, x); } else { - return (hi, low + y); + match hi.checked_add(&new_high_bits) { + Some(y) => return (y, x), + None => fail!("Numeric overflow occured.") + } + } + }, + None => { + let one: T = One::one(); + let z = match new_high_bits.checked_add(&one) { + Some(w) => w, + None => fail!("Numeric overflow occured.") + }; + match hi.checked_add(&z) { + // This re-executes the addition that was already performed earlier when overflow + // occured, this time allowing the overflow to happen. Technically, this could be + // avoided by using the checked add intrinsic directly, but that involves using + // unsafe code and is not really worthwhile considering how infrequently code will + // run in practice. This is the reason that this function requires that the type T + // be Unsigned - overflow is not defined for Signed types. This function could be + // implemented for signed types as well if that were needed. + Some(y) => return (y, low + new_low_bits), + None => fail!("Numeric overflow occured.") } } } diff --git a/src/libextra/crypto/sha1.rs b/src/libextra/crypto/sha1.rs index 8ee9006f613..4d4d47feee8 100644 --- a/src/libextra/crypto/sha1.rs +++ b/src/libextra/crypto/sha1.rs @@ -23,7 +23,7 @@ */ -use cryptoutil::{write_u32_be, read_u32v_be, shift_add_check_overflow, FixedBuffer, FixedBuffer64, +use cryptoutil::{write_u32_be, read_u32v_be, add_bytes_to_bits, FixedBuffer, FixedBuffer64, StandardPadding}; use digest::Digest; @@ -52,7 +52,7 @@ pub struct Sha1 { fn add_input(st: &mut Sha1, msg: &[u8]) { assert!((!st.computed)); // Assumes that msg.len() can be converted to u64 without overflow - st.length_bits = shift_add_check_overflow(st.length_bits, msg.len() as u64, 3); + st.length_bits = add_bytes_to_bits(st.length_bits, msg.len() as u64); st.buffer.input(msg, |d: &[u8]| { process_msg_block(d, &mut st.h); }); } diff --git a/src/libextra/crypto/sha2.rs b/src/libextra/crypto/sha2.rs index 47535d5103a..96f3e13eb22 100644 --- a/src/libextra/crypto/sha2.rs +++ b/src/libextra/crypto/sha2.rs @@ -10,8 +10,8 @@ use std::uint; -use cryptoutil::{write_u64_be, write_u32_be, read_u64v_be, read_u32v_be, shift_add_check_overflow, - shift_add_check_overflow_tuple, FixedBuffer, FixedBuffer128, FixedBuffer64, StandardPadding}; +use cryptoutil::{write_u64_be, write_u32_be, read_u64v_be, read_u32v_be, add_bytes_to_bits, + add_bytes_to_bits_tuple, FixedBuffer, FixedBuffer128, FixedBuffer64, StandardPadding}; use digest::Digest; @@ -210,7 +210,7 @@ impl Engine512 { fn input(&mut self, input: &[u8]) { assert!(!self.finished) // Assumes that input.len() can be converted to u64 without overflow - self.length_bits = shift_add_check_overflow_tuple(self.length_bits, input.len() as u64, 3); + self.length_bits = add_bytes_to_bits_tuple(self.length_bits, input.len() as u64); self.buffer.input(input, |input: &[u8]| { self.state.process_block(input) }); } @@ -602,7 +602,7 @@ impl Engine256 { fn input(&mut self, input: &[u8]) { assert!(!self.finished) // Assumes that input.len() can be converted to u64 without overflow - self.length_bits = shift_add_check_overflow(self.length_bits, input.len() as u64, 3); + self.length_bits = add_bytes_to_bits(self.length_bits, input.len() as u64); self.buffer.input(input, |input: &[u8]| { self.state.process_block(input) }); }