Skip to content

Commit

Permalink
handle the case of non-finite floats
Browse files Browse the repository at this point in the history
  • Loading branch information
radumarias committed Sep 11, 2024
1 parent 670bce2 commit 0c6efb5
Show file tree
Hide file tree
Showing 31 changed files with 462 additions and 104 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ strum_macros = "0.24"
coset = "0.3.8"
ciborium = "0.2.2"
digest = "0.10.7"
hex = "0.4.3"

[dev-dependencies]
hex = "0.4.3"
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,4 @@ stateDiagram

You can see the full example in [simulated_device_and_reader](tests/simulated_device_and_reader.rs) and a version that
uses `State` pattern, `Arc` and `Mutex` [simulated_device_and_reader](tests/simulated_device_and_reader_state.rs).

4 changes: 2 additions & 2 deletions macros/src/to_cbor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ fn named_fields(isomdl_path: Ident, ident: Ident, input: FieldsNamed) -> TokenSt
fn to_cbor(self) -> Value {
let map = self.to_ns_map()
.into_iter()
.map(|(k, v)| (ciborium::Value::Text(k), v.into()))
.map(|(k, v)| (ciborium::Value::Text(k), v.try_into().unwrap()))
.collect();
ciborium::Value::Map(map).into()
ciborium::Value::Map(map).try_into().unwrap()
}
}
}
Expand Down
249 changes: 229 additions & 20 deletions src/cbor.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,56 @@
use coset::{cbor, CoseError, EndOfFile};
use serde::{de, Deserialize, Serialize};
use std::borrow::{Borrow, BorrowMut};
use std::cmp::Ordering;
use std::io::Cursor;
use std::ops::{Deref, DerefMut};

use coset::{cbor, CoseError, EndOfFile};
use serde::{de, Deserialize, Serialize};
use thiserror::Error;

/// Wraps [ciborium::Value] and implements [PartialEq], [Eq], [PartialOrd] and [Ord],
/// so it can be used in maps and sets.
///
/// [IEEE754](https://www.rfc-editor.org/rfc/rfc8949.html#IEEE754)
/// non-finite floats do not have a total ordering,
/// which means [`Ord`] cannot be correctly implemented for types that may contain them.
/// That's why we don't support such values.
///
/// Also, useful in future if we want to change the CBOR library.
#[derive(Debug, Clone)]
pub struct Value(pub ciborium::Value);
pub struct Value(pub(crate) ciborium::Value);

impl Value {
/// Create a new CBOR value.
///
/// Return an error if the value contains non-finite floats or NaN.
pub fn from(value: ciborium::Value) -> Result<Self, CborError> {
// Validate the CBOR value. If it contains non-finite floats, return an error.
if contains_non_finite_floats(&value) {
Err(CborError::NonFiniteFloats)
} else {
Ok(Value(value))
}
}

/// Unsafe version of `new`.
///
/// It will allow creating from value containing non-finite floats or NaN.
pub unsafe fn from_unsafe(value: ciborium::Value) -> Self {
Value(value)
}
}

// Helper function to check for non-finite floats
fn contains_non_finite_floats(value: &ciborium::Value) -> bool {
match value {
ciborium::Value::Float(f) => !f.is_finite(),
ciborium::Value::Array(arr) => arr.iter().any(contains_non_finite_floats),
ciborium::Value::Map(map) => map
.iter()
.any(|(k, v)| contains_non_finite_floats(k) || contains_non_finite_floats(v)),
_ => false,
}
}

#[derive(Debug, Error)]
pub enum CborError {
Expand Down Expand Up @@ -39,6 +79,48 @@ pub enum CborError {
/// Unrecognized value in neither IANA-controlled range nor private range.
#[error("unregistered non-private IANA value")]
UnregisteredIanaNonPrivateValue,
/// Value contains non-finite float (NaN or Infinity).
#[error("non finite floats")]
NonFiniteFloats,
}

impl PartialEq for CborError {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::DecodeFailed(_), Self::DecodeFailed(_)) => true,
(Self::DuplicateMapKey, Self::DuplicateMapKey) => true,
(Self::EncodeFailed, Self::EncodeFailed) => true,
(Self::ExtraneousData, Self::ExtraneousData) => true,
(Self::OutOfRangeIntegerValue, Self::OutOfRangeIntegerValue) => true,
(Self::UnexpectedItem(l_msg, l_want), Self::UnexpectedItem(r_msg, r_want)) => {
l_msg == r_msg && l_want == r_want
}
(Self::UnregisteredIanaValue, Self::UnregisteredIanaValue) => true,
(Self::UnregisteredIanaNonPrivateValue, Self::UnregisteredIanaNonPrivateValue) => true,
(Self::NonFiniteFloats, Self::NonFiniteFloats) => true,
_ => false,
}
}
}

