Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct EcPoint const time comparison and better comments and namings #345

Merged
merged 6 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mbedtls/benches/ecp_eq_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ fn ecp_equal(a: &EcPoint, b: &EcPoint) {
}

fn ecp_equal_const_time(a: &EcPoint, b: &EcPoint) {
assert!(!a.eq_const_time(&b));
assert!(!a.eq_const_time(&b).unwrap());
}

fn criterion_benchmark(c: &mut Criterion) {
Expand Down
98 changes: 98 additions & 0 deletions mbedtls/src/bignum/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,25 @@ impl Mpi {
}
}

/// Checks if an [`Mpi`] is less than the other in constant time.
///
/// Will return [`Error::MpiBadInputData`] if the allocated length of the two input [`Mpi`]s is not the same.
pub fn less_than_const_time(&self, other: &Mpi) -> Result<bool> {
mpi_inner_less_than_const_time(&self.inner, &other.inner)
}

/// Compares an [`Mpi`] with the other in constant time.
///
/// Will return [`Error::MpiBadInputData`] if the allocated length of the two input [`Mpi`]s is not the same.
pub fn cmp_const_time(&self, other: &Mpi) -> Result<Ordering> {
mpi_inner_cmp_const_time(&self.inner, &other.inner)
}

/// Checks equalness with the other in constant time.
pub fn eq_const_time(&self, other: &Mpi) -> Result<bool> {
mpi_inner_eq_const_time(&self.inner, &other.inner)
}

