From 39b42b8fa364a1456e312dfb5005e3daca28d83d Mon Sep 17 00:00:00 2001 From: zachs18 <8355914+zachs18@users.noreply.github.com> Date: Wed, 26 Jul 2023 14:50:58 -0500 Subject: [PATCH] Zeroable derive custom bounds (#196) * Add space because rust-analyzer thinks it's a weird literal otherwise. * Add custom bounds for Zeroable. * Cleanup custom bounds code. * Only parse explicit bounds for Zeroable. * Qualify syn types. * If no explicit bounds are given, apply the default bounds. * Add perfect derive semantics to `#[zeroable(bound = "...")]`. --- derive/src/lib.rs | 196 ++++++++++++++++++++++++++++++++++++++++--- derive/src/traits.rs | 9 +- 2 files changed, 192 insertions(+), 13 deletions(-) diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 448e503..23ddf5e 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -77,7 +77,6 @@ pub fn derive_anybitpattern( /// /// ```rust /// # use bytemuck_derive::{Zeroable}; -/// /// #[derive(Copy, Clone, Zeroable)] /// #[repr(C)] /// struct Test { @@ -85,7 +84,65 @@ pub fn derive_anybitpattern( /// b: u16, /// } /// ``` -#[proc_macro_derive(Zeroable)] +/// +/// # Custom bounds +/// +/// Custom bounds for the derived `Zeroable` impl can be given using the +/// `#[zeroable(bound = "")]` helper attribute. +/// +/// Using this attribute additionally opts-in to "perfect derive" semantics, +/// where instead of adding bounds for each generic type parameter, bounds are +/// added for each field's type. +/// +/// ## Examples +/// +/// ```rust +/// # use bytemuck::Zeroable; +/// # use std::marker::PhantomData; +/// #[derive(Clone, Zeroable)] +/// #[zeroable(bound = "")] +/// struct AlwaysZeroable { +/// a: PhantomData, +/// } +/// +/// AlwaysZeroable::::zeroed(); +/// ``` +/// +/// ```rust,compile_fail +/// # use bytemuck::Zeroable; +/// # use std::marker::PhantomData; +/// #[derive(Clone, Zeroable)] +/// #[zeroable(bound = "T: Copy")] +/// struct ZeroableWhenTIsCopy { +/// a: PhantomData, +/// } +/// +/// ZeroableWhenTIsCopy::::zeroed(); +/// ``` +/// +/// The restriction that all fields must be Zeroable is still applied, and this +/// is enforced using the mentioned "perfect derive" semantics. +/// +/// ```rust +/// # use bytemuck::Zeroable; +/// #[derive(Clone, Zeroable)] +/// #[zeroable(bound = "")] +/// struct ZeroableWhenTIsZeroable { +/// a: T, +/// } +/// ZeroableWhenTIsZeroable::::zeroed(); +/// ``` +/// +/// ```rust,compile_fail +/// # use bytemuck::Zeroable; +/// # #[derive(Clone, Zeroable)] +/// # #[zeroable(bound = "")] +/// # struct ZeroableWhenTIsZeroable { +/// # a: T, +/// # } +/// ZeroableWhenTIsZeroable::::zeroed(); +/// ``` +#[proc_macro_derive(Zeroable, attributes(zeroable))] pub fn derive_zeroable( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream { @@ -317,12 +374,127 @@ fn derive_marker_trait(input: DeriveInput) -> TokenStream { .unwrap_or_else(|err| err.into_compile_error()) } +/// Find `#[name(key = "value")]` helper attributes on the struct, and return +/// their `"value"`s parsed with `parser`. +/// +/// Returns an error if any attributes with the given `name` do not match the +/// expected format. Returns `Ok([])` if no attributes with `name` are found. +fn find_and_parse_helper_attributes( + attributes: &[syn::Attribute], name: &str, key: &str, parser: P, + example_value: &str, invalid_value_msg: &str, +) -> Result> { + let invalid_format_msg = + format!("{name} attribute must be `{name}({key} = \"{example_value}\")`",); + let values_to_check = attributes.iter().filter_map(|attr| match &attr.meta { + // If a `Path` matches our `name`, return an error, else ignore it. + // e.g. `#[zeroable]` + syn::Meta::Path(path) => path + .is_ident(name) + .then(|| Err(syn::Error::new_spanned(path, &invalid_format_msg))), + // If a `NameValue` matches our `name`, return an error, else ignore it. + // e.g. `#[zeroable = "hello"]` + syn::Meta::NameValue(namevalue) => { + namevalue.path.is_ident(name).then(|| { + Err(syn::Error::new_spanned(&namevalue.path, &invalid_format_msg)) + }) + } + // If a `List` matches our `name`, match its contents to our format, else + // ignore it. If its contents match our format, return the value, else + // return an error. + syn::Meta::List(list) => list.path.is_ident(name).then(|| { + let namevalue: syn::MetaNameValue = syn::parse2(list.tokens.clone()) + .map_err(|_| { + syn::Error::new_spanned(&list.tokens, &invalid_format_msg) + })?; + if namevalue.path.is_ident(key) { + match namevalue.value { + syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Str(strlit), .. + }) => Ok(strlit), + _ => { + Err(syn::Error::new_spanned(&namevalue.path, &invalid_format_msg)) + } + } + } else { + Err(syn::Error::new_spanned(&namevalue.path, &invalid_format_msg)) + } + }), + }); + // Parse each value found with the given parser, and return them if no errors + // occur. + values_to_check + .map(|lit| { + let lit = lit?; + lit.parse_with(parser).map_err(|err| { + syn::Error::new_spanned(&lit, format!("{invalid_value_msg}: {err}")) + }) + }) + .collect() +} + fn derive_marker_trait_inner( mut input: DeriveInput, ) -> Result { - // Enforce Pod on all generic fields. let trait_ = Trait::ident(&input)?; - add_trait_marker(&mut input.generics, &trait_); + // If this trait allows explicit bounds, and any explicit bounds were given, + // then use those explicit bounds. Else, apply the default bounds (bound + // each generic type on this trait). + if let Some(name) = Trait::explicit_bounds_attribute_name() { + // See if any explicit bounds were given in attributes. + let explicit_bounds = find_and_parse_helper_attributes( + &input.attrs, + name, + "bound", + >::parse_terminated, + "Type: Trait", + "invalid where predicate", + )?; + + if !explicit_bounds.is_empty() { + // Explicit bounds were given. + // Enforce explicitly given bounds, and emit "perfect derive" (i.e. add + // bounds for each field's type). + let explicit_bounds = explicit_bounds + .into_iter() + .flatten() + .collect::>(); + + let predicates = &mut input.generics.make_where_clause().predicates; + + predicates.extend(explicit_bounds); + + let fields = match &input.data { + syn::Data::Struct(syn::DataStruct { fields, .. }) => fields.clone(), + syn::Data::Union(_) => { + return Err(syn::Error::new_spanned( + trait_, + &"perfect derive is not supported for unions", + )); + } + syn::Data::Enum(_) => { + return Err(syn::Error::new_spanned( + trait_, + &"perfect derive is not supported for enums", + )); + } + }; + + for field in fields { + let ty = field.ty; + predicates.push(syn::parse_quote!( + #ty: #trait_ + )); + } + } else { + // No explicit bounds were given. + // Enforce trait bound on all type generics. + add_trait_marker(&mut input.generics, &trait_); + } + } else { + // This trait does not allow explicit bounds. + // Enforce trait bound on all type generics. + add_trait_marker(&mut input.generics, &trait_); + } let name = &input.ident; @@ -339,11 +511,8 @@ fn derive_marker_trait_inner( quote!() }; - let where_clause = if Trait::requires_where_clause() { - where_clause - } else { - None - }; + let where_clause = + if Trait::requires_where_clause() { where_clause } else { None }; Ok(quote! { #asserts @@ -364,9 +533,12 @@ fn add_trait_marker(generics: &mut syn::Generics, trait_name: &syn::Path) { let type_params = generics .type_params() .map(|param| ¶m.ident) - .map(|param| syn::parse_quote!( - #param: #trait_name - )).collect::>(); + .map(|param| { + syn::parse_quote!( + #param: #trait_name + ) + }) + .collect::>(); generics.make_where_clause().predicates.extend(type_params); } diff --git a/derive/src/traits.rs b/derive/src/traits.rs index ec6afd9..da51275 100644 --- a/derive/src/traits.rs +++ b/derive/src/traits.rs @@ -35,6 +35,9 @@ pub trait Derivable { fn requires_where_clause() -> bool { true } + fn explicit_bounds_attribute_name() -> Option<&'static str> { + None + } } pub struct Pod; @@ -126,6 +129,10 @@ impl Derivable for Zeroable { Data::Enum(_) => bail!("Deriving Zeroable is not supported for enums"), } } + + fn explicit_bounds_attribute_name() -> Option<&'static str> { + Some("zeroable") + } } pub struct NoUninit; @@ -532,7 +539,7 @@ fn generate_assert_no_padding(input: &DeriveInput) -> Result { let size_rest = quote_spanned!(span => #( + ::core::mem::size_of::<#field_types>() )*); - quote_spanned!(span => #size_first#size_rest) + quote_spanned!(span => #size_first #size_rest) } else { quote_spanned!(span => 0) };