impl Eq for CborError {}

impl Clone for CborError {
fn clone(&self) -> Self {
match self {
CborError::DecodeFailed(_) => panic!("cannot clone"),
CborError::DuplicateMapKey => CborError::DuplicateMapKey,
CborError::EncodeFailed => CborError::EncodeFailed,
CborError::ExtraneousData => CborError::ExtraneousData,
CborError::OutOfRangeIntegerValue => CborError::OutOfRangeIntegerValue,
CborError::UnexpectedItem(msg, want) => CborError::UnexpectedItem(msg, want),
CborError::UnregisteredIanaValue => CborError::UnregisteredIanaValue,
CborError::UnregisteredIanaNonPrivateValue => {
CborError::UnregisteredIanaNonPrivateValue
}
CborError::NonFiniteFloats => CborError::NonFiniteFloats,
}
}
}

impl From<CoseError> for CborError {
Expand Down Expand Up @@ -80,23 +162,23 @@ impl DerefMut for Value {
}
}

impl PartialEq for Value {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
impl Eq for Value {}

impl Ord for Value {
fn cmp(&self, other: &Self) -> Ordering {
self.0.partial_cmp(&other.0).unwrap()
}
}

impl Eq for Value {}

impl PartialOrd for Value {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.0.partial_cmp(&other.0)
}
}

impl Ord for Value {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.0.partial_cmp(&other.0).unwrap()
impl PartialEq for Value {
fn eq(&self, other: &Self) -> bool {
self.0.eq(&other.0)
}
}

Expand Down Expand Up @@ -151,9 +233,11 @@ where
from_slice(&bytes)
}

