diff --git a/Cargo.toml b/Cargo.toml index 0732036..b60e41b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,9 +28,7 @@ unsound_ptr_pod_impl = [] nightly_portable_simd = [] [dependencies] -# use the upper line for testing against bytemuck_derive changes, if any -#bytemuck_derive = { path = "./derive", optional = true } -bytemuck_derive = { version = "1.1", optional = true } +bytemuck_derive = { version = "1.1", path = "derive", optional = true } [package.metadata.docs.rs] # Note(Lokathor): Don't use all-features or it would use `unsound_ptr_pod_impl` too. diff --git a/derive/src/traits.rs b/derive/src/traits.rs index a7791c7..0230ac9 100644 --- a/derive/src/traits.rs +++ b/derive/src/traits.rs @@ -1,9 +1,9 @@ -use proc_macro2::{Ident, TokenStream, TokenTree}; +use proc_macro2::{Ident, Span, TokenStream, TokenTree}; use quote::{quote, quote_spanned, ToTokens}; use syn::{ spanned::Spanned, AttrStyle, Attribute, Data, DataEnum, DataStruct, - DeriveInput, Expr, ExprLit, ExprUnary, Fields, Lit, LitInt, Type, UnOp, - Variant, DataUnion, + DataUnion, DeriveInput, Expr, ExprLit, ExprUnary, Fields, Lit, LitInt, Meta, + NestedMeta, Type, UnOp, Variant, }; pub trait Derivable { @@ -37,13 +37,21 @@ impl Derivable for Pod { } fn asserts(input: &DeriveInput) -> Result { - if !input.generics.params.is_empty() { - return Err("Pod requires cannot be derived for types containing generic parameters because the padding requirements can't be verified for generic structs"); + let repr = get_repr(&input.attrs); + + let completly_packed = repr.packed == Some(1); + + if !completly_packed && !input.generics.params.is_empty() { + return Err("Pod requires cannot be derived for non-packed types containing generic parameters because the padding requirements can't be verified for generic non-packed structs"); } match &input.data { Data::Struct(_) => { - let assert_no_padding = generate_assert_no_padding(input)?; + let assert_no_padding = if !completly_packed { + Some(generate_assert_no_padding(input)?) + } else { + None + }; let assert_fields_are_pod = generate_fields_are_trait(input, Self::ident())?; @@ -51,9 +59,9 @@ impl Derivable for Pod { #assert_no_padding #assert_fields_are_pod )) - }, + } Data::Enum(_) => Err("Deriving Pod is not supported for enums"), - Data::Union(_) => Err("Deriving Pod is not supported for unions") + Data::Union(_) => Err("Deriving Pod is not supported for unions"), } } @@ -61,9 +69,9 @@ impl Derivable for Pod { _ty: &Data, attributes: &[Attribute], ) -> Result<(), &'static str> { let repr = get_repr(attributes); - match repr.as_ref().map(|repr| repr.as_str()) { - Some("C") => Ok(()), - Some("transparent") => Ok(()), + match repr.repr { + Repr::C => Ok(()), + Repr::Transparent => Ok(()), _ => { Err("Pod requires the type to be #[repr(C)] or #[repr(transparent)]") } @@ -119,11 +127,11 @@ impl Derivable for NoUninit { ) -> Result<(), &'static str> { let repr = get_repr(attributes); match ty { - Data::Struct(_) => match repr.as_deref() { - Some("C" | "transparent") => Ok(()), + Data::Struct(_) => match repr.repr { + Repr::C | Repr::Transparent => Ok(()), _ => Err("NoUninit requires the struct to be #[repr(C)] or #[repr(transparent)]"), }, - Data::Enum(_) => if repr.map(|repr| repr.starts_with('u') || repr.starts_with('i')) == Some(true) { + Data::Enum(_) => if repr.repr.is_integer() { Ok(()) } else { Err("NoUninit requires the enum to be an explicit #[repr(Int)]") @@ -178,11 +186,11 @@ impl Derivable for CheckedBitPattern { ) -> Result<(), &'static str> { let repr = get_repr(attributes); match ty { - Data::Struct(_) => match repr.as_deref() { - Some("C" | "transparent") => Ok(()), + Data::Struct(_) => match repr.repr { + Repr::C | Repr::Transparent => Ok(()), _ => Err("CheckedBitPattern derive requires the struct to be #[repr(C)] or #[repr(transparent)]"), }, - Data::Enum(_) => if repr.map(|repr| repr.starts_with('u') || repr.starts_with('i')) == Some(true) { + Data::Enum(_) => if repr.repr.is_integer() { Ok(()) } else { Err("CheckedBitPattern requires the enum to be an explicit #[repr(Int)]") @@ -212,9 +220,9 @@ impl Derivable for CheckedBitPattern { input: &DeriveInput, ) -> Result<(TokenStream, TokenStream), &'static str> { match &input.data { - Data::Struct(DataStruct { fields, .. }) => { - Ok(generate_checked_bit_pattern_struct(&input.ident, fields, &input.attrs)) - } + Data::Struct(DataStruct { fields, .. }) => Ok( + generate_checked_bit_pattern_struct(&input.ident, fields, &input.attrs), + ), Data::Enum(_) => generate_checked_bit_pattern_enum(input), Data::Union(_) => Err("Internal error in CheckedBitPattern derive"), // shouldn't be possible since we already error in attribute check for this case } @@ -277,8 +285,8 @@ impl Derivable for TransparentWrapper { ) -> Result<(), &'static str> { let repr = get_repr(attributes); - match repr.as_ref().map(|repr| repr.as_str()) { - Some("transparent") => Ok(()), + match repr.repr { + Repr::Transparent => Ok(()), _ => { Err("TransparentWrapper requires the struct to be #[repr(transparent)]") } @@ -296,12 +304,13 @@ impl Derivable for Contiguous { fn trait_impl( input: &DeriveInput, ) -> Result<(TokenStream, TokenStream), &'static str> { - let repr = get_repr(&input.attrs) - .ok_or("Contiguous requires the enum to be #[repr(Int)]")?; + let repr = get_repr(&input.attrs); - if !repr.starts_with('u') && !repr.starts_with('i') { + let integer_ty = if let Some(integer_ty) = repr.repr.as_integer_type() { + integer_ty + } else { return Err("Contiguous requires the enum to be #[repr(Int)]"); - } + }; let variants = get_enum_variants(input)?; let mut variants_with_discriminator = @@ -325,16 +334,15 @@ impl Derivable for Contiguous { ); } - let repr_ident = Ident::new(&repr, input.span()); let min_lit = LitInt::new(&format!("{}", min), input.span()); let max_lit = LitInt::new(&format!("{}", max), input.span()); Ok(( quote!(), quote! { - type Int = #repr_ident; - const MIN_VALUE: #repr_ident = #min_lit; - const MAX_VALUE: #repr_ident = #max_lit; + type Int = #integer_ty; + const MIN_VALUE: #integer_ty = #min_lit; + const MAX_VALUE: #integer_ty = #max_lit; }, )) } @@ -352,7 +360,7 @@ fn get_fields(input: &DeriveInput) -> Result { match &input.data { Data::Struct(DataStruct { fields, .. }) => Ok(fields.clone()), Data::Union(DataUnion { fields, .. }) => Ok(Fields::Named(fields.clone())), - Data::Enum(_) => Err("deriving this trait is not supported for enums") + Data::Enum(_) => Err("deriving this trait is not supported for enums"), } } @@ -377,7 +385,7 @@ fn generate_checked_bit_pattern_struct( ) -> (TokenStream, TokenStream) { let bits_ty = Ident::new(&format!("{}Bits", input_ident), input_ident.span()); - let repr = get_simple_attr(attrs, "repr").unwrap(); // should be checked in attr check already + let repr = get_repr(attrs); let field_names = fields .iter() @@ -400,7 +408,7 @@ fn generate_checked_bit_pattern_struct( ( quote! { - #[repr(#repr)] + #repr #[derive(Clone, Copy, ::bytemuck::AnyBitPattern)] #derive_dbg pub struct #bits_ty { @@ -459,11 +467,12 @@ fn generate_checked_bit_pattern_enum( quote!(matches!(*bits, #first #(| #rest )*)) }; - let repr = get_simple_attr(&input.attrs, "repr").unwrap(); // should be checked in attr check already + let repr = get_repr(&input.attrs); + let integer_ty = repr.repr.as_integer_type().unwrap(); // should be checked in attr check already Ok(( quote!(), quote! { - type Bits = #repr; + type Bits = #integer_ty; #[inline] #[allow(clippy::double_comparisons)] @@ -544,8 +553,199 @@ fn get_simple_attr(attributes: &[Attribute], attr_name: &str) -> Option { None } -fn get_repr(attributes: &[Attribute]) -> Option { - get_simple_attr(attributes, "repr").map(|ident| ident.to_string()) +#[derive(Clone, Copy)] +struct Representation { + packed: Option, + repr: Repr, +} + +impl ToTokens for Representation { + fn to_tokens(&self, tokens: &mut TokenStream) { + let repr = match self.repr { + Repr::Rust => None, + Repr::C => Some("C"), + Repr::Transparent => Some("transparent"), + Repr::U8 => Some("u8"), + Repr::I8 => Some("i8"), + Repr::U16 => Some("u16"), + Repr::I16 => Some("i16"), + Repr::U32 => Some("u32"), + Repr::I32 => Some("i32"), + Repr::U64 => Some("u64"), + Repr::I64 => Some("i64"), + Repr::I128 => Some("i128"), + Repr::U128 => Some("u128"), + }; + if let Some(repr) = repr { + let ident = Ident::new(repr, Span::call_site()); + tokens.extend(quote! { + #[repr(#ident)] + }); + } + + if let Some(packed) = self.packed { + let lit = LitInt::new(&packed.to_string(), Span::call_site()); + tokens.extend(quote! { + #[repr(packed(#lit))] + }); + } + } +} + +#[derive(Clone, Copy)] +enum Repr { + Rust, + C, + Transparent, + U8, + I8, + U16, + I16, + U32, + I32, + U64, + I64, + I128, + U128, +} + +impl Repr { + fn is_integer(&self) -> bool { + match *self { + Repr::Rust | Repr::C | Repr::Transparent => false, + Repr::U8 + | Repr::I8 + | Repr::U16 + | Repr::I16 + | Repr::U32 + | Repr::I32 + | Repr::U64 + | Repr::I64 + | Repr::I128 + | Repr::U128 => true, + } + } + + fn as_integer_type(&self) -> Option { + match self { + Repr::Rust | Repr::C | Repr::Transparent => None, + Repr::U8 => Some(quote! { ::core::primitive::u8 }), + Repr::I8 => Some(quote! { ::core::primitive::i8 }), + Repr::U16 => Some(quote! { ::core::primitive::u16 }), + Repr::I16 => Some(quote! { ::core::primitive::i16 }), + Repr::U32 => Some(quote! { ::core::primitive::u32 }), + Repr::I32 => Some(quote! { ::core::primitive::i32 }), + Repr::U64 => Some(quote! { ::core::primitive::u64 }), + Repr::I64 => Some(quote! { ::core::primitive::i64 }), + Repr::I128 => Some(quote! { ::core::primitive::u128 }), + Repr::U128 => Some(quote! { ::core::primitive::i128 }), + } + } +} + +fn get_repr(attributes: &[Attribute]) -> Representation { + let mut repr = Representation { packed: None, repr: Repr::Rust }; + + for attr in attributes { + let meta = if let Ok(meta) = attr.parse_meta() { meta } else { continue }; + if !meta.path().is_ident("repr") { + continue; + } + let list = if let Meta::List(list) = meta { + list + } else { + // The other `Meta` variants are illegal for `repr`. + continue; + }; + + for item in list.nested { + let meta = if let NestedMeta::Meta(meta) = item { + meta + } else { + // Other nested items are illegal for `repr`. + continue; + }; + + match meta.path() { + path if path.is_ident("C") => { + repr.repr = Repr::C; + } + path if path.is_ident("transparent") => { + repr.repr = Repr::Transparent; + } + path if path.is_ident("u8") => { + repr.repr = Repr::U8; + } + path if path.is_ident("i8") => { + repr.repr = Repr::I8; + } + path if path.is_ident("u16") => { + repr.repr = Repr::U16; + } + path if path.is_ident("i16") => { + repr.repr = Repr::I16; + } + path if path.is_ident("u32") => { + repr.repr = Repr::U32; + } + path if path.is_ident("i32") => { + repr.repr = Repr::I32; + } + path if path.is_ident("u64") => { + repr.repr = Repr::U64; + } + path if path.is_ident("i64") => { + repr.repr = Repr::I64; + } + path if path.is_ident("u128") => { + repr.repr = Repr::U128; + } + path if path.is_ident("i128") => { + repr.repr = Repr::I128; + } + path if path.is_ident("packed") => { + let packed_alignment = match meta { + Meta::Path(_) => 1, + Meta::List(list) => { + if list.nested.len() != 1 { + // `repr(packed(n))` must have exactly one nested item. + continue; + } + + let nested = &list.nested[0]; + let int_lit = if let NestedMeta::Lit(Lit::Int(int_lit)) = nested { + int_lit + } else { + // The nested item must be an integer literal. + continue; + }; + + let value = if let Ok(value) = int_lit.base10_parse::() { + value + } else { + // The literal must be positive and less than 2^29. + continue; + }; + value + } + Meta::NameValue(_) => { + // `repr(packed)` doesn't support name value syntax. + continue; + } + }; + + let new_packed_alignment = match repr.packed { + Some(prev) => u32::min(prev, packed_alignment), + None => packed_alignment, + }; + repr.packed = Some(new_packed_alignment); + } + _ => {} + } + } + } + + repr } struct VariantDiscriminantIterator<'a, I: Iterator + 'a> { diff --git a/derive/tests/basic.rs b/derive/tests/basic.rs index 456266d..3e164bb 100644 --- a/derive/tests/basic.rs +++ b/derive/tests/basic.rs @@ -12,6 +12,38 @@ struct Test { b: u16, } +#[derive(Pod, Zeroable)] +#[repr(C, packed)] +struct GenericPackedStruct { + a: u32, + b: T, + c: u32, +} + +impl Clone for GenericPackedStruct { + fn clone(&self) -> Self { + *self + } +} + +impl Copy for GenericPackedStruct {} + +#[derive(Pod, Zeroable)] +#[repr(C, packed(1))] +struct GenericPackedStructExplicitPackedAlignment { + a: u32, + b: T, + c: u32, +} + +impl Clone for GenericPackedStructExplicitPackedAlignment { + fn clone(&self) -> Self { + *self + } +} + +impl Copy for GenericPackedStructExplicitPackedAlignment {} + #[derive(Zeroable)] struct ZeroGeneric { a: T,