support deriving Pod for packed generic types. (#123)

* improve `#[repr]` parsing

* allow deriving `Pod` for packed generic types

* Revert "Update Cargo.toml"

This reverts commit 6632bcef2c.
This commit is contained in:
Tom Dohrmann 2022-08-07 22:32:00 +02:00 committed by GitHub
parent 331762b014
commit 2c97676bfc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 270 additions and 40 deletions

View File

@ -28,9 +28,7 @@ unsound_ptr_pod_impl = []
nightly_portable_simd = [] nightly_portable_simd = []
[dependencies] [dependencies]
# use the upper line for testing against bytemuck_derive changes, if any bytemuck_derive = { version = "1.1", path = "derive", optional = true }
#bytemuck_derive = { path = "./derive", optional = true }
bytemuck_derive = { version = "1.1", optional = true }
[package.metadata.docs.rs] [package.metadata.docs.rs]
# Note(Lokathor): Don't use all-features or it would use `unsound_ptr_pod_impl` too. # Note(Lokathor): Don't use all-features or it would use `unsound_ptr_pod_impl` too.

View File

@ -1,9 +1,9 @@
use proc_macro2::{Ident, TokenStream, TokenTree}; use proc_macro2::{Ident, Span, TokenStream, TokenTree};
use quote::{quote, quote_spanned, ToTokens}; use quote::{quote, quote_spanned, ToTokens};
use syn::{ use syn::{
spanned::Spanned, AttrStyle, Attribute, Data, DataEnum, DataStruct, spanned::Spanned, AttrStyle, Attribute, Data, DataEnum, DataStruct,
DeriveInput, Expr, ExprLit, ExprUnary, Fields, Lit, LitInt, Type, UnOp, DataUnion, DeriveInput, Expr, ExprLit, ExprUnary, Fields, Lit, LitInt, Meta,
Variant, DataUnion, NestedMeta, Type, UnOp, Variant,
}; };
pub trait Derivable { pub trait Derivable {
@ -37,13 +37,21 @@ impl Derivable for Pod {
} }
fn asserts(input: &DeriveInput) -> Result<TokenStream, &'static str> { fn asserts(input: &DeriveInput) -> Result<TokenStream, &'static str> {
if !input.generics.params.is_empty() { let repr = get_repr(&input.attrs);
return Err("Pod requires cannot be derived for types containing generic parameters because the padding requirements can't be verified for generic structs");
let completly_packed = repr.packed == Some(1);
if !completly_packed && !input.generics.params.is_empty() {
return Err("Pod requires cannot be derived for non-packed types containing generic parameters because the padding requirements can't be verified for generic non-packed structs");
} }
match &input.data { match &input.data {
Data::Struct(_) => { Data::Struct(_) => {
let assert_no_padding = generate_assert_no_padding(input)?; let assert_no_padding = if !completly_packed {
Some(generate_assert_no_padding(input)?)
} else {
None
};
let assert_fields_are_pod = let assert_fields_are_pod =
generate_fields_are_trait(input, Self::ident())?; generate_fields_are_trait(input, Self::ident())?;
@ -51,9 +59,9 @@ impl Derivable for Pod {
#assert_no_padding #assert_no_padding
#assert_fields_are_pod #assert_fields_are_pod
)) ))
}, }
Data::Enum(_) => Err("Deriving Pod is not supported for enums"), Data::Enum(_) => Err("Deriving Pod is not supported for enums"),
Data::Union(_) => Err("Deriving Pod is not supported for unions") Data::Union(_) => Err("Deriving Pod is not supported for unions"),
} }
} }
@ -61,9 +69,9 @@ impl Derivable for Pod {
_ty: &Data, attributes: &[Attribute], _ty: &Data, attributes: &[Attribute],
) -> Result<(), &'static str> { ) -> Result<(), &'static str> {
let repr = get_repr(attributes); let repr = get_repr(attributes);
match repr.as_ref().map(|repr| repr.as_str()) { match repr.repr {
Some("C") => Ok(()), Repr::C => Ok(()),
Some("transparent") => Ok(()), Repr::Transparent => Ok(()),
_ => { _ => {
Err("Pod requires the type to be #[repr(C)] or #[repr(transparent)]") Err("Pod requires the type to be #[repr(C)] or #[repr(transparent)]")
} }
@ -119,11 +127,11 @@ impl Derivable for NoUninit {
) -> Result<(), &'static str> { ) -> Result<(), &'static str> {
let repr = get_repr(attributes); let repr = get_repr(attributes);
match ty { match ty {
Data::Struct(_) => match repr.as_deref() { Data::Struct(_) => match repr.repr {
Some("C" | "transparent") => Ok(()), Repr::C | Repr::Transparent => Ok(()),
_ => Err("NoUninit requires the struct to be #[repr(C)] or #[repr(transparent)]"), _ => Err("NoUninit requires the struct to be #[repr(C)] or #[repr(transparent)]"),
}, },
Data::Enum(_) => if repr.map(|repr| repr.starts_with('u') || repr.starts_with('i')) == Some(true) { Data::Enum(_) => if repr.repr.is_integer() {
Ok(()) Ok(())
} else { } else {
Err("NoUninit requires the enum to be an explicit #[repr(Int)]") Err("NoUninit requires the enum to be an explicit #[repr(Int)]")
@ -178,11 +186,11 @@ impl Derivable for CheckedBitPattern {
) -> Result<(), &'static str> { ) -> Result<(), &'static str> {
let repr = get_repr(attributes); let repr = get_repr(attributes);
match ty { match ty {
Data::Struct(_) => match repr.as_deref() { Data::Struct(_) => match repr.repr {
Some("C" | "transparent") => Ok(()), Repr::C | Repr::Transparent => Ok(()),
_ => Err("CheckedBitPattern derive requires the struct to be #[repr(C)] or #[repr(transparent)]"), _ => Err("CheckedBitPattern derive requires the struct to be #[repr(C)] or #[repr(transparent)]"),
}, },
Data::Enum(_) => if repr.map(|repr| repr.starts_with('u') || repr.starts_with('i')) == Some(true) { Data::Enum(_) => if repr.repr.is_integer() {
Ok(()) Ok(())
} else { } else {
Err("CheckedBitPattern requires the enum to be an explicit #[repr(Int)]") Err("CheckedBitPattern requires the enum to be an explicit #[repr(Int)]")
@ -212,9 +220,9 @@ impl Derivable for CheckedBitPattern {
input: &DeriveInput, input: &DeriveInput,
) -> Result<(TokenStream, TokenStream), &'static str> { ) -> Result<(TokenStream, TokenStream), &'static str> {
match &input.data { match &input.data {
Data::Struct(DataStruct { fields, .. }) => { Data::Struct(DataStruct { fields, .. }) => Ok(
Ok(generate_checked_bit_pattern_struct(&input.ident, fields, &input.attrs)) generate_checked_bit_pattern_struct(&input.ident, fields, &input.attrs),
} ),
Data::Enum(_) => generate_checked_bit_pattern_enum(input), Data::Enum(_) => generate_checked_bit_pattern_enum(input),
Data::Union(_) => Err("Internal error in CheckedBitPattern derive"), // shouldn't be possible since we already error in attribute check for this case Data::Union(_) => Err("Internal error in CheckedBitPattern derive"), // shouldn't be possible since we already error in attribute check for this case
} }
@ -277,8 +285,8 @@ impl Derivable for TransparentWrapper {
) -> Result<(), &'static str> { ) -> Result<(), &'static str> {
let repr = get_repr(attributes); let repr = get_repr(attributes);
match repr.as_ref().map(|repr| repr.as_str()) { match repr.repr {
Some("transparent") => Ok(()), Repr::Transparent => Ok(()),
_ => { _ => {
Err("TransparentWrapper requires the struct to be #[repr(transparent)]") Err("TransparentWrapper requires the struct to be #[repr(transparent)]")
} }
@ -296,12 +304,13 @@ impl Derivable for Contiguous {
fn trait_impl( fn trait_impl(
input: &DeriveInput, input: &DeriveInput,
) -> Result<(TokenStream, TokenStream), &'static str> { ) -> Result<(TokenStream, TokenStream), &'static str> {
let repr = get_repr(&input.attrs) let repr = get_repr(&input.attrs);
.ok_or("Contiguous requires the enum to be #[repr(Int)]")?;
if !repr.starts_with('u') && !repr.starts_with('i') { let integer_ty = if let Some(integer_ty) = repr.repr.as_integer_type() {
integer_ty
} else {
return Err("Contiguous requires the enum to be #[repr(Int)]"); return Err("Contiguous requires the enum to be #[repr(Int)]");
} };
let variants = get_enum_variants(input)?; let variants = get_enum_variants(input)?;
let mut variants_with_discriminator = let mut variants_with_discriminator =
@ -325,16 +334,15 @@ impl Derivable for Contiguous {
); );
} }
let repr_ident = Ident::new(&repr, input.span());
let min_lit = LitInt::new(&format!("{}", min), input.span()); let min_lit = LitInt::new(&format!("{}", min), input.span());
let max_lit = LitInt::new(&format!("{}", max), input.span()); let max_lit = LitInt::new(&format!("{}", max), input.span());
Ok(( Ok((
quote!(), quote!(),
quote! { quote! {
type Int = #repr_ident; type Int = #integer_ty;
const MIN_VALUE: #repr_ident = #min_lit; const MIN_VALUE: #integer_ty = #min_lit;
const MAX_VALUE: #repr_ident = #max_lit; const MAX_VALUE: #integer_ty = #max_lit;
}, },
)) ))
} }
@ -352,7 +360,7 @@ fn get_fields(input: &DeriveInput) -> Result<Fields, &'static str> {
match &input.data { match &input.data {
Data::Struct(DataStruct { fields, .. }) => Ok(fields.clone()), Data::Struct(DataStruct { fields, .. }) => Ok(fields.clone()),
Data::Union(DataUnion { fields, .. }) => Ok(Fields::Named(fields.clone())), Data::Union(DataUnion { fields, .. }) => Ok(Fields::Named(fields.clone())),
Data::Enum(_) => Err("deriving this trait is not supported for enums") Data::Enum(_) => Err("deriving this trait is not supported for enums"),
} }
} }
@ -377,7 +385,7 @@ fn generate_checked_bit_pattern_struct(
) -> (TokenStream, TokenStream) { ) -> (TokenStream, TokenStream) {
let bits_ty = Ident::new(&format!("{}Bits", input_ident), input_ident.span()); let bits_ty = Ident::new(&format!("{}Bits", input_ident), input_ident.span());
let repr = get_simple_attr(attrs, "repr").unwrap(); // should be checked in attr check already let repr = get_repr(attrs);
let field_names = fields let field_names = fields
.iter() .iter()
@ -400,7 +408,7 @@ fn generate_checked_bit_pattern_struct(
( (
quote! { quote! {
#[repr(#repr)] #repr
#[derive(Clone, Copy, ::bytemuck::AnyBitPattern)] #[derive(Clone, Copy, ::bytemuck::AnyBitPattern)]
#derive_dbg #derive_dbg
pub struct #bits_ty { pub struct #bits_ty {
@ -459,11 +467,12 @@ fn generate_checked_bit_pattern_enum(
quote!(matches!(*bits, #first #(| #rest )*)) quote!(matches!(*bits, #first #(| #rest )*))
}; };
let repr = get_simple_attr(&input.attrs, "repr").unwrap(); // should be checked in attr check already let repr = get_repr(&input.attrs);
let integer_ty = repr.repr.as_integer_type().unwrap(); // should be checked in attr check already
Ok(( Ok((
quote!(), quote!(),
quote! { quote! {
type Bits = #repr; type Bits = #integer_ty;
#[inline] #[inline]
#[allow(clippy::double_comparisons)] #[allow(clippy::double_comparisons)]
@ -544,8 +553,199 @@ fn get_simple_attr(attributes: &[Attribute], attr_name: &str) -> Option<Ident> {
None None
} }
fn get_repr(attributes: &[Attribute]) -> Option<String> { #[derive(Clone, Copy)]
get_simple_attr(attributes, "repr").map(|ident| ident.to_string()) struct Representation {
packed: Option<u32>,
repr: Repr,
}
impl ToTokens for Representation {
fn to_tokens(&self, tokens: &mut TokenStream) {
let repr = match self.repr {
Repr::Rust => None,
Repr::C => Some("C"),
Repr::Transparent => Some("transparent"),
Repr::U8 => Some("u8"),
Repr::I8 => Some("i8"),
Repr::U16 => Some("u16"),
Repr::I16 => Some("i16"),
Repr::U32 => Some("u32"),
Repr::I32 => Some("i32"),
Repr::U64 => Some("u64"),
Repr::I64 => Some("i64"),
Repr::I128 => Some("i128"),
Repr::U128 => Some("u128"),
};
if let Some(repr) = repr {
let ident = Ident::new(repr, Span::call_site());
tokens.extend(quote! {
#[repr(#ident)]
});
}
if let Some(packed) = self.packed {
let lit = LitInt::new(&packed.to_string(), Span::call_site());
tokens.extend(quote! {
#[repr(packed(#lit))]
});
}
}
}
#[derive(Clone, Copy)]
enum Repr {
Rust,
C,
Transparent,
U8,
I8,
U16,
I16,
U32,
I32,
U64,
I64,
I128,
U128,
}
impl Repr {
fn is_integer(&self) -> bool {
match *self {
Repr::Rust | Repr::C | Repr::Transparent => false,
Repr::U8
| Repr::I8
| Repr::U16
| Repr::I16
| Repr::U32
| Repr::I32
| Repr::U64
| Repr::I64
| Repr::I128
| Repr::U128 => true,
}
}
fn as_integer_type(&self) -> Option<TokenStream> {
match self {
Repr::Rust | Repr::C | Repr::Transparent => None,
Repr::U8 => Some(quote! { ::core::primitive::u8 }),
Repr::I8 => Some(quote! { ::core::primitive::i8 }),
Repr::U16 => Some(quote! { ::core::primitive::u16 }),
Repr::I16 => Some(quote! { ::core::primitive::i16 }),
Repr::U32 => Some(quote! { ::core::primitive::u32 }),
Repr::I32 => Some(quote! { ::core::primitive::i32 }),
Repr::U64 => Some(quote! { ::core::primitive::u64 }),
Repr::I64 => Some(quote! { ::core::primitive::i64 }),
Repr::I128 => Some(quote! { ::core::primitive::u128 }),
Repr::U128 => Some(quote! { ::core::primitive::i128 }),
}
}
}
fn get_repr(attributes: &[Attribute]) -> Representation {
let mut repr = Representation { packed: None, repr: Repr::Rust };
for attr in attributes {
let meta = if let Ok(meta) = attr.parse_meta() { meta } else { continue };
if !meta.path().is_ident("repr") {
continue;
}
let list = if let Meta::List(list) = meta {
list
} else {
// The other `Meta` variants are illegal for `repr`.
continue;
};
for item in list.nested {
let meta = if let NestedMeta::Meta(meta) = item {
meta
} else {
// Other nested items are illegal for `repr`.
continue;
};
match meta.path() {
path if path.is_ident("C") => {
repr.repr = Repr::C;
}
path if path.is_ident("transparent") => {
repr.repr = Repr::Transparent;
}
path if path.is_ident("u8") => {
repr.repr = Repr::U8;
}
path if path.is_ident("i8") => {
repr.repr = Repr::I8;
}
path if path.is_ident("u16") => {
repr.repr = Repr::U16;
}
path if path.is_ident("i16") => {
repr.repr = Repr::I16;
}
path if path.is_ident("u32") => {
repr.repr = Repr::U32;
}
path if path.is_ident("i32") => {
repr.repr = Repr::I32;
}
path if path.is_ident("u64") => {
repr.repr = Repr::U64;
}
path if path.is_ident("i64") => {
repr.repr = Repr::I64;
}
path if path.is_ident("u128") => {
repr.repr = Repr::U128;
}
path if path.is_ident("i128") => {
repr.repr = Repr::I128;
}
path if path.is_ident("packed") => {
let packed_alignment = match meta {
Meta::Path(_) => 1,
Meta::List(list) => {
if list.nested.len() != 1 {
// `repr(packed(n))` must have exactly one nested item.
continue;
}
let nested = &list.nested[0];
let int_lit = if let NestedMeta::Lit(Lit::Int(int_lit)) = nested {
int_lit
} else {
// The nested item must be an integer literal.
continue;
};
let value = if let Ok(value) = int_lit.base10_parse::<u32>() {
value
} else {
// The literal must be positive and less than 2^29.
continue;
};
value
}
Meta::NameValue(_) => {
// `repr(packed)` doesn't support name value syntax.
continue;
}
};
let new_packed_alignment = match repr.packed {
Some(prev) => u32::min(prev, packed_alignment),
None => packed_alignment,
};
repr.packed = Some(new_packed_alignment);
}
_ => {}
}
}
}
repr
} }
struct VariantDiscriminantIterator<'a, I: Iterator<Item = &'a Variant> + 'a> { struct VariantDiscriminantIterator<'a, I: Iterator<Item = &'a Variant> + 'a> {

View File

@ -12,6 +12,38 @@ struct Test {
b: u16, b: u16,
} }
#[derive(Pod, Zeroable)]
#[repr(C, packed)]
struct GenericPackedStruct<T: Pod> {
a: u32,
b: T,
c: u32,
}
impl<T: Pod> Clone for GenericPackedStruct<T> {
fn clone(&self) -> Self {
*self
}
}
impl<T: Pod> Copy for GenericPackedStruct<T> {}
#[derive(Pod, Zeroable)]
#[repr(C, packed(1))]
struct GenericPackedStructExplicitPackedAlignment<T: Pod> {
a: u32,
b: T,
c: u32,
}
impl<T: Pod> Clone for GenericPackedStructExplicitPackedAlignment<T> {
fn clone(&self) -> Self {
*self
}
}
impl<T: Pod> Copy for GenericPackedStructExplicitPackedAlignment<T> {}
#[derive(Zeroable)] #[derive(Zeroable)]
struct ZeroGeneric<T: bytemuck::Zeroable> { struct ZeroGeneric<T: bytemuck::Zeroable> {
a: T, a: T,