Allow repr(transparent) to be used generically in derive(Pod) (#139)

* Enabled transparent generics

* Move trait checks to implementation block

* Replace add_trait_marker impl
This commit is contained in:
John Nunley 2022-11-03 06:53:01 -07:00 committed by GitHub
parent 7b67524a43
commit 518baf9c0b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 149 additions and 116 deletions

View File

@ -9,7 +9,8 @@ use quote::quote;
use syn::{parse_macro_input, DeriveInput, Result};
use crate::traits::{
AnyBitPattern, Contiguous, Derivable, CheckedBitPattern, NoUninit, Pod, TransparentWrapper, Zeroable,
AnyBitPattern, CheckedBitPattern, Contiguous, Derivable, NoUninit, Pod,
TransparentWrapper, Zeroable,
};
/// Derive the `Pod` trait for a struct
@ -56,8 +57,9 @@ pub fn derive_pod(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
pub fn derive_anybitpattern(
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let expanded =
derive_marker_trait::<AnyBitPattern>(parse_macro_input!(input as DeriveInput));
let expanded = derive_marker_trait::<AnyBitPattern>(parse_macro_input!(
input as DeriveInput
));
proc_macro::TokenStream::from(expanded)
}
@ -99,8 +101,8 @@ pub fn derive_zeroable(
/// for the `NoUninit` trait.
///
/// The following constraints need to be satisfied for the macro to succeed
/// (the rest of the constraints are guaranteed by the `NoUninit` subtrait bounds,
/// i.e. the type must be `Sized + Copy + 'static`):
/// (the rest of the constraints are guaranteed by the `NoUninit` subtrait
/// bounds, i.e. the type must be `Sized + Copy + 'static`):
///
/// If applied to a struct:
/// - All fields in the struct must implement `NoUninit`
@ -129,9 +131,9 @@ pub fn derive_no_uninit(
/// definition and `is_valid_bit_pattern` method for the type automatically.
///
/// The following constraints need to be satisfied for the macro to succeed
/// (the rest of the constraints are guaranteed by the `CheckedBitPattern` subtrait bounds,
/// i.e. are guaranteed by the requirements of the `NoUninit` trait which `CheckedBitPattern`
/// is a subtrait of):
/// (the rest of the constraints are guaranteed by the `CheckedBitPattern`
/// subtrait bounds, i.e. are guaranteed by the requirements of the `NoUninit`
/// trait which `CheckedBitPattern` is a subtrait of):
///
/// If applied to a struct:
/// - All fields must implement `CheckedBitPattern`
@ -142,8 +144,9 @@ pub fn derive_no_uninit(
pub fn derive_maybe_pod(
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let expanded =
derive_marker_trait::<CheckedBitPattern>(parse_macro_input!(input as DeriveInput));
let expanded = derive_marker_trait::<CheckedBitPattern>(parse_macro_input!(
input as DeriveInput
));
proc_macro::TokenStream::from(expanded)
}
@ -228,17 +231,19 @@ fn derive_marker_trait<Trait: Derivable>(input: DeriveInput) -> TokenStream {
}
fn derive_marker_trait_inner<Trait: Derivable>(
input: DeriveInput,
mut input: DeriveInput,
) -> Result<TokenStream> {
// Enforce Pod on all generic fields.
let trait_ = Trait::ident(&input)?;
add_trait_marker(&mut input.generics, &trait_);
let name = &input.ident;
let (impl_generics, ty_generics, where_clause) =
input.generics.split_for_impl();
let trait_ = Trait::ident();
Trait::check_attributes(&input.data, &input.attrs)?;
let asserts = Trait::asserts(&input)?;
let trait_params = Trait::generic_params(&input)?;
let (trait_impl_extras, trait_impl) = Trait::trait_impl(&input)?;
let implies_trait = if let Some(implies_trait) = Trait::implies_trait() {
@ -252,10 +257,23 @@ fn derive_marker_trait_inner<Trait: Derivable>(
#trait_impl_extras
unsafe impl #impl_generics #trait_ #trait_params for #name #ty_generics #where_clause {
unsafe impl #impl_generics #trait_ for #name #ty_generics #where_clause {
#trait_impl
}
#implies_trait
})
}
/// Add a trait marker to the generics if it is not already present
fn add_trait_marker(generics: &mut syn::Generics, trait_name: &syn::Path) {
// Get each generic type parameter.
let type_params = generics
.type_params()
.map(|param| &param.ident)
.map(|param| syn::parse_quote!(
#param: #trait_name
)).collect::<Vec<syn::WherePredicate>>();
generics.make_where_clause().predicates.extend(type_params);
}

View File

@ -1,42 +1,35 @@
#![allow(unused_imports)]
use proc_macro2::{Ident, Span, TokenStream, TokenTree};
use quote::{quote, quote_spanned, ToTokens};
use syn::{*,
parse::{Parse, Parser, ParseStream},
use syn::{
parse::{Parse, ParseStream, Parser},
punctuated::Punctuated,
spanned::Spanned,
Result,
Result, *,
};
macro_rules! bail {
($msg:expr $(,)?) => (
($msg:expr $(,)?) => {
return Err(Error::new(Span::call_site(), &$msg[..]))
);
};
( $msg:expr => $span_to_blame:expr $(,)? ) => (
( $msg:expr => $span_to_blame:expr $(,)? ) => {
return Err(Error::new_spanned(&$span_to_blame, $msg))
);
};
}
pub trait Derivable {
fn ident() -> TokenStream;
fn ident(input: &DeriveInput) -> Result<syn::Path>;
fn implies_trait() -> Option<TokenStream> {
None
}
fn generic_params(_input: &DeriveInput) -> Result<TokenStream> {
Ok(quote!())
}
fn asserts(_input: &DeriveInput) -> Result<TokenStream> {
Ok(quote!())
}
fn check_attributes(
_ty: &Data, _attributes: &[Attribute],
) -> Result<()> {
fn check_attributes(_ty: &Data, _attributes: &[Attribute]) -> Result<()> {
Ok(())
}
fn trait_impl(
_input: &DeriveInput,
) -> Result<(TokenStream, TokenStream)> {
fn trait_impl(_input: &DeriveInput) -> Result<(TokenStream, TokenStream)> {
Ok((quote!(), quote!()))
}
}
@ -44,14 +37,15 @@ pub trait Derivable {
pub struct Pod;
impl Derivable for Pod {
fn ident() -> TokenStream {
quote!(::bytemuck::Pod)
fn ident(_: &DeriveInput) -> Result<syn::Path> {
Ok(syn::parse_quote!(::bytemuck::Pod))
}
fn asserts(input: &DeriveInput) -> Result<TokenStream> {
let repr = get_repr(&input.attrs)?;
let completly_packed = repr.packed == Some(1);
let completly_packed =
repr.packed == Some(1) || repr.repr == Repr::Transparent;
if !completly_packed && !input.generics.params.is_empty() {
bail!("\
@ -69,7 +63,7 @@ impl Derivable for Pod {
None
};
let assert_fields_are_pod =
generate_fields_are_trait(input, Self::ident())?;
generate_fields_are_trait(input, Self::ident(input)?)?;
Ok(quote!(
#assert_no_padding
@ -81,9 +75,7 @@ impl Derivable for Pod {
}
}
fn check_attributes(
_ty: &Data, attributes: &[Attribute],
) -> Result<()> {
fn check_attributes(_ty: &Data, attributes: &[Attribute]) -> Result<()> {
let repr = get_repr(attributes)?;
match repr.repr {
Repr::C => Ok(()),
@ -98,8 +90,8 @@ impl Derivable for Pod {
pub struct AnyBitPattern;
impl Derivable for AnyBitPattern {
fn ident() -> TokenStream {
quote!(::bytemuck::AnyBitPattern)
fn ident(_: &DeriveInput) -> Result<syn::Path> {
Ok(syn::parse_quote!(::bytemuck::AnyBitPattern))
}
fn implies_trait() -> Option<TokenStream> {
@ -109,8 +101,10 @@ impl Derivable for AnyBitPattern {
fn asserts(input: &DeriveInput) -> Result<TokenStream> {
match &input.data {
Data::Union(_) => Ok(quote!()), // unions are always `AnyBitPattern`
Data::Struct(_) => generate_fields_are_trait(input, Self::ident()),
Data::Enum(_) => bail!("Deriving AnyBitPattern is not supported for enums"),
Data::Struct(_) => generate_fields_are_trait(input, Self::ident(input)?),
Data::Enum(_) => {
bail!("Deriving AnyBitPattern is not supported for enums")
}
}
}
}
@ -118,14 +112,14 @@ impl Derivable for AnyBitPattern {
pub struct Zeroable;
impl Derivable for Zeroable {
fn ident() -> TokenStream {
quote!(::bytemuck::Zeroable)
fn ident(_: &DeriveInput) -> Result<syn::Path> {
Ok(syn::parse_quote!(::bytemuck::Zeroable))
}
fn asserts(input: &DeriveInput) -> Result<TokenStream> {
match &input.data {
Data::Union(_) => Ok(quote!()), // unions are always `Zeroable`
Data::Struct(_) => generate_fields_are_trait(input, Self::ident()),
Data::Struct(_) => generate_fields_are_trait(input, Self::ident(input)?),
Data::Enum(_) => bail!("Deriving Zeroable is not supported for enums"),
}
}
@ -134,13 +128,11 @@ impl Derivable for Zeroable {
pub struct NoUninit;
impl Derivable for NoUninit {
fn ident() -> TokenStream {
quote!(::bytemuck::NoUninit)
fn ident(_: &DeriveInput) -> Result<syn::Path> {
Ok(syn::parse_quote!(::bytemuck::NoUninit))
}
fn check_attributes(
ty: &Data, attributes: &[Attribute],
) -> Result<()> {
fn check_attributes(ty: &Data, attributes: &[Attribute]) -> Result<()> {
let repr = get_repr(attributes)?;
match ty {
Data::Struct(_) => match repr.repr {
@ -165,7 +157,7 @@ impl Derivable for NoUninit {
Data::Struct(DataStruct { .. }) => {
let assert_no_padding = generate_assert_no_padding(&input)?;
let assert_fields_are_no_padding =
generate_fields_are_trait(&input, Self::ident())?;
generate_fields_are_trait(&input, Self::ident(input)?)?;
Ok(quote!(
#assert_no_padding
@ -179,13 +171,11 @@ impl Derivable for NoUninit {
Ok(quote!())
}
}
Data::Union(_) => bail!("NoUninit cannot be derived for unions"), // shouldn't be possible since we already error in attribute check for this case
Data::Union(_) => bail!("NoUninit cannot be derived for unions"), /* shouldn't be possible since we already error in attribute check for this case */
}
}
fn trait_impl(
_input: &DeriveInput,
) -> Result<(TokenStream, TokenStream)> {
fn trait_impl(_input: &DeriveInput) -> Result<(TokenStream, TokenStream)> {
Ok((quote!(), quote!()))
}
}
@ -193,13 +183,11 @@ impl Derivable for NoUninit {
pub struct CheckedBitPattern;
impl Derivable for CheckedBitPattern {
fn ident() -> TokenStream {
quote!(::bytemuck::CheckedBitPattern)
fn ident(_: &DeriveInput) -> Result<syn::Path> {
Ok(syn::parse_quote!(::bytemuck::CheckedBitPattern))
}
fn check_attributes(
ty: &Data, attributes: &[Attribute],
) -> Result<()> {
fn check_attributes(ty: &Data, attributes: &[Attribute]) -> Result<()> {
let repr = get_repr(attributes)?;
match ty {
Data::Struct(_) => match repr.repr {
@ -223,24 +211,23 @@ impl Derivable for CheckedBitPattern {
match &input.data {
Data::Struct(DataStruct { .. }) => {
let assert_fields_are_maybe_pod =
generate_fields_are_trait(&input, Self::ident())?;
generate_fields_are_trait(&input, Self::ident(input)?)?;
Ok(assert_fields_are_maybe_pod)
}
Data::Enum(_) => Ok(quote!()), // nothing needed, already guaranteed OK by NoUninit
Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), // shouldn't be possible since we already error in attribute check for this case
Data::Enum(_) => Ok(quote!()), /* nothing needed, already guaranteed
* OK by NoUninit */
Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), /* shouldn't be possible since we already error in attribute check for this case */
}
}
fn trait_impl(
input: &DeriveInput,
) -> Result<(TokenStream, TokenStream)> {
fn trait_impl(input: &DeriveInput) -> Result<(TokenStream, TokenStream)> {
match &input.data {
Data::Struct(DataStruct { fields, .. }) => {
generate_checked_bit_pattern_struct(&input.ident, fields, &input.attrs)
},
}
Data::Enum(_) => generate_checked_bit_pattern_enum(input),
Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), // shouldn't be possible since we already error in attribute check for this case
Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), /* shouldn't be possible since we already error in attribute check for this case */
}
}
}
@ -266,20 +253,20 @@ impl TransparentWrapper {
}
impl Derivable for TransparentWrapper {
fn ident() -> TokenStream {
quote!(::bytemuck::TransparentWrapper)
}
fn generic_params(input: &DeriveInput) -> Result<TokenStream> {
fn ident(input: &DeriveInput) -> Result<syn::Path> {
let fields = get_struct_fields(input)?;
match Self::get_wrapper_type(&input.attrs, &fields) {
| Some(ty) => Ok(quote!(<#ty>)),
| None => bail!("\
let ty = match Self::get_wrapper_type(&input.attrs, &fields) {
Some(ty) => ty,
None => bail!(
"\
when deriving TransparentWrapper for a struct with more than one field \
you need to specify the transparent field using #[transparent(T)]\
"),
}
"
),
};
Ok(syn::parse_quote!(::bytemuck::TransparentWrapper<#ty>))
}
fn asserts(input: &DeriveInput) -> Result<TokenStream> {
@ -301,15 +288,15 @@ impl Derivable for TransparentWrapper {
}
}
fn check_attributes(
_ty: &Data, attributes: &[Attribute],
) -> Result<()> {
fn check_attributes(_ty: &Data, attributes: &[Attribute]) -> Result<()> {
let repr = get_repr(attributes)?;
match repr.repr {
Repr::Transparent => Ok(()),
_ => {
bail!("TransparentWrapper requires the struct to be #[repr(transparent)]")
bail!(
"TransparentWrapper requires the struct to be #[repr(transparent)]"
)
}
}
}
@ -318,13 +305,11 @@ impl Derivable for TransparentWrapper {
pub struct Contiguous;
impl Derivable for Contiguous {
fn ident() -> TokenStream {
quote!(::bytemuck::Contiguous)
fn ident(_: &DeriveInput) -> Result<syn::Path> {
Ok(syn::parse_quote!(::bytemuck::Contiguous))
}
fn trait_impl(
input: &DeriveInput,
) -> Result<(TokenStream, TokenStream)> {
fn trait_impl(input: &DeriveInput) -> Result<(TokenStream, TokenStream)> {
let repr = get_repr(&input.attrs)?;
let integer_ty = if let Some(integer_ty) = repr.repr.as_integer_type() {
@ -422,7 +407,8 @@ fn generate_checked_bit_pattern_struct(
let field_name = &field_names[..];
let field_ty = &field_tys[..];
let derive_dbg = quote!(#[cfg_attr(not(target_arch = "spirv"), derive(Debug))]);
let derive_dbg =
quote!(#[cfg_attr(not(target_arch = "spirv"), derive(Debug))]);
Ok((
quote! {
@ -456,7 +442,11 @@ fn generate_checked_bit_pattern_enum(
(i64::max_value(), i64::min_value(), 0),
|(min, max, count), res| {
let discriminant = res?;
Ok::<_, Error>((i64::min(min, discriminant), i64::max(max, discriminant), count + 1))
Ok::<_, Error>((
i64::min(min, discriminant),
i64::max(max, discriminant),
count + 1,
))
},
)?;
@ -503,9 +493,7 @@ fn generate_checked_bit_pattern_enum(
/// Check that a struct has no padding by asserting that the size of the struct
/// is equal to the sum of the size of it's fields
fn generate_assert_no_padding(
input: &DeriveInput,
) -> Result<TokenStream> {
fn generate_assert_no_padding(input: &DeriveInput) -> Result<TokenStream> {
let struct_type = &input.ident;
let span = input.ident.span();
let fields = get_fields(input)?;
@ -529,7 +517,7 @@ fn generate_assert_no_padding(
/// Check that all fields implement a given trait
fn generate_fields_are_trait(
input: &DeriveInput, trait_: TokenStream,
input: &DeriveInput, trait_: syn::Path,
) -> Result<TokenStream> {
let (impl_generics, _ty_generics, where_clause) =
input.generics.split_for_impl();
@ -574,28 +562,30 @@ fn get_simple_attr(attributes: &[Attribute], attr_name: &str) -> Option<Ident> {
fn get_repr(attributes: &[Attribute]) -> Result<Representation> {
attributes
.iter()
.filter_map(|attr| if attr.path.is_ident("repr") {
Some(attr.parse_args::<Representation>())
} else {
None
.filter_map(|attr| {
if attr.path.is_ident("repr") {
Some(attr.parse_args::<Representation>())
} else {
None
}
})
.try_fold(Representation::default(), |a, b| {
let b = b?;
Ok(Representation {
repr: match (a.repr, b.repr) {
| (a, Repr::Rust) => a,
| (Repr::Rust, b) => b,
| _ => bail!("conflicting representation hints"),
(a, Repr::Rust) => a,
(Repr::Rust, b) => b,
_ => bail!("conflicting representation hints"),
},
packed: match (a.packed, b.packed) {
| (a, None) => a,
| (None, b) => b,
| _ => bail!("conflicting representation hints"),
(a, None) => a,
(None, b) => b,
_ => bail!("conflicting representation hints"),
},
align: match (a.align, b.align) {
| (a, None) => a,
| (None, b) => b,
| _ => bail!("conflicting representation hints"),
(a, None) => a,
(None, b) => b,
_ => bail!("conflicting representation hints"),
},
})
})
@ -719,7 +709,8 @@ macro_rules! mk_repr {(
));
}
}
)} use mk_repr;
)}
use mk_repr;
struct VariantDiscriminantIterator<'a, I: Iterator<Item = &'a Variant> + 'a> {
inner: I,
@ -767,9 +758,7 @@ fn parse_int_expr(expr: &Expr) -> Result<i64> {
Expr::Unary(ExprUnary { op: UnOp::Neg(_), expr, .. }) => {
parse_int_expr(expr).map(|int| -int)
}
Expr::Lit(ExprLit { lit: Lit::Int(int), .. }) => {
int.base10_parse()
}
Expr::Lit(ExprLit { lit: Lit::Int(int), .. }) => int.base10_parse(),
_ => bail!("Not an integer expression"),
}
}

View File

@ -1,7 +1,8 @@
#![allow(dead_code)]
use bytemuck::{
AnyBitPattern, Contiguous, CheckedBitPattern, NoUninit, Pod, TransparentWrapper, Zeroable,
AnyBitPattern, CheckedBitPattern, Contiguous, NoUninit, Pod,
TransparentWrapper, Zeroable,
};
use std::marker::PhantomData;
@ -138,9 +139,26 @@ struct CheckedBitPatternStruct {
#[repr(C)]
struct AnyBitPatternTest {
a: u16,
b: u16
b: u16,
}
/// ```compile_fail
/// use bytemuck::{Pod, Zeroable};
///
/// #[derive(Pod, Zeroable)]
/// #[repr(transparent)]
/// struct TransparentSingle<T>(T);
///
/// struct NotPod(u32);
///
/// let _: u32 = bytemuck::cast(TransparentSingle(NotPod(0u32)));
/// ```
#[derive(
Debug, Copy, Clone, PartialEq, Eq, Pod, Zeroable, TransparentWrapper,
)]
#[repr(transparent)]
struct NewtypeWrapperTest<T>(T);
#[test]
fn fails_cast_contiguous() {
let can_cast = CheckedBitPatternEnumWithValues::is_valid_bit_pattern(&5);
@ -149,7 +167,8 @@ fn fails_cast_contiguous() {
#[test]
fn passes_cast_contiguous() {
let res = bytemuck::checked::from_bytes::<CheckedBitPatternEnumWithValues>(&[2u8]);
let res =
bytemuck::checked::from_bytes::<CheckedBitPatternEnumWithValues>(&[2u8]);
assert_eq!(*res, CheckedBitPatternEnumWithValues::C);
}
@ -162,7 +181,9 @@ fn fails_cast_noncontiguous() {
#[test]
fn passes_cast_noncontiguous() {
let res =
bytemuck::checked::from_bytes::<CheckedBitPatternEnumNonContiguous>(&[56u8]);
bytemuck::checked::from_bytes::<CheckedBitPatternEnumNonContiguous>(&[
56u8,
]);
assert_eq!(*res, CheckedBitPatternEnumNonContiguous::E);
}
@ -177,7 +198,10 @@ fn fails_cast_struct() {
fn passes_cast_struct() {
let pod = [0u8, 8u8];
let res = bytemuck::checked::from_bytes::<CheckedBitPatternStruct>(&pod);
assert_eq!(*res, CheckedBitPatternStruct { a: 0, b: CheckedBitPatternEnumNonContiguous::B });
assert_eq!(
*res,
CheckedBitPatternStruct { a: 0, b: CheckedBitPatternEnumNonContiguous::B }
);
}
#[test]

View File

@ -23,7 +23,9 @@ possibility code branch.
#[cfg(not(target_arch = "spirv"))]
#[cold]
#[inline(never)]
pub(crate) fn something_went_wrong<D: core::fmt::Display>(_src: &str, _err: D) -> ! {
pub(crate) fn something_went_wrong<D: core::fmt::Display>(
_src: &str, _err: D,
) -> ! {
// Note(Lokathor): Keeping the panic here makes the panic _formatting_ go
// here too, which helps assembly readability and also helps keep down
// the inline pressure.