From 590419a0fc7d36adc367e554084057998eeb8b8b Mon Sep 17 00:00:00 2001 From: YX Cao Date: Mon, 11 Dec 2023 15:12:28 -0800 Subject: [PATCH] Add FFI function wrappers and bump version (#336) * add more hkdf function and zeroize * bump version --- Cargo.lock | 2 +- mbedtls/Cargo.toml | 8 +- mbedtls/src/hash/mod.rs | 194 ++++++++++++++++++++++++++++++++++++---- mbedtls/src/lib.rs | 4 + 4 files changed, 189 insertions(+), 19 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f04014112..3a12768a9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -485,7 +485,7 @@ checksum = "7ffc5c5338469d4d3ea17d269fa8ea3512ad247247c30bd2df69e68309ed0a08" [[package]] name = "mbedtls" -version = "0.12.0-alpha.2" +version = "0.12.0" dependencies = [ "async-stream", "bit-vec", diff --git a/mbedtls/Cargo.toml b/mbedtls/Cargo.toml index 4d119a989..77451e98b 100644 --- a/mbedtls/Cargo.toml +++ b/mbedtls/Cargo.toml @@ -2,7 +2,7 @@ name = "mbedtls" # We jumped from v0.9 to v0.12 because v0.10 and v0.11 were based on mbedtls 3.X, which # we decided not to support. -version = "0.12.0-alpha.2" +version = "0.12.0" authors = ["Jethro Beekman "] build = "build.rs" edition = "2018" @@ -61,6 +61,12 @@ pin-project-lite = "0.2" [build-dependencies] cc = "1.0" +# feature 'time` is necessary under windows +[target.'cfg(target_os = "windows")'.mbedtls-platform-support] +version = "0.1" +path = "../mbedtls-platform-support" +features = ["time"] + [features] # Features are documented in the README diff --git a/mbedtls/src/hash/mod.rs b/mbedtls/src/hash/mod.rs index c9799b155..c50f09911 100644 --- a/mbedtls/src/hash/mod.rs +++ b/mbedtls/src/hash/mod.rs @@ -78,6 +78,21 @@ impl MdInfo { } } +impl Clone for Md { + fn clone(&self) -> Self { + fn copy_md(md: &Md) -> Result { + let mut ctx = Md::init(); + unsafe { + md_setup(&mut ctx.inner, md.inner.md_info, 0).into_result()?; + md_starts(&mut ctx.inner).into_result()?; + md_clone(&mut ctx.inner, &md.inner).into_result()?; + }; + Ok(ctx) + } + copy_md(self).expect("Md::copy success") + } +} + impl Md { pub fn new(md: Type) -> Result { let md: MdInfo = match md.into() { @@ -126,6 +141,7 @@ impl Md { } } +#[derive(Clone)] pub struct Hmac { ctx: Md, } @@ -178,12 +194,31 @@ impl Hmac { } } -pub struct Hkdf { - _ctx: Md, -} +/// The HMAC-based Extract-and-Expand Key Derivation Function (HKDF) is specified by RFC 5869. +#[derive(Debug)] +pub struct Hkdf; impl Hkdf { - pub fn hkdf(md: Type, salt: &[u8], ikm: &[u8], info: &[u8], key: &mut [u8]) -> Result<()> { + /// This is the HMAC-based Extract-and-Expand Key Derivation Function (HKDF). + /// + /// # Parameters + /// + /// * `md`: A hash function; `MdInfo::from(md).size()` denotes the length of the hash + /// function output in bytes. + /// * `salt`: An salt value (a non-secret random value); + /// * `ikm`: The input keying material. + /// * `info`: An optional context and application specific information + /// string. This can be a zero-length string. + /// * `okm`: The output keying material. The length of the output keying material in bytes + /// must be less than or equal to 255 * `MdInfo::from(md).size()` bytes. + /// + /// # Returns + /// + /// * `()` on success. + /// * [`Error::HkdfBadInputData`] when the parameters are invalid. + /// * Any `Error::Md*` error for errors returned from the underlying + /// MD layer. + pub fn hkdf(md: Type, salt: &[u8], ikm: &[u8], info: &[u8], okm: &mut [u8]) -> Result<()> { let md: MdInfo = match md.into() { Some(md) => md, None => return Err(Error::MdBadInputData), @@ -198,24 +233,149 @@ impl Hkdf { ikm.len(), info.as_ptr(), info.len(), - key.as_mut_ptr(), - key.len(), + okm.as_mut_ptr(), + okm.len(), ) - .into_result()?; - Ok(()) } + .into_result()?; + Ok(()) } -} -impl Clone for Md { - fn clone(&self) -> Self { - fn copy_md(md: &Md) -> Result { - let md_type = unsafe { md_get_type(md.inner.md_info) }; - let mut m = Md::new(md_type.into())?; - unsafe { md_clone(&mut m.inner, &md.inner) }.into_result()?; - Ok(m) + /// This is the HMAC-based Extract-and-Expand Key Derivation Function (HKDF). + /// + /// # Parameters + /// + /// * `md`: A hash function; `MdInfo::from(md).size()` denotes the length of the hash + /// function output in bytes. + /// * `salt`: An optional salt value (a non-secret random value); + /// if the salt is not provided, a string of all zeros of + /// `MdInfo::from(md).size()` length is used as the salt. + /// * `ikm`: The input keying material. + /// * `info`: An optional context and application specific information + /// string. This can be a zero-length string. + /// * `okm`: The output keying material. The length of the output keying material in bytes + /// must be less than or equal to 255 * `MdInfo::from(md).size()` bytes. + /// + /// # Returns + /// + /// * `()` on success. + /// * [`Error::HkdfBadInputData`] when the parameters are invalid. + /// * Any `Error::Md*` error for errors returned from the underlying + /// MD layer. + pub fn hkdf_optional_salt(md: Type, maybe_salt: Option<&[u8]>, ikm: &[u8], info: &[u8], okm: &mut [u8]) -> Result<()> { + let md: MdInfo = match md.into() { + Some(md) => md, + None => return Err(Error::MdBadInputData), + }; + + unsafe { + hkdf( + md.inner, + maybe_salt.map_or(::core::ptr::null(), |salt| salt.as_ptr()), + maybe_salt.map_or(0, |salt| salt.len()), + ikm.as_ptr(), + ikm.len(), + info.as_ptr(), + info.len(), + okm.as_mut_ptr(), + okm.len(), + ) } - copy_md(self).expect("Md::copy success") + .into_result()?; + Ok(()) + } + + /// Takes the input keying material `ikm` and extracts from it a + /// fixed-length pseudorandom key `prk`. + /// + /// # Warning + /// + /// This function should only be used if the security of it has been + /// studied and established in that particular context (eg. TLS 1.3 + /// key schedule). For standard HKDF security guarantees use + /// `hkdf` instead. + /// + /// # Parameters + /// + /// * `md`: A hash function; `MdInfo::from(md).size()` denotes the length of the + /// hash function output in bytes. + /// * `salt`: An optional salt value (a non-secret random value); + /// if the salt is not provided, a string of all zeros + /// of `MdInfo::from(md).size()` length is used as the salt. + /// * `ikm`: The input keying material. + /// * `prk`: The output pseudorandom key of at least `MdInfo::from(md).size()` bytes. + /// + /// # Returns + /// + /// * `()` on success. + /// * [`Error::HkdfBadInputData`] when the parameters are invalid. + /// * Any `Error::Md*` error for errors returned from the underlying + /// MD layer. + pub fn hkdf_extract(md: Type, maybe_salt: Option<&[u8]>, ikm: &[u8], prk: &mut [u8]) -> Result<()> { + let md: MdInfo = match md.into() { + Some(md) => md, + None => return Err(Error::MdBadInputData), + }; + + unsafe { + hkdf_extract( + md.inner, + maybe_salt.map_or(::core::ptr::null(), |salt| salt.as_ptr()), + maybe_salt.map_or(0, |salt| salt.len()), + ikm.as_ptr(), + ikm.len(), + prk.as_mut_ptr(), + ) + } + .into_result()?; + Ok(()) + } + + /// Expand the supplied `prk` into several additional pseudorandom keys, which is the output of the HKDF. + /// + /// # Warning + /// + /// This function should only be used if the security of it has been + /// studied and established in that particular context (eg. TLS 1.3 + /// key schedule). For standard HKDF security guarantees use + /// `hkdf` instead. + /// + /// # Parameters + /// + /// * `md`: A hash function; `MdInfo::from(md).size()` denotes the length of the + /// hash function output in bytes. + /// * `prk`: A pseudorandom key of at least `MdInfo::from(md).size()` bytes. `prk` is + /// usually the output from the HKDF extract step. + /// * `info`: An optional context and application specific information + /// string. This can be a zero-length string. + /// * `okm`: The output keying material. The length of the output keying material in bytes + /// must be less than or equal to 255 * `MdInfo::from(md).size()` bytes. + /// + /// # Returns + /// + /// * `()` on success. + /// * [`Error::HkdfBadInputData`] when the parameters are invalid. + /// * Any `Error::Md*` error for errors returned from the underlying + /// MD layer. + pub fn hkdf_expand(md: Type, prk: &[u8], info: &[u8], okm: &mut [u8]) -> Result<()> { + let md: MdInfo = match md.into() { + Some(md) => md, + None => return Err(Error::MdBadInputData), + }; + + unsafe { + hkdf_expand( + md.inner, + prk.as_ptr(), + prk.len(), + info.as_ptr(), + info.len(), + okm.as_mut_ptr(), + okm.len(), + ) + } + .into_result()?; + Ok(()) } } diff --git a/mbedtls/src/lib.rs b/mbedtls/src/lib.rs index dc73d8d5d..07710a99b 100644 --- a/mbedtls/src/lib.rs +++ b/mbedtls/src/lib.rs @@ -44,6 +44,10 @@ pub mod x509; #[cfg(feature = "pkcs12")] pub mod pkcs12; +pub fn zeroize(buf: &mut [u8]) { + unsafe { mbedtls_sys::platform_zeroize(buf.as_mut_ptr() as *mut mbedtls_sys::types::raw_types::c_void, buf.len()) } +} + // ============== // Utility // ==============