From 0a96396093f3ca508aa224f8ea560d0f6d0e3f91 Mon Sep 17 00:00:00 2001 From: francisco Date: Tue, 19 Dec 2023 15:39:53 +0100 Subject: [PATCH] Add ciphertext check for RSADP --- mbedtls/src/pk/mod.rs | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/mbedtls/src/pk/mod.rs b/mbedtls/src/pk/mod.rs index 53fd08561..77e6bb76f 100644 --- a/mbedtls/src/pk/mod.rs +++ b/mbedtls/src/pk/mod.rs @@ -17,6 +17,7 @@ use crate::hash::Type as MdType; use crate::private::UnsafeFrom; use crate::rng::Random; use core::convert::TryInto; +use core::ops::Sub; use core::ptr; use byteorder::{BigEndian, ByteOrder}; @@ -691,6 +692,13 @@ impl Pk { return Err(Error::RsaOutputTooLarge); } + // Don't process outside of {2, ..., n-2} + let nm1 = self.rsa_public_modulus()?.sub(&Mpi::new(1)?)?; + let c_mpi = Mpi::from_binary(cipher)?; + if c_mpi <= Mpi::new(1).unwrap() || c_mpi >= nm1 { + return Err(Error::MpiBadInputData); + } + unsafe { rsa_private(ctx, Some(F::call), rng.data_ptr(), cipher.as_ptr(), plain.as_mut_ptr()).into_result()?; }; @@ -1023,6 +1031,7 @@ mod tests { use super::*; use crate::hash::{MdInfo, Type}; use crate::pk::Type as PkType; + use core::ops::Sub; // This is test data that must match library output *exactly* const TEST_PEM: &'static str = "-----BEGIN RSA PRIVATE KEY----- @@ -1470,4 +1479,38 @@ iy6KC991zzvaWY/Ys+q/84Afqa+0qJKQnPuy/7F5GkVdQA/lfbhi assert!(pk.custom_public_key().is_err()); assert!(pk.custom_private_key().is_err()); } + + #[test] + fn fips_rsadp_bounds() { + let rng = &mut crate::test_support::rand::test_rng(); + let (bitlen, exp) = (2048, 0x10001); + let mut pk = Pk::generate_rsa(rng, bitlen, exp).unwrap(); + pk.set_options(Options::Rsa { + padding: RsaPadding::None, + }); + + const LEN: usize = 256; + + // Decrypting anything out of {2, n-2} should fail + let expected_err = Error::MpiBadInputData; + + let mut pt = [0x00; LEN]; + + let _0 = Mpi::new(0).unwrap(); + let _1 = Mpi::new(1).unwrap(); + let _2 = Mpi::new(2).unwrap(); + let n = pk.rsa_public_modulus().unwrap(); + let nm1 = pk.rsa_public_modulus().unwrap().sub(&Mpi::new(1).unwrap()).unwrap(); + let nm2 = pk.rsa_public_modulus().unwrap().sub(&Mpi::new(2).unwrap()).unwrap(); + for c in [_0, _1, nm1, n] { + let ct = c.to_binary_padded(LEN).unwrap(); + let l = pk.decrypt(&ct, &mut pt, rng); + assert_eq!(l.unwrap_err(), expected_err); + } + for c in [_2, nm2] { + let ct = c.to_binary_padded(LEN).unwrap(); + let l = pk.decrypt(&ct, &mut pt, rng); + assert_eq!(l.unwrap(), LEN); + } + } }