Skip to content

Commit

Permalink
derive Invariant for enums
Browse files Browse the repository at this point in the history
  • Loading branch information
carolynzech committed Oct 25, 2024
1 parent 693495a commit 48886c0
Showing 1 changed file with 84 additions and 22 deletions.
106 changes: 84 additions & 22 deletions library/contracts/safety/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

use proc_macro::TokenStream;
use proc_macro_error::proc_macro_error;
use quote::{quote, quote_spanned};
use quote::{format_ident, quote, quote_spanned};
use syn::{
parse_macro_input, parse_quote, spanned::Spanned, Data, DeriveInput, Fields, GenericParam,
Generics, Index, ItemStruct,
parse_macro_input, parse_quote, spanned::Spanned, Data, DataEnum, DeriveInput, Fields, GenericParam,
Generics, Ident, Index, ItemStruct,
};

#[cfg(kani_host)]
Expand Down Expand Up @@ -63,7 +63,7 @@ pub fn invariant(attr: TokenStream, item: TokenStream) -> TokenStream {

/// Expands the derive macro for the Invariant trait.
/// The macro expands to an implementation of the `is_safe` method for the `Invariant` trait.
/// This macro is only supported for structs.
/// This macro is only supported for structs and enums.
///
/// # Example
///
Expand All @@ -83,33 +83,67 @@ pub fn invariant(attr: TokenStream, item: TokenStream) -> TokenStream {
/// }
/// }
/// ```
/// For enums, the body of `is_safe` matches on the variant and calls `is_safe` on its fields,
/// # Example
///
/// /// ```ignore
/// #[derive(Invariant)]
/// enum MyEnum {
/// OptionOne(u32, u32),
/// OptionTwo(Square),
/// OptionThree
/// }
/// ```
///
/// expands to:
/// ```ignore
/// impl core::ub_checks::Invariant for MyEnum {
/// fn is_safe(&self) -> bool {
/// match self {
/// MyEnum::OptionOne(field1, field2) => field1.is_safe() && field2.is_safe(),
/// MyEnum::OptionTwo(field1) => field1.is_safe(),
/// MyEnum::OptionThree => true,
/// }
/// }
/// }
/// ```
/// For more information on the Invariant trait, see its documentation in core::ub_checks.
#[proc_macro_error]
#[proc_macro_derive(Invariant)]
pub fn derive_invariant(item: TokenStream) -> TokenStream {
let derive_item = parse_macro_input!(item as DeriveInput);
let item_name = &derive_item.ident;
if let Data::Struct(struct_data) = derive_item.data {
let safe_body = safe_body(&struct_data.fields);

// Add a bound `T: Invariant` to every type parameter T.
let generics = add_trait_bound_invariant(derive_item.generics);
// Generate an expression to sum up the heap size of each field.
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let safe_body = match derive_item.data {
Data::Struct(struct_data) => {
safe_body(&struct_data.fields)
},
Data::Enum(enum_data) => {
let variant_checks = variant_checks(enum_data, item_name);

let expanded = quote! {
// The generated implementation.
#[unstable(feature="invariant", issue="none")]
impl #impl_generics core::ub_checks::Invariant for #item_name #ty_generics #where_clause {
fn is_safe(&self) -> bool {
#safe_body
quote! {
match self {
#(#variant_checks),*
}
}
};
proc_macro::TokenStream::from(expanded)
} else {
panic!("Attempted to derive the Invariant trait on a non-struct type.")
}
},
Data::Union(..) => unimplemented!("Attempted to derive Invariant on a union; Invariant can only be derived for structs and enums."),
};

// Add a bound `T: Invariant` to every type parameter T.
let generics = add_trait_bound_invariant(derive_item.generics);
// Generate an expression to sum up the heap size of each field.
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

let expanded = quote! {
// The generated implementation.
#[unstable(feature="invariant", issue="none")]
impl #impl_generics core::ub_checks::Invariant for #item_name #ty_generics #where_clause {
fn is_safe(&self) -> bool {
#safe_body
}
}
};
proc_macro::TokenStream::from(expanded)
}

#[proc_macro_error]
Expand All @@ -136,6 +170,34 @@ fn add_trait_bound_invariant(mut generics: Generics) -> Generics {
generics
}

/// Generate safety checks for each variant of an enum
fn variant_checks(enum_data: DataEnum, item_name: &Ident) -> Vec<proc_macro2::TokenStream> {
enum_data.variants.iter().map(|variant| {
let variant_name = &variant.ident;
match &variant.fields {
Fields::Unnamed(fields) => {
let field_names: Vec<_> = fields.unnamed.iter().enumerate().map(|(i, _)| {
format_ident!("field{}", i + 1)
}).collect();

let field_checks: Vec<_> = field_names.iter().map(|field_name| {
quote! { #field_name.is_safe() }
}).collect();

quote! {
#item_name::#variant_name(#(#field_names),*) => #(#field_checks)&&*
}
},
Fields::Unit => {
quote! {
#item_name::#variant_name => true
}
},
Fields::Named(_) => unreachable!("Enums do not have named fields"),
}
}).collect()
}

/// Generate the body for the `is_safe` method.
/// For each field of the type, enforce that it is safe.
fn safe_body(fields: &Fields) -> proc_macro2::TokenStream {
Expand Down

0 comments on commit 48886c0

Please sign in to comment.