impl From<ciborium::Value> for Value {
fn from(value: ciborium::Value) -> Self {
Self(value)
impl TryFrom<ciborium::Value> for Value {
type Error = CborError;

fn try_from(value: ciborium::Value) -> Result<Self, Self::Error> {
Value::from(value)
}
}

Expand Down Expand Up @@ -203,13 +287,18 @@ macro_rules! impl_from {
($variant:path, $for_type:ty) => {
impl From<$for_type> for Value {
fn from(v: $for_type) -> Value {
$variant(v.into()).into()
unsafe { Value::from_unsafe($variant(v.into())) }
}
}
};
}

impl_from!(ciborium::Value::Bool, bool);
impl_from!(ciborium::Value::Bytes, Vec<u8>);
impl_from!(ciborium::Value::Bytes, &[u8]);
impl_from!(ciborium::Value::Text, String);
impl_from!(ciborium::Value::Text, &str);
impl_from!(ciborium::Value::Array, Vec<ciborium::Value>);
impl_from!(ciborium::Value::Integer, i8);
impl_from!(ciborium::Value::Integer, i16);
impl_from!(ciborium::Value::Integer, i32);
Expand All @@ -222,6 +311,126 @@ impl_from!(ciborium::Value::Integer, u64);
// u128 omitted because not all numbers fit in CBOR serialization
impl_from!(ciborium::Value::Float, f32);
impl_from!(ciborium::Value::Float, f64);
impl_from!(ciborium::Value::Bytes, Vec<u8>);
impl_from!(ciborium::Value::Text, String);
impl_from!(ciborium::Value::Array, Vec<ciborium::Value>);

#[cfg(test)]
mod tests {
use crate::cbor::{CborError, Value};

#[test]
fn conversions() {
assert_eq!(
Value::from(ciborium::Value::Bool(true)),
Ok(ciborium::Value::Bool(true).try_into().unwrap())
);
assert_eq!(
Value::from(ciborium::Value::Integer(1i8.into())),
Ok(ciborium::Value::Integer(1i8.into()).try_into().unwrap())
);
assert_eq!(
Value::from(ciborium::Value::Integer(1i16.into())),
Ok(ciborium::Value::Integer(1i16.into()).try_into().unwrap())
);
assert_eq!(
Value::from(ciborium::Value::Integer(1i32.into())),
Ok(ciborium::Value::Integer(1i32.into()).try_into().unwrap())
);
assert_eq!(
Value::from(ciborium::Value::Integer(1i64.into())),
Ok(ciborium::Value::Integer(1i64.into()).try_into().unwrap())
);
assert_eq!(
Value::from(ciborium::Value::Integer(1u8.into())),
Ok(ciborium::Value::Integer(1u8.into()).try_into().unwrap())
);
assert_eq!(
Value::from(ciborium::Value::Integer(1u16.into())),
Ok(ciborium::Value::Integer(1u16.into()).try_into().unwrap())
);
assert_eq!(
Value::from(ciborium::Value::Integer(1u32.into())),
Ok(ciborium::Value::Integer(1u32.into()).try_into().unwrap())
);
assert_eq!(
Value::from(ciborium::Value::Integer(1u64.into())),
Ok(ciborium::Value::Integer(1u64.into()).try_into().unwrap())
);
assert_eq!(
Value::from(ciborium::Value::Float(1.0f32.into())),
Ok(ciborium::Value::Float(1.0f32.into()).try_into().unwrap())
);
assert_eq!(
Value::from(ciborium::Value::Float(1.0f64.into())),
Ok(ciborium::Value::Float(1.0f64.into()).try_into().unwrap())
);
assert_eq!(
Value::from(ciborium::Value::Text("foo".to_string())),
Ok(ciborium::Value::Text("foo".to_string()).try_into().unwrap())
);
}

#[test]
fn non_finite_floats() {
assert_eq!(
Value::from(ciborium::Value::from(f32::NAN)),
Err(CborError::NonFiniteFloats)
);
assert_eq!(
Value::from(ciborium::Value::from(f32::INFINITY)),
Err(CborError::NonFiniteFloats)
);
assert_eq!(
Value::from(ciborium::Value::from(f32::NEG_INFINITY)),
Err(CborError::NonFiniteFloats)
);
assert_eq!(
Value::from(ciborium::Value::from(f64::NAN)),
Err(CborError::NonFiniteFloats)
);
assert_eq!(
Value::from(ciborium::Value::from(f64::NEG_INFINITY)),
Err(CborError::NonFiniteFloats)
);
}

#[test]
#[should_panic]
fn non_finite_floats_no_panic() {
let _ = Value::from(ciborium::Value::from(f32::NAN)).unwrap();
}

#[test]
fn non_finite_floats_unsafe() {
unsafe {
assert!(Value::from_unsafe(ciborium::Value::from(f32::NAN))
.0
.into_float()
.unwrap()
.is_nan());
assert!(Value::from_unsafe(ciborium::Value::from(f32::INFINITY))
.0
.into_float()
.unwrap()
.is_infinite());
assert!(Value::from_unsafe(ciborium::Value::from(f32::NEG_INFINITY))
.0
.into_float()
.unwrap()
.is_infinite());
assert!(Value::from_unsafe(ciborium::Value::from(f64::NAN))
.0
.into_float()
.unwrap()
.is_nan());
assert!(Value::from_unsafe(ciborium::Value::from(f64::INFINITY))
.0
.into_float()
.unwrap()
.is_infinite());
assert!(Value::from_unsafe(ciborium::Value::from(f64::NEG_INFINITY))
.0
.into_float()
.unwrap()
.is_infinite());
}
}
}
6 changes: 1 addition & 5 deletions src/cose.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,7 @@ where
where
S: serde::Serializer,
{
let tag = if self.tagged {
Some(T::TAG)
} else {
None
};
let tag = if self.tagged { Some(T::TAG) } else { None };

ciborium::tag::Captured(tag, SerializedAsCborValue(&self.inner)).serialize(serializer)
}
Expand Down
3 changes: 1 addition & 2 deletions src/cose/mac0.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,8 @@ mod tests {
#[test]
fn roundtrip() {
let bytes = Vec::<u8>::from_hex(COSE_MAC0).unwrap();
let mut parsed: MaybeTagged<CoseMac0> =
let parsed: MaybeTagged<CoseMac0> =
cbor::from_slice(&bytes).expect("failed to parse COSE_MAC0 from bytes");
parsed.set_tagged();
let roundtripped = cbor::to_vec(&parsed).expect("failed to serialize COSE_MAC0");
assert_eq!(
bytes, roundtripped,
Expand Down
2 changes: 1 addition & 1 deletion src/cose/serialized_as_cbor_value.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use coset::AsCborValue;
use serde::{Deserialize, Serialize};

/// This is a small helper wrapper to deal with `coset`` types that don't
/// This is a small helper wrapper to deal with `coset` types that don't
/// implement `Serialize`/`Deserialize` but only `AsCborValue`.
pub struct SerializedAsCborValue<T>(pub T);

Expand Down
Loading

0 comments on commit 0c6efb5

Please sign in to comment.