pub fn as_u32(&self) -> Result<u32> {
if self.bit_length()? > 32 {
// Not exactly correct but close enough
Expand Down Expand Up @@ -409,6 +428,35 @@ impl Mpi {
}
}

pub(super) fn mpi_inner_eq_const_time(x: &mpi, y: &mpi) -> core::prelude::v1::Result<bool, Error> {
match mpi_inner_cmp_const_time(x, y) {
Ok(order) => Ok(order == Ordering::Equal),
Err(Error::MpiBadInputData) => Ok(false),
Err(e) => Err(e),
}
}

fn mpi_inner_cmp_const_time(x: &mpi, y: &mpi) -> Result<Ordering> {
let less = mpi_inner_less_than_const_time(x, y);
let more = mpi_inner_less_than_const_time(y, x);
match (less, more) {
(Ok(true), Ok(false)) => Ok(Ordering::Less),
(Ok(false), Ok(true)) => Ok(Ordering::Greater),
(Ok(false), Ok(false)) => Ok(Ordering::Equal),
(Ok(true), Ok(true)) => unreachable!(),
(Err(e), _) => Err(e),
(Ok(_), Err(e)) => Err(e),
}
}

fn mpi_inner_less_than_const_time(x: &mpi, y: &mpi) -> Result<bool> {
let mut r = 0;
unsafe {
mpi_lt_mpi_ct(x, y, &mut r).into_result()?;
};
Ok(r == 1)
}

impl Ord for Mpi {
fn cmp(&self, other: &Mpi) -> Ordering {
let r = unsafe { mpi_cmp_mpi(&self.inner, &other.inner) };
Expand Down Expand Up @@ -709,3 +757,53 @@ impl ShrAssign<usize> for Mpi {
// mbedtls_mpi_sub_abs
// mbedtls_mpi_mod_int
// mbedtls_mpi_gcd

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_less_than_const_time() {
let mpi1 = Mpi::new(10).unwrap();
let mpi2 = Mpi::new(20).unwrap();

assert_eq!(mpi1.less_than_const_time(&mpi2), Ok(true));

assert_eq!(mpi1.less_than_const_time(&mpi1), Ok(false));

assert_eq!(mpi2.less_than_const_time(&mpi1), Ok(false));

// Check: function returns `Error::MpiBadInputData` if the allocated length of the two input Mpis is not the same.
let mpi3 = Mpi::from_binary(&[
0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd,
])
.unwrap();
assert_eq!(mpi3.less_than_const_time(&mpi3), Ok(false));
assert_eq!(mpi2.less_than_const_time(&mpi3), Err(Error::MpiBadInputData));
}

#[test]
fn test_cmp_const_time() {
let mpi1 = Mpi::new(10).unwrap();
let mpi2 = Mpi::new(20).unwrap();

assert_eq!(mpi1.cmp_const_time(&mpi2), Ok(Ordering::Less));

let mpi3 = Mpi::new(10).unwrap();
assert_eq!(mpi1.cmp_const_time(&mpi3), Ok(Ordering::Equal));

let mpi4 = Mpi::new(5).unwrap();
assert_eq!(mpi1.cmp_const_time(&mpi4), Ok(Ordering::Greater));
}

#[test]
fn test_eq_const_time() {
let mpi1 = Mpi::new(10).unwrap();
let mpi2 = Mpi::new(10).unwrap();

assert_eq!(mpi1.eq_const_time(&mpi2), Ok(true));

let mpi3 = Mpi::new(20).unwrap();
assert_eq!(mpi1.eq_const_time(&mpi3), Ok(false));
}
}
42 changes: 26 additions & 16 deletions mbedtls/src/ecp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,10 +309,6 @@ impl EcPoint {
Mpi::copy(&self.inner.Y)
}

pub fn z(&self) -> Result<Mpi> {
Taowyoo marked this conversation as resolved.
Show resolved Hide resolved
Mpi::copy(&self.inner.Z)
}

pub fn is_zero(&self) -> Result<bool> {
/*
mbedtls_ecp_is_zero takes arg as non-const for no particular reason
Expand Down Expand Up @@ -373,9 +369,14 @@ Please use `mul_with_rng` instead."
///
/// This function will return an error if:
///
/// * `k` is not a valid private key, or `self` is not a valid public key.
/// * The scalar `k` is not valid as a private key, determined by mbedtls function [`mbedtls_ecp_check_privkey`].
/// * The point `self` is not valid as a public key, determined by mbedtls function [`mbedtls_ecp_check_pubkey`].
/// * Memory allocation fails.
/// * Any other kind of failure occurs during the execution of the underlying `mbedtls_ecp_mul` function.
/// * Any other kind of failure occurs during the execution of the underlying [`mbedtls_ecp_mul`] function.
///
/// [`mbedtls_ecp_check_pubkey`]: https://github.com/fortanix/rust-mbedtls/blob/main/mbedtls-sys/vendor/include/mbedtls/ecp.h#L1115-L1143
/// [`mbedtls_ecp_check_privkey`]: https://github.com/fortanix/rust-mbedtls/blob/main/mbedtls-sys/vendor/include/mbedtls/ecp.h#L1145-L1165
/// [`mbedtls_ecp_mul`]: https://github.com/fortanix/rust-mbedtls/blob/main/mbedtls-sys/vendor/include/mbedtls/ecp.h#L933-L971
pub fn mul_with_rng<F: crate::rng::Random>(&self, group: &mut EcGroup, k: &Mpi, rng: &mut F) -> Result<EcPoint> {
// Note: mbedtls_ecp_mul performs point validation itself so we skip that here

Expand Down Expand Up @@ -433,13 +434,22 @@ Please use `mul_with_rng` instead."
}
}

/// This function compares two points in const time.
pub fn eq_const_time(&self, other: &EcPoint) -> bool {
unsafe {
let x = mpi_cmp_mpi(&self.inner.X, &other.inner.X) == 0;
let y = mpi_cmp_mpi(&self.inner.Y, &other.inner.Y) == 0;
let z = mpi_cmp_mpi(&self.inner.Z, &other.inner.Z) == 0;
x & y & z
/// This function checks equalness of two points in const time.
///
/// The implementation is based on C mbedtls function [`mbedtls_ecp_point_cmp`].
/// This new implementation ensures there is no shortcut when any of `x, y ,z` fields of two points is not equal.
///
/// [`mbedtls_ecp_point_cmp`]: https://github.com/fortanix/rust-mbedtls/blob/main/mbedtls-sys/vendor/library/ecp.c#L809-L825
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One TBD here:
The mpi_cmp_mpi is also not const time.
Should I use other mpi const time function to implement this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Such as this:

/**
 * \brief          Check if an MPI is less than the other in constant time.
 *
 * \param X        The left-hand MPI. This must point to an initialized MPI
 *                 with the same allocated length as Y.
 * \param Y        The right-hand MPI. This must point to an initialized MPI
 *                 with the same allocated length as X.
 * \param ret      The result of the comparison:
 *                 \c 1 if \p X is less than \p Y.
 *                 \c 0 if \p X is greater than or equal to \p Y.
 *
 * \return         0 on success.
 * \return         MBEDTLS_ERR_MPI_BAD_INPUT_DATA if the allocated length of
 *                 the two input MPIs is not the same.
 */
int mbedtls_mpi_lt_mpi_ct(const mbedtls_mpi *X, const mbedtls_mpi *Y,
                          unsigned *ret);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, was going to suggest this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! Reimplemented it with calling mbedtls_mpi_lt_mpi_ct twice on x,y,z.

pub fn eq_const_time(&self, other: &EcPoint) -> Result<bool> {
let x = crate::bignum::mpi_inner_eq_const_time(&self.inner.X, &other.inner.X);
let y = crate::bignum::mpi_inner_eq_const_time(&self.inner.Y, &other.inner.Y);
let z = crate::bignum::mpi_inner_eq_const_time(&self.inner.Z, &other.inner.Z);
match (x, y, z) {
(Ok(true), Ok(true), Ok(true)) => Ok(true),
(Ok(_), Ok(_), Ok(_)) => Ok(false),
(Ok(_), Ok(_), Err(e)) => Err(e),
(Ok(_), Err(e), _) => Err(e),
(Err(e), _, _) => Err(e),
}
}

Expand Down Expand Up @@ -718,9 +728,9 @@ mod tests {
assert!(g.eq(&g).unwrap());
assert!(zero.eq(&zero).unwrap());
assert!(!g.eq(&zero).unwrap());
assert!(g.eq_const_time(&g));
assert!(zero.eq_const_time(&zero));
assert!(!g.eq_const_time(&zero));
assert!(g.eq_const_time(&g).unwrap());
assert!(zero.eq_const_time(&zero).unwrap());
assert!(!g.eq_const_time(&zero).unwrap());
}

#[test]
Expand Down
8 changes: 4 additions & 4 deletions mbedtls/src/pk/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ impl Pk {
#[deprecated(
since = "0.12.3",
note = "This function does not accept an RNG so it's vulnerable to side channel attacks.
Please use `private_from_ec_components_with_rng` instead."
Please use `private_from_ec_scalar_with_rng` instead."
)]
pub fn private_from_ec_components(mut curve: EcGroup, private_key: Mpi) -> Result<Pk> {
let mut ret = Self::init();
Expand Down Expand Up @@ -378,8 +378,8 @@ Please use `private_from_ec_components_with_rng` instead."
///
/// * Fails to generate `EcPoint` from given EcGroup in `curve`.
/// * The underlying C `mbedtls_pk_setup` function fails to set up the `Pk` context.
/// * The `EcPoint::mul` function fails to generate the public key point.
pub fn private_from_ec_components_with_rng<F: Random>(mut curve: EcGroup, private_key: Mpi, rng: &mut F) -> Result<Pk> {
/// * The `EcPoint::mul_with_rng` function fails to generate the public key point.
pub fn private_from_ec_scalar_with_rng<F: Random>(mut curve: EcGroup, private_key: Mpi, rng: &mut F) -> Result<Pk> {
let mut ret = Self::init();
let curve_generator = curve.generator()?;
let public_point = curve_generator.mul_with_rng(&mut curve, &private_key, rng)?;
Expand Down Expand Up @@ -1266,7 +1266,7 @@ iy6KC991zzvaWY/Ys+q/84Afqa+0qJKQnPuy/7F5GkVdQA/lfbhi

assert_eq!(pem1, pem2);

let mut key_from_components = Pk::private_from_ec_components_with_rng(
let mut key_from_components = Pk::private_from_ec_scalar_with_rng(
secp256r1.clone(),
key1.ec_private().unwrap(),
&mut crate::test_support::rand::test_rng(),
Expand Down
Loading