mirror of
https://github.com/Lokathor/bytemuck.git
synced 2024-11-21 14:22:26 +00:00
Zeroable derive custom bounds (#196)
* Add space because rust-analyzer thinks it's a weird literal otherwise. * Add custom bounds for Zeroable. * Cleanup custom bounds code. * Only parse explicit bounds for Zeroable. * Qualify syn types. * If no explicit bounds are given, apply the default bounds. * Add perfect derive semantics to `#[zeroable(bound = "...")]`.
This commit is contained in:
parent
3b81c85c60
commit
39b42b8fa3
@ -77,7 +77,6 @@ pub fn derive_anybitpattern(
|
||||
///
|
||||
/// ```rust
|
||||
/// # use bytemuck_derive::{Zeroable};
|
||||
///
|
||||
/// #[derive(Copy, Clone, Zeroable)]
|
||||
/// #[repr(C)]
|
||||
/// struct Test {
|
||||
@ -85,7 +84,65 @@ pub fn derive_anybitpattern(
|
||||
/// b: u16,
|
||||
/// }
|
||||
/// ```
|
||||
#[proc_macro_derive(Zeroable)]
|
||||
///
|
||||
/// # Custom bounds
|
||||
///
|
||||
/// Custom bounds for the derived `Zeroable` impl can be given using the
|
||||
/// `#[zeroable(bound = "")]` helper attribute.
|
||||
///
|
||||
/// Using this attribute additionally opts-in to "perfect derive" semantics,
|
||||
/// where instead of adding bounds for each generic type parameter, bounds are
|
||||
/// added for each field's type.
|
||||
///
|
||||
/// ## Examples
|
||||
///
|
||||
/// ```rust
|
||||
/// # use bytemuck::Zeroable;
|
||||
/// # use std::marker::PhantomData;
|
||||
/// #[derive(Clone, Zeroable)]
|
||||
/// #[zeroable(bound = "")]
|
||||
/// struct AlwaysZeroable<T> {
|
||||
/// a: PhantomData<T>,
|
||||
/// }
|
||||
///
|
||||
/// AlwaysZeroable::<std::num::NonZeroU8>::zeroed();
|
||||
/// ```
|
||||
///
|
||||
/// ```rust,compile_fail
|
||||
/// # use bytemuck::Zeroable;
|
||||
/// # use std::marker::PhantomData;
|
||||
/// #[derive(Clone, Zeroable)]
|
||||
/// #[zeroable(bound = "T: Copy")]
|
||||
/// struct ZeroableWhenTIsCopy<T> {
|
||||
/// a: PhantomData<T>,
|
||||
/// }
|
||||
///
|
||||
/// ZeroableWhenTIsCopy::<String>::zeroed();
|
||||
/// ```
|
||||
///
|
||||
/// The restriction that all fields must be Zeroable is still applied, and this
|
||||
/// is enforced using the mentioned "perfect derive" semantics.
|
||||
///
|
||||
/// ```rust
|
||||
/// # use bytemuck::Zeroable;
|
||||
/// #[derive(Clone, Zeroable)]
|
||||
/// #[zeroable(bound = "")]
|
||||
/// struct ZeroableWhenTIsZeroable<T> {
|
||||
/// a: T,
|
||||
/// }
|
||||
/// ZeroableWhenTIsZeroable::<u32>::zeroed();
|
||||
/// ```
|
||||
///
|
||||
/// ```rust,compile_fail
|
||||
/// # use bytemuck::Zeroable;
|
||||
/// # #[derive(Clone, Zeroable)]
|
||||
/// # #[zeroable(bound = "")]
|
||||
/// # struct ZeroableWhenTIsZeroable<T> {
|
||||
/// # a: T,
|
||||
/// # }
|
||||
/// ZeroableWhenTIsZeroable::<String>::zeroed();
|
||||
/// ```
|
||||
#[proc_macro_derive(Zeroable, attributes(zeroable))]
|
||||
pub fn derive_zeroable(
|
||||
input: proc_macro::TokenStream,
|
||||
) -> proc_macro::TokenStream {
|
||||
@ -317,12 +374,127 @@ fn derive_marker_trait<Trait: Derivable>(input: DeriveInput) -> TokenStream {
|
||||
.unwrap_or_else(|err| err.into_compile_error())
|
||||
}
|
||||
|
||||
/// Find `#[name(key = "value")]` helper attributes on the struct, and return
|
||||
/// their `"value"`s parsed with `parser`.
|
||||
///
|
||||
/// Returns an error if any attributes with the given `name` do not match the
|
||||
/// expected format. Returns `Ok([])` if no attributes with `name` are found.
|
||||
fn find_and_parse_helper_attributes<P: syn::parse::Parser + Copy>(
|
||||
attributes: &[syn::Attribute], name: &str, key: &str, parser: P,
|
||||
example_value: &str, invalid_value_msg: &str,
|
||||
) -> Result<Vec<P::Output>> {
|
||||
let invalid_format_msg =
|
||||
format!("{name} attribute must be `{name}({key} = \"{example_value}\")`",);
|
||||
let values_to_check = attributes.iter().filter_map(|attr| match &attr.meta {
|
||||
// If a `Path` matches our `name`, return an error, else ignore it.
|
||||
// e.g. `#[zeroable]`
|
||||
syn::Meta::Path(path) => path
|
||||
.is_ident(name)
|
||||
.then(|| Err(syn::Error::new_spanned(path, &invalid_format_msg))),
|
||||
// If a `NameValue` matches our `name`, return an error, else ignore it.
|
||||
// e.g. `#[zeroable = "hello"]`
|
||||
syn::Meta::NameValue(namevalue) => {
|
||||
namevalue.path.is_ident(name).then(|| {
|
||||
Err(syn::Error::new_spanned(&namevalue.path, &invalid_format_msg))
|
||||
})
|
||||
}
|
||||
// If a `List` matches our `name`, match its contents to our format, else
|
||||
// ignore it. If its contents match our format, return the value, else
|
||||
// return an error.
|
||||
syn::Meta::List(list) => list.path.is_ident(name).then(|| {
|
||||
let namevalue: syn::MetaNameValue = syn::parse2(list.tokens.clone())
|
||||
.map_err(|_| {
|
||||
syn::Error::new_spanned(&list.tokens, &invalid_format_msg)
|
||||
})?;
|
||||
if namevalue.path.is_ident(key) {
|
||||
match namevalue.value {
|
||||
syn::Expr::Lit(syn::ExprLit {
|
||||
lit: syn::Lit::Str(strlit), ..
|
||||
}) => Ok(strlit),
|
||||
_ => {
|
||||
Err(syn::Error::new_spanned(&namevalue.path, &invalid_format_msg))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Err(syn::Error::new_spanned(&namevalue.path, &invalid_format_msg))
|
||||
}
|
||||
}),
|
||||
});
|
||||
// Parse each value found with the given parser, and return them if no errors
|
||||
// occur.
|
||||
values_to_check
|
||||
.map(|lit| {
|
||||
let lit = lit?;
|
||||
lit.parse_with(parser).map_err(|err| {
|
||||
syn::Error::new_spanned(&lit, format!("{invalid_value_msg}: {err}"))
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn derive_marker_trait_inner<Trait: Derivable>(
|
||||
mut input: DeriveInput,
|
||||
) -> Result<TokenStream> {
|
||||
// Enforce Pod on all generic fields.
|
||||
let trait_ = Trait::ident(&input)?;
|
||||
add_trait_marker(&mut input.generics, &trait_);
|
||||
// If this trait allows explicit bounds, and any explicit bounds were given,
|
||||
// then use those explicit bounds. Else, apply the default bounds (bound
|
||||
// each generic type on this trait).
|
||||
if let Some(name) = Trait::explicit_bounds_attribute_name() {
|
||||
// See if any explicit bounds were given in attributes.
|
||||
let explicit_bounds = find_and_parse_helper_attributes(
|
||||
&input.attrs,
|
||||
name,
|
||||
"bound",
|
||||
<syn::punctuated::Punctuated<syn::WherePredicate, syn::Token![,]>>::parse_terminated,
|
||||
"Type: Trait",
|
||||
"invalid where predicate",
|
||||
)?;
|
||||
|
||||
if !explicit_bounds.is_empty() {
|
||||
// Explicit bounds were given.
|
||||
// Enforce explicitly given bounds, and emit "perfect derive" (i.e. add
|
||||
// bounds for each field's type).
|
||||
let explicit_bounds = explicit_bounds
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect::<Vec<syn::WherePredicate>>();
|
||||
|
||||
let predicates = &mut input.generics.make_where_clause().predicates;
|
||||
|
||||
predicates.extend(explicit_bounds);
|
||||
|
||||
let fields = match &input.data {
|
||||
syn::Data::Struct(syn::DataStruct { fields, .. }) => fields.clone(),
|
||||
syn::Data::Union(_) => {
|
||||
return Err(syn::Error::new_spanned(
|
||||
trait_,
|
||||
&"perfect derive is not supported for unions",
|
||||
));
|
||||
}
|
||||
syn::Data::Enum(_) => {
|
||||
return Err(syn::Error::new_spanned(
|
||||
trait_,
|
||||
&"perfect derive is not supported for enums",
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
for field in fields {
|
||||
let ty = field.ty;
|
||||
predicates.push(syn::parse_quote!(
|
||||
#ty: #trait_
|
||||
));
|
||||
}
|
||||
} else {
|
||||
// No explicit bounds were given.
|
||||
// Enforce trait bound on all type generics.
|
||||
add_trait_marker(&mut input.generics, &trait_);
|
||||
}
|
||||
} else {
|
||||
// This trait does not allow explicit bounds.
|
||||
// Enforce trait bound on all type generics.
|
||||
add_trait_marker(&mut input.generics, &trait_);
|
||||
}
|
||||
|
||||
let name = &input.ident;
|
||||
|
||||
@ -339,11 +511,8 @@ fn derive_marker_trait_inner<Trait: Derivable>(
|
||||
quote!()
|
||||
};
|
||||
|
||||
let where_clause = if Trait::requires_where_clause() {
|
||||
where_clause
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let where_clause =
|
||||
if Trait::requires_where_clause() { where_clause } else { None };
|
||||
|
||||
Ok(quote! {
|
||||
#asserts
|
||||
@ -364,9 +533,12 @@ fn add_trait_marker(generics: &mut syn::Generics, trait_name: &syn::Path) {
|
||||
let type_params = generics
|
||||
.type_params()
|
||||
.map(|param| ¶m.ident)
|
||||
.map(|param| syn::parse_quote!(
|
||||
#param: #trait_name
|
||||
)).collect::<Vec<syn::WherePredicate>>();
|
||||
.map(|param| {
|
||||
syn::parse_quote!(
|
||||
#param: #trait_name
|
||||
)
|
||||
})
|
||||
.collect::<Vec<syn::WherePredicate>>();
|
||||
|
||||
generics.make_where_clause().predicates.extend(type_params);
|
||||
}
|
||||
|
@ -35,6 +35,9 @@ pub trait Derivable {
|
||||
fn requires_where_clause() -> bool {
|
||||
true
|
||||
}
|
||||
fn explicit_bounds_attribute_name() -> Option<&'static str> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Pod;
|
||||
@ -126,6 +129,10 @@ impl Derivable for Zeroable {
|
||||
Data::Enum(_) => bail!("Deriving Zeroable is not supported for enums"),
|
||||
}
|
||||
}
|
||||
|
||||
fn explicit_bounds_attribute_name() -> Option<&'static str> {
|
||||
Some("zeroable")
|
||||
}
|
||||
}
|
||||
|
||||
pub struct NoUninit;
|
||||
@ -532,7 +539,7 @@ fn generate_assert_no_padding(input: &DeriveInput) -> Result<TokenStream> {
|
||||
let size_rest =
|
||||
quote_spanned!(span => #( + ::core::mem::size_of::<#field_types>() )*);
|
||||
|
||||
quote_spanned!(span => #size_first#size_rest)
|
||||
quote_spanned!(span => #size_first #size_rest)
|
||||
} else {
|
||||
quote_spanned!(span => 0)
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user