mirror of
https://github.com/Lokathor/bytemuck.git
synced 2024-11-25 00:02:22 +00:00
allow deriving CheckedBitPattern
for enums with fields (#171)
* simplify `ToTokens` impl for `Representation` Instead of collecting the representation and modifier into `Option`s and determining whether a comma is needed manually, we can use the `Puncutuated` struct which handles commas automatically. This will also make emitting the `align` modifier in the future easier. * emit alignment modifier This is required for correctly implementing `CheckedBitPattern` because we need the layout of the type and its `Bits` type to have the same layout. * add unit test for `#[repr]` parsing * allow multiple alignment modifiers According to RFC #1358, if multiple alignment modifiers are specified, the resulting alignment is the maximum of all alignment modifiers. * actually return the error we just created * factor out the integer Repr's into their own type This is a preparation step for adding support for `#[repr(C, int)]`. * allow parsing `#[repr(C, int)]` This can be used on enums with fields. * derive `CheckedBitPattern` for enums with fields The implementation mostly mirrors the desugaring described at https://doc.rust-lang.org/reference/type-layout.html * add comments and rename some idents * update error message * update docs for `CheckedBitPattern` derive * add new nested test case, change generated type naming scheme * fix wrong comment * small nit --------- Co-authored-by: Gray Olson <gray@grayolson.com>
This commit is contained in:
parent
ff0b14dae9
commit
d10fbfc6ff
@ -218,7 +218,7 @@ pub fn derive_zeroable(
|
||||
/// - The struct must contain no generic parameters
|
||||
///
|
||||
/// If applied to an enum:
|
||||
/// - The enum must be explicit `#[repr(Int)]`
|
||||
/// - The enum must be explicit `#[repr(Int)]`, `#[repr(C)]`, or both
|
||||
/// - All variants must be fieldless
|
||||
/// - The enum must contain no generic parameters
|
||||
#[proc_macro_derive(NoUninit)]
|
||||
@ -237,16 +237,17 @@ pub fn derive_no_uninit(
|
||||
/// for the `CheckedBitPattern` trait and derives the required `Bits` type
|
||||
/// definition and `is_valid_bit_pattern` method for the type automatically.
|
||||
///
|
||||
/// The following constraints need to be satisfied for the macro to succeed
|
||||
/// (the rest of the constraints are guaranteed by the `CheckedBitPattern`
|
||||
/// subtrait bounds, i.e. are guaranteed by the requirements of the `NoUninit`
|
||||
/// trait which `CheckedBitPattern` is a subtrait of):
|
||||
/// The following constraints need to be satisfied for the macro to succeed:
|
||||
///
|
||||
/// If applied to a struct:
|
||||
/// - All fields must implement `CheckedBitPattern`
|
||||
/// - The struct must be `#[repr(C)]` or `#[repr(transparent)]`
|
||||
/// - The struct must contain no generic parameters
|
||||
///
|
||||
/// If applied to an enum:
|
||||
/// - All requirements already checked by `NoUninit`, just impls the trait
|
||||
/// - The enum must be explicit `#[repr(Int)]`
|
||||
/// - All fields in variants must implement `CheckedBitPattern`
|
||||
/// - The enum must contain no generic parameters
|
||||
#[proc_macro_derive(CheckedBitPattern)]
|
||||
pub fn derive_maybe_pod(
|
||||
input: proc_macro::TokenStream,
|
||||
|
@ -1,4 +1,6 @@
|
||||
#![allow(unused_imports)]
|
||||
use std::{cmp, convert::TryFrom};
|
||||
|
||||
use proc_macro2::{Ident, Span, TokenStream, TokenTree};
|
||||
use quote::{quote, quote_spanned, ToTokens};
|
||||
use syn::{
|
||||
@ -204,11 +206,19 @@ impl Derivable for CheckedBitPattern {
|
||||
Repr::C | Repr::Transparent => Ok(()),
|
||||
_ => bail!("CheckedBitPattern derive requires the struct to be #[repr(C)] or #[repr(transparent)]"),
|
||||
},
|
||||
Data::Enum(_) => if repr.repr.is_integer() {
|
||||
Data::Enum(DataEnum { variants,.. }) => {
|
||||
if !enum_has_fields(variants.iter()){
|
||||
if repr.repr.is_integer() {
|
||||
Ok(())
|
||||
} else {
|
||||
bail!("CheckedBitPattern requires the enum to be an explicit #[repr(Int)]")
|
||||
},
|
||||
}
|
||||
} else if matches!(repr.repr, Repr::Rust) {
|
||||
bail!("CheckedBitPattern requires an explicit repr annotation because `repr(Rust)` doesn't have a specified type layout")
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
Data::Union(_) => bail!("CheckedBitPattern can only be derived on enums and structs")
|
||||
}
|
||||
}
|
||||
@ -235,7 +245,9 @@ impl Derivable for CheckedBitPattern {
|
||||
Data::Struct(DataStruct { fields, .. }) => {
|
||||
generate_checked_bit_pattern_struct(&input.ident, fields, &input.attrs)
|
||||
}
|
||||
Data::Enum(_) => generate_checked_bit_pattern_enum(input),
|
||||
Data::Enum(DataEnum { variants, .. }) => {
|
||||
generate_checked_bit_pattern_enum(input, variants)
|
||||
}
|
||||
Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), /* shouldn't be possible since we already error in attribute check for this case */
|
||||
}
|
||||
}
|
||||
@ -347,13 +359,20 @@ impl Derivable for Contiguous {
|
||||
fn trait_impl(input: &DeriveInput) -> Result<(TokenStream, TokenStream)> {
|
||||
let repr = get_repr(&input.attrs)?;
|
||||
|
||||
let integer_ty = if let Some(integer_ty) = repr.repr.as_integer_type() {
|
||||
let integer_ty = if let Some(integer_ty) = repr.repr.as_integer() {
|
||||
integer_ty
|
||||
} else {
|
||||
bail!("Contiguous requires the enum to be #[repr(Int)]");
|
||||
};
|
||||
|
||||
let variants = get_enum_variants(input)?;
|
||||
if enum_has_fields(variants.clone()) {
|
||||
return Err(Error::new_spanned(
|
||||
&input,
|
||||
"Only fieldless enums are supported",
|
||||
));
|
||||
}
|
||||
|
||||
let mut variants_with_discriminator =
|
||||
VariantDiscriminantIterator::new(variants);
|
||||
|
||||
@ -426,7 +445,7 @@ fn get_fields(input: &DeriveInput) -> Result<Fields> {
|
||||
|
||||
fn get_enum_variants<'a>(
|
||||
input: &'a DeriveInput,
|
||||
) -> Result<impl Iterator<Item = &'a Variant> + 'a> {
|
||||
) -> Result<impl Iterator<Item = &'a Variant> + Clone + 'a> {
|
||||
if let Data::Enum(DataEnum { variants, .. }) = &input.data {
|
||||
Ok(variants.iter())
|
||||
} else {
|
||||
@ -486,11 +505,21 @@ fn generate_checked_bit_pattern_struct(
|
||||
}
|
||||
|
||||
fn generate_checked_bit_pattern_enum(
|
||||
input: &DeriveInput,
|
||||
input: &DeriveInput, variants: &Punctuated<Variant, Token![,]>,
|
||||
) -> Result<(TokenStream, TokenStream)> {
|
||||
if enum_has_fields(variants.iter()) {
|
||||
generate_checked_bit_pattern_enum_with_fields(input, variants)
|
||||
} else {
|
||||
generate_checked_bit_pattern_enum_without_fields(input, variants)
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_checked_bit_pattern_enum_without_fields(
|
||||
input: &DeriveInput, variants: &Punctuated<Variant, Token![,]>,
|
||||
) -> Result<(TokenStream, TokenStream)> {
|
||||
let span = input.span();
|
||||
let mut variants_with_discriminant =
|
||||
VariantDiscriminantIterator::new(get_enum_variants(input)?);
|
||||
VariantDiscriminantIterator::new(variants.iter());
|
||||
|
||||
let (min, max, count) = variants_with_discriminant.try_fold(
|
||||
(i64::max_value(), i64::min_value(), 0),
|
||||
@ -514,8 +543,7 @@ fn generate_checked_bit_pattern_enum(
|
||||
quote!(*bits >= #min_lit && *bits <= #max_lit)
|
||||
} else {
|
||||
// not contiguous range, check for each
|
||||
let variant_lits =
|
||||
VariantDiscriminantIterator::new(get_enum_variants(input)?)
|
||||
let variant_lits = VariantDiscriminantIterator::new(variants.iter())
|
||||
.map(|res| {
|
||||
let variant = res?;
|
||||
Ok(LitInt::new(&format!("{}", variant), span))
|
||||
@ -530,11 +558,11 @@ fn generate_checked_bit_pattern_enum(
|
||||
};
|
||||
|
||||
let repr = get_repr(&input.attrs)?;
|
||||
let integer_ty = repr.repr.as_integer_type().unwrap(); // should be checked in attr check already
|
||||
let integer = repr.repr.as_integer().unwrap(); // should be checked in attr check already
|
||||
Ok((
|
||||
quote!(),
|
||||
quote! {
|
||||
type Bits = #integer_ty;
|
||||
type Bits = #integer;
|
||||
|
||||
#[inline]
|
||||
#[allow(clippy::double_comparisons)]
|
||||
@ -545,6 +573,244 @@ fn generate_checked_bit_pattern_enum(
|
||||
))
|
||||
}
|
||||
|
||||
fn generate_checked_bit_pattern_enum_with_fields(
|
||||
input: &DeriveInput, variants: &Punctuated<Variant, Token![,]>,
|
||||
) -> Result<(TokenStream, TokenStream)> {
|
||||
let representation = get_repr(&input.attrs)?;
|
||||
let vis = &input.vis;
|
||||
|
||||
let derive_dbg =
|
||||
quote!(#[cfg_attr(not(target_arch = "spirv"), derive(Debug))]);
|
||||
|
||||
match representation.repr {
|
||||
Repr::Rust => unreachable!(),
|
||||
repr @ (Repr::C | Repr::CWithDiscriminant(_)) => {
|
||||
let integer = match repr {
|
||||
Repr::C => quote!(::core::ffi::c_int),
|
||||
Repr::CWithDiscriminant(integer) => quote!(#integer),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let input_ident = &input.ident;
|
||||
|
||||
let bits_repr = Representation { repr: Repr::C, ..representation };
|
||||
|
||||
// the enum manually re-configured as the actual tagged union it represents,
|
||||
// thus circumventing the requirements rust imposes on the tag even when using
|
||||
// #[repr(C)] enum layout
|
||||
// see: https://doc.rust-lang.org/reference/type-layout.html#reprc-enums-with-fields
|
||||
let bits_ty_ident = Ident::new(&format!("{input_ident}Bits"), input.span());
|
||||
|
||||
// the variants union part of the tagged union. These get put into a union which gets the
|
||||
// AnyBitPattern derive applied to it, thus checking that the fields of the union obey the requriements of AnyBitPattern.
|
||||
// The types that actually go in the union are one more level of indirection deep: we generate new structs for each variant
|
||||
// (`variant_struct_definitions`) which themselves have the `CheckedBitPattern` derive applied, thus generating `{variant_struct_ident}Bits`
|
||||
// structs, which are the ones that go into this union.
|
||||
let variants_union_ident =
|
||||
Ident::new(&format!("{}Variants", input.ident), input.span());
|
||||
|
||||
let variant_struct_idents = variants
|
||||
.iter()
|
||||
.map(|v| Ident::new(&format!("{input_ident}Variant{}", v.ident), v.span()));
|
||||
|
||||
let variant_struct_definitions =
|
||||
variant_struct_idents.clone().zip(variants.iter()).map(|(variant_struct_ident, v)| {
|
||||
let fields = v.fields.iter().map(|v| &v.ty);
|
||||
|
||||
quote! {
|
||||
#[derive(::core::clone::Clone, ::core::marker::Copy, ::bytemuck::CheckedBitPattern)]
|
||||
#[repr(C)]
|
||||
#vis struct #variant_struct_ident(#(#fields),*);
|
||||
}
|
||||
});
|
||||
|
||||
let union_fields =
|
||||
variant_struct_idents.clone().zip(variants.iter()).map(|(variant_struct_ident, v)| {
|
||||
let variant_struct_bits_ident =
|
||||
Ident::new(&format!("{variant_struct_ident}Bits"), input.span());
|
||||
let field_ident = &v.ident;
|
||||
quote! {
|
||||
#field_ident: #variant_struct_bits_ident
|
||||
}
|
||||
});
|
||||
|
||||
let variant_checks = variant_struct_idents
|
||||
.clone()
|
||||
.zip(VariantDiscriminantIterator::new(variants.iter()))
|
||||
.zip(variants.iter())
|
||||
.map(|((variant_struct_ident, discriminant), v)| -> Result<_> {
|
||||
let discriminant = discriminant?;
|
||||
let discriminant = LitInt::new(&discriminant.to_string(), v.span());
|
||||
let ident = &v.ident;
|
||||
Ok(quote! {
|
||||
#discriminant => {
|
||||
let payload = unsafe { &bits.payload.#ident };
|
||||
<#variant_struct_ident as ::bytemuck::CheckedBitPattern>::is_valid_bit_pattern(payload)
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
Ok((
|
||||
quote! {
|
||||
#[derive(::core::clone::Clone, ::core::marker::Copy, ::bytemuck::AnyBitPattern)]
|
||||
#derive_dbg
|
||||
#bits_repr
|
||||
#vis struct #bits_ty_ident {
|
||||
tag: #integer,
|
||||
payload: #variants_union_ident,
|
||||
}
|
||||
|
||||
#[derive(::core::clone::Clone, ::core::marker::Copy, ::bytemuck::AnyBitPattern)]
|
||||
#[repr(C)]
|
||||
#[allow(non_snake_case)]
|
||||
#vis union #variants_union_ident {
|
||||
#(#union_fields,)*
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "spirv"))]
|
||||
impl ::core::fmt::Debug for #variants_union_ident {
|
||||
fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
|
||||
let mut debug_struct = ::core::fmt::Formatter::debug_struct(f, ::core::stringify!(#variants_union_ident));
|
||||
::core::fmt::DebugStruct::finish_non_exhaustive(&mut debug_struct)
|
||||
}
|
||||
}
|
||||
|
||||
#(#variant_struct_definitions)*
|
||||
},
|
||||
quote! {
|
||||
type Bits = #bits_ty_ident;
|
||||
|
||||
#[inline]
|
||||
#[allow(clippy::double_comparisons)]
|
||||
fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
|
||||
match bits.tag {
|
||||
#(#variant_checks)*
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
},
|
||||
))
|
||||
}
|
||||
Repr::Transparent => {
|
||||
if variants.len() != 1 {
|
||||
bail!("enums with more than one variant cannot be transparent")
|
||||
}
|
||||
|
||||
let variant = &variants[0];
|
||||
|
||||
let bits_ty = Ident::new(&format!("{}Bits", input.ident), input.span());
|
||||
let fields = variant.fields.iter().map(|v| &v.ty);
|
||||
|
||||
Ok((
|
||||
quote! {
|
||||
#[derive(::core::clone::Clone, ::core::marker::Copy, ::bytemuck::CheckedBitPattern)]
|
||||
#[repr(C)]
|
||||
#vis struct #bits_ty(#(#fields),*);
|
||||
},
|
||||
quote! {
|
||||
type Bits = <#bits_ty as ::bytemuck::CheckedBitPattern>::Bits;
|
||||
|
||||
#[inline]
|
||||
#[allow(clippy::double_comparisons)]
|
||||
fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
|
||||
<#bits_ty as ::bytemuck::CheckedBitPattern>::is_valid_bit_pattern(bits)
|
||||
}
|
||||
},
|
||||
))
|
||||
}
|
||||
Repr::Integer(integer) => {
|
||||
let bits_repr = Representation { repr: Repr::C, ..representation };
|
||||
let input_ident = &input.ident;
|
||||
|
||||
// the enum manually re-configured as the union it represents. such a union is the union of variants
|
||||
// as a repr(c) struct with the discriminator type inserted at the beginning.
|
||||
// in our case we union the `Bits` representation of each variant rather than the variant itself, which we generate
|
||||
// via a nested `CheckedBitPattern` derive on the `variant_struct_definitions` generated below.
|
||||
//
|
||||
// see: https://doc.rust-lang.org/reference/type-layout.html#primitive-representation-of-enums-with-fields
|
||||
let bits_ty_ident = Ident::new(&format!("{input_ident}Bits"), input.span());
|
||||
|
||||
let variant_struct_idents = variants
|
||||
.iter()
|
||||
.map(|v| Ident::new(&format!("{input_ident}Variant{}", v.ident), v.span()));
|
||||
|
||||
let variant_struct_definitions =
|
||||
variant_struct_idents.clone().zip(variants.iter()).map(|(variant_struct_ident, v)| {
|
||||
let fields = v.fields.iter().map(|v| &v.ty);
|
||||
|
||||
// adding the discriminant repr integer as first field, as described above
|
||||
quote! {
|
||||
#[derive(::core::clone::Clone, ::core::marker::Copy, ::bytemuck::CheckedBitPattern)]
|
||||
#[repr(C)]
|
||||
#vis struct #variant_struct_ident(#integer, #(#fields),*);
|
||||
}
|
||||
});
|
||||
|
||||
let union_fields =
|
||||
variant_struct_idents.clone().zip(variants.iter()).map(|(variant_struct_ident, v)| {
|
||||
let variant_struct_bits_ident =
|
||||
Ident::new(&format!("{variant_struct_ident}Bits"), input.span());
|
||||
let field_ident = &v.ident;
|
||||
quote! {
|
||||
#field_ident: #variant_struct_bits_ident
|
||||
}
|
||||
});
|
||||
|
||||
let variant_checks = variant_struct_idents
|
||||
.clone()
|
||||
.zip(VariantDiscriminantIterator::new(variants.iter()))
|
||||
.zip(variants.iter())
|
||||
.map(|((variant_struct_ident, discriminant), v)| -> Result<_> {
|
||||
let discriminant = discriminant?;
|
||||
let discriminant = LitInt::new(&discriminant.to_string(), v.span());
|
||||
let ident = &v.ident;
|
||||
Ok(quote! {
|
||||
#discriminant => {
|
||||
let payload = unsafe { &bits.#ident };
|
||||
<#variant_struct_ident as ::bytemuck::CheckedBitPattern>::is_valid_bit_pattern(payload)
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
Ok((
|
||||
quote! {
|
||||
#[derive(::core::clone::Clone, ::core::marker::Copy, ::bytemuck::AnyBitPattern)]
|
||||
#bits_repr
|
||||
#[allow(non_snake_case)]
|
||||
#vis union #bits_ty_ident {
|
||||
__tag: #integer,
|
||||
#(#union_fields,)*
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "spirv"))]
|
||||
impl ::core::fmt::Debug for #bits_ty_ident {
|
||||
fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
|
||||
let mut debug_struct = ::core::fmt::Formatter::debug_struct(f, ::core::stringify!(#bits_ty_ident));
|
||||
::core::fmt::DebugStruct::field(&mut debug_struct, "tag", unsafe { &self.__tag });
|
||||
::core::fmt::DebugStruct::finish_non_exhaustive(&mut debug_struct)
|
||||
}
|
||||
}
|
||||
|
||||
#(#variant_struct_definitions)*
|
||||
},
|
||||
quote! {
|
||||
type Bits = #bits_ty_ident;
|
||||
|
||||
#[inline]
|
||||
#[allow(clippy::double_comparisons)]
|
||||
fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
|
||||
match unsafe { bits.__tag } {
|
||||
#(#variant_checks)*
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
},
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Check that a struct has no padding by asserting that the size of the struct
|
||||
/// is equal to the sum of the size of it's fields
|
||||
fn generate_assert_no_padding(input: &DeriveInput) -> Result<TokenStream> {
|
||||
@ -637,9 +903,9 @@ fn get_repr(attributes: &[Attribute]) -> Result<Representation> {
|
||||
_ => bail!("conflicting representation hints"),
|
||||
},
|
||||
align: match (a.align, b.align) {
|
||||
(Some(a), Some(b)) => Some(cmp::max(a, b)),
|
||||
(a, None) => a,
|
||||
(None, b) => b,
|
||||
_ => bail!("conflicting representation hints"),
|
||||
},
|
||||
})
|
||||
})
|
||||
@ -665,46 +931,73 @@ macro_rules! mk_repr {(
|
||||
$Xn:ident => $xn:ident
|
||||
),* $(,)?
|
||||
) => (
|
||||
#[derive(Clone, Copy, PartialEq)]
|
||||
enum Repr {
|
||||
Rust,
|
||||
C,
|
||||
Transparent,
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum IntegerRepr {
|
||||
$($Xn),*
|
||||
}
|
||||
|
||||
impl Repr {
|
||||
fn is_integer(self) -> bool {
|
||||
match self {
|
||||
Repr::Rust | Repr::C | Repr::Transparent => false,
|
||||
_ => true,
|
||||
impl<'a> TryFrom<&'a str> for IntegerRepr {
|
||||
type Error = &'a str;
|
||||
|
||||
fn try_from(value: &'a str) -> std::result::Result<Self, &'a str> {
|
||||
match value {
|
||||
$(
|
||||
stringify!($xn) => Ok(Self::$Xn),
|
||||
)*
|
||||
_ => Err(value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn as_integer_type(self) -> Option<TokenStream> {
|
||||
impl ToTokens for IntegerRepr {
|
||||
fn to_tokens(&self, tokens: &mut TokenStream) {
|
||||
match self {
|
||||
Repr::Rust | Repr::C | Repr::Transparent => None,
|
||||
$(
|
||||
Repr::$Xn => Some(quote! { ::core::primitive::$xn }),
|
||||
Self::$Xn => tokens.extend(quote!($xn)),
|
||||
)*
|
||||
}
|
||||
}
|
||||
}
|
||||
)}
|
||||
use mk_repr;
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct Representation {
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum Repr {
|
||||
Rust,
|
||||
C,
|
||||
Transparent,
|
||||
Integer(IntegerRepr),
|
||||
CWithDiscriminant(IntegerRepr),
|
||||
}
|
||||
|
||||
impl Repr {
|
||||
fn is_integer(&self) -> bool {
|
||||
matches!(self, Self::Integer(..))
|
||||
}
|
||||
|
||||
fn as_integer(&self) -> Option<IntegerRepr> {
|
||||
if let Self::Integer(v) = self {
|
||||
Some(*v)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
struct Representation {
|
||||
packed: Option<u32>,
|
||||
align: Option<u32>,
|
||||
repr: Repr,
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Representation {
|
||||
impl Default for Representation {
|
||||
fn default() -> Self {
|
||||
Self { packed: None, align: None, repr: Repr::Rust }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Parse for Representation {
|
||||
impl Parse for Representation {
|
||||
fn parse(input: ParseStream<'_>) -> Result<Representation> {
|
||||
let mut ret = Representation::default();
|
||||
while !input.is_empty() {
|
||||
@ -716,60 +1009,90 @@ macro_rules! mk_repr {(
|
||||
"transparent" => Repr::Transparent,
|
||||
"packed" => {
|
||||
ret.packed = Some(if input.peek(token::Paren) {
|
||||
let contents; parenthesized!(contents in input);
|
||||
let contents;
|
||||
parenthesized!(contents in input);
|
||||
LitInt::base10_parse::<u32>(&contents.parse()?)?
|
||||
} else {
|
||||
1
|
||||
});
|
||||
let _: Option<Token![,]> = input.parse()?;
|
||||
continue;
|
||||
},
|
||||
}
|
||||
"align" => {
|
||||
let contents; parenthesized!(contents in input);
|
||||
ret.align = Some(LitInt::base10_parse::<u32>(&contents.parse()?)?);
|
||||
let contents;
|
||||
parenthesized!(contents in input);
|
||||
let new_align = LitInt::base10_parse::<u32>(&contents.parse()?)?;
|
||||
ret.align = Some(
|
||||
ret
|
||||
.align
|
||||
.map_or(new_align, |old_align| cmp::max(old_align, new_align)),
|
||||
);
|
||||
let _: Option<Token![,]> = input.parse()?;
|
||||
continue;
|
||||
},
|
||||
$(
|
||||
stringify!($xn) => Repr::$Xn,
|
||||
)*
|
||||
_ => return Err(input.error("unrecognized representation hint"))
|
||||
};
|
||||
if ::core::mem::replace(&mut ret.repr, new_repr) != Repr::Rust {
|
||||
input.error("duplicate representation hint");
|
||||
}
|
||||
ident => {
|
||||
let primitive = IntegerRepr::try_from(ident)
|
||||
.map_err(|_| input.error("unrecognized representation hint"))?;
|
||||
Repr::Integer(primitive)
|
||||
}
|
||||
};
|
||||
ret.repr = match (ret.repr, new_repr) {
|
||||
(Repr::Rust, new_repr) => {
|
||||
// This is the first explicit repr.
|
||||
new_repr
|
||||
}
|
||||
(Repr::C, Repr::Integer(integer))
|
||||
| (Repr::Integer(integer), Repr::C) => {
|
||||
// Both the C repr and an integer repr have been specified
|
||||
// -> merge into a C wit discriminant.
|
||||
Repr::CWithDiscriminant(integer)
|
||||
}
|
||||
(_, _) => {
|
||||
return Err(input.error("duplicate representation hint"));
|
||||
}
|
||||
};
|
||||
let _: Option<Token![,]> = input.parse()?;
|
||||
}
|
||||
Ok(ret)
|
||||
}
|
||||
}
|
||||
|
||||
impl ToTokens for Representation {
|
||||
fn to_tokens(&self, tokens: &mut TokenStream) {
|
||||
let mut meta = Punctuated::<_, Token![,]>::new();
|
||||
|
||||
match self.repr {
|
||||
Repr::Rust => {}
|
||||
Repr::C => meta.push(quote!(C)),
|
||||
Repr::Transparent => meta.push(quote!(transparent)),
|
||||
Repr::Integer(primitive) => meta.push(quote!(#primitive)),
|
||||
Repr::CWithDiscriminant(primitive) => {
|
||||
meta.push(quote!(C));
|
||||
meta.push(quote!(#primitive));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(packed) = self.packed.as_ref() {
|
||||
let lit = LitInt::new(&packed.to_string(), Span::call_site());
|
||||
meta.push(quote!(packed(#lit)));
|
||||
}
|
||||
|
||||
if let Some(align) = self.align.as_ref() {
|
||||
let lit = LitInt::new(&align.to_string(), Span::call_site());
|
||||
meta.push(quote!(align(#lit)));
|
||||
}
|
||||
|
||||
impl ToTokens for Representation {
|
||||
fn to_tokens(&self, tokens: &mut TokenStream) {
|
||||
let repr = match self.repr {
|
||||
Repr::Rust => None,
|
||||
Repr::C => Some(quote!(C)),
|
||||
Repr::Transparent => Some(quote!(transparent)),
|
||||
$(
|
||||
Repr::$Xn => Some(quote!($xn)),
|
||||
)*
|
||||
};
|
||||
let packed = self.packed.map(|p| {
|
||||
let lit = LitInt::new(&p.to_string(), Span::call_site());
|
||||
quote!(packed(#lit))
|
||||
});
|
||||
let comma = if packed.is_some() && repr.is_some() {
|
||||
Some(quote!(,))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
tokens.extend(quote!(
|
||||
#[repr( #repr #comma #packed )]
|
||||
#[repr(#meta)]
|
||||
));
|
||||
}
|
||||
}
|
||||
)}
|
||||
use mk_repr;
|
||||
}
|
||||
|
||||
fn enum_has_fields<'a>(
|
||||
mut variants: impl Iterator<Item = &'a Variant>,
|
||||
) -> bool {
|
||||
variants.any(|v| matches!(v.fields, Fields::Named(_) | Fields::Unnamed(_)))
|
||||
}
|
||||
|
||||
struct VariantDiscriminantIterator<'a, I: Iterator<Item = &'a Variant> + 'a> {
|
||||
inner: I,
|
||||
@ -791,12 +1114,6 @@ impl<'a, I: Iterator<Item = &'a Variant> + 'a> Iterator
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
let variant = self.inner.next()?;
|
||||
if !variant.fields.is_empty() {
|
||||
return Some(Err(Error::new_spanned(
|
||||
&variant.fields,
|
||||
"Only fieldless enums are supported",
|
||||
)));
|
||||
}
|
||||
|
||||
if let Some((_, discriminant)) = &variant.discriminant {
|
||||
let discriminant_value = match parse_int_expr(discriminant) {
|
||||
@ -822,3 +1139,83 @@ fn parse_int_expr(expr: &Expr) -> Result<i64> {
|
||||
_ => bail!("Not an integer expression"),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use syn::parse_quote;
|
||||
|
||||
use super::{get_repr, IntegerRepr, Repr, Representation};
|
||||
|
||||
#[test]
|
||||
fn parse_basic_repr() {
|
||||
let attr = parse_quote!(#[repr(C)]);
|
||||
let repr = get_repr(&[attr]).unwrap();
|
||||
assert_eq!(repr, Representation { repr: Repr::C, ..Default::default() });
|
||||
|
||||
let attr = parse_quote!(#[repr(transparent)]);
|
||||
let repr = get_repr(&[attr]).unwrap();
|
||||
assert_eq!(
|
||||
repr,
|
||||
Representation { repr: Repr::Transparent, ..Default::default() }
|
||||
);
|
||||
|
||||
let attr = parse_quote!(#[repr(u8)]);
|
||||
let repr = get_repr(&[attr]).unwrap();
|
||||
assert_eq!(
|
||||
repr,
|
||||
Representation {
|
||||
repr: Repr::Integer(IntegerRepr::U8),
|
||||
..Default::default()
|
||||
}
|
||||
);
|
||||
|
||||
let attr = parse_quote!(#[repr(packed)]);
|
||||
let repr = get_repr(&[attr]).unwrap();
|
||||
assert_eq!(repr, Representation { packed: Some(1), ..Default::default() });
|
||||
|
||||
let attr = parse_quote!(#[repr(packed(1))]);
|
||||
let repr = get_repr(&[attr]).unwrap();
|
||||
assert_eq!(repr, Representation { packed: Some(1), ..Default::default() });
|
||||
|
||||
let attr = parse_quote!(#[repr(packed(2))]);
|
||||
let repr = get_repr(&[attr]).unwrap();
|
||||
assert_eq!(repr, Representation { packed: Some(2), ..Default::default() });
|
||||
|
||||
let attr = parse_quote!(#[repr(align(2))]);
|
||||
let repr = get_repr(&[attr]).unwrap();
|
||||
assert_eq!(repr, Representation { align: Some(2), ..Default::default() });
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_advanced_repr() {
|
||||
let attr = parse_quote!(#[repr(align(4), align(2))]);
|
||||
let repr = get_repr(&[attr]).unwrap();
|
||||
assert_eq!(repr, Representation { align: Some(4), ..Default::default() });
|
||||
|
||||
let attr1 = parse_quote!(#[repr(align(1))]);
|
||||
let attr2 = parse_quote!(#[repr(align(4))]);
|
||||
let attr3 = parse_quote!(#[repr(align(2))]);
|
||||
let repr = get_repr(&[attr1, attr2, attr3]).unwrap();
|
||||
assert_eq!(repr, Representation { align: Some(4), ..Default::default() });
|
||||
|
||||
let attr = parse_quote!(#[repr(C, u8)]);
|
||||
let repr = get_repr(&[attr]).unwrap();
|
||||
assert_eq!(
|
||||
repr,
|
||||
Representation {
|
||||
repr: Repr::CWithDiscriminant(IntegerRepr::U8),
|
||||
..Default::default()
|
||||
}
|
||||
);
|
||||
|
||||
let attr = parse_quote!(#[repr(u8, C)]);
|
||||
let repr = get_repr(&[attr]).unwrap();
|
||||
assert_eq!(
|
||||
repr,
|
||||
Representation {
|
||||
repr: Repr::CWithDiscriminant(IntegerRepr::U8),
|
||||
..Default::default()
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
use bytemuck::{
|
||||
AnyBitPattern, CheckedBitPattern, Contiguous, NoUninit, Pod,
|
||||
TransparentWrapper, Zeroable,
|
||||
TransparentWrapper, Zeroable, checked::CheckedCastError,
|
||||
};
|
||||
use std::marker::{PhantomData, PhantomPinned};
|
||||
|
||||
@ -160,6 +160,66 @@ struct AnyBitPatternTest<A: AnyBitPattern, B: AnyBitPattern> {
|
||||
b: B,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, CheckedBitPattern)]
|
||||
#[repr(C, align(8))]
|
||||
struct CheckedBitPatternAlignedStruct {
|
||||
a: u16,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, CheckedBitPattern, PartialEq, Eq)]
|
||||
#[repr(C)]
|
||||
enum CheckedBitPatternCDefaultDiscriminantEnumWithFields {
|
||||
A(u64),
|
||||
B { c: u64 },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, CheckedBitPattern, PartialEq, Eq)]
|
||||
#[repr(C, u8)]
|
||||
enum CheckedBitPatternCEnumWithFields {
|
||||
A(u32),
|
||||
B { c: u32 },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, CheckedBitPattern, PartialEq, Eq)]
|
||||
#[repr(u8)]
|
||||
enum CheckedBitPatternIntEnumWithFields {
|
||||
A(u8),
|
||||
B { c: u32 },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, CheckedBitPattern, PartialEq, Eq)]
|
||||
#[repr(transparent)]
|
||||
enum CheckedBitPatternTransparentEnumWithFields {
|
||||
A { b: u32 },
|
||||
}
|
||||
|
||||
// size 24, align 8.
|
||||
// first byte always the u8 discriminant, then 7 bytes of padding until the payload union since the align of the payload
|
||||
// is the greatest of the align of all the variants, which is 8 (from CheckedBitPatternCDefaultDiscriminantEnumWithFields)
|
||||
#[derive(Debug, Clone, Copy, CheckedBitPattern, PartialEq, Eq)]
|
||||
#[repr(C, u8)]
|
||||
enum CheckedBitPatternEnumNested {
|
||||
A(CheckedBitPatternCEnumWithFields),
|
||||
B(CheckedBitPatternCDefaultDiscriminantEnumWithFields),
|
||||
}
|
||||
|
||||
/// ```compile_fail
|
||||
/// use bytemuck::{Pod, Zeroable};
|
||||
///
|
||||
/// #[derive(Pod, Zeroable)]
|
||||
/// #[repr(transparent)]
|
||||
/// struct TransparentSingle<T>(T);
|
||||
///
|
||||
/// struct NotPod(u32);
|
||||
///
|
||||
/// let _: u32 = bytemuck::cast(TransparentSingle(NotPod(0u32)));
|
||||
/// ```
|
||||
#[derive(
|
||||
Debug, Copy, Clone, PartialEq, Eq, Pod, Zeroable, TransparentWrapper,
|
||||
)]
|
||||
#[repr(transparent)]
|
||||
struct NewtypeWrapperTest<T>(T);
|
||||
|
||||
#[test]
|
||||
fn fails_cast_contiguous() {
|
||||
let can_cast = CheckedBitPatternEnumWithValues::is_valid_bit_pattern(&5);
|
||||
@ -246,6 +306,140 @@ fn checkedbitpattern_try_pod_read_unaligned() {
|
||||
assert!(res.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn checkedbitpattern_aligned_struct() {
|
||||
let pod = [0u8; 8];
|
||||
bytemuck::checked::pod_read_unaligned::<CheckedBitPatternAlignedStruct>(&pod);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn checkedbitpattern_c_default_discriminant_enum_with_fields() {
|
||||
let pod = [
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xcc, 0x55, 0x55, 0x55,
|
||||
0x55, 0x55, 0x55, 0xcc,
|
||||
];
|
||||
let value = bytemuck::checked::pod_read_unaligned::<
|
||||
CheckedBitPatternCDefaultDiscriminantEnumWithFields,
|
||||
>(&pod);
|
||||
assert_eq!(
|
||||
value,
|
||||
CheckedBitPatternCDefaultDiscriminantEnumWithFields::A(0xcc555555555555cc)
|
||||
);
|
||||
|
||||
let pod = [
|
||||
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xcc, 0x55, 0x55, 0x55,
|
||||
0x55, 0x55, 0x55, 0xcc,
|
||||
];
|
||||
let value = bytemuck::checked::pod_read_unaligned::<
|
||||
CheckedBitPatternCDefaultDiscriminantEnumWithFields,
|
||||
>(&pod);
|
||||
assert_eq!(
|
||||
value,
|
||||
CheckedBitPatternCDefaultDiscriminantEnumWithFields::B {
|
||||
c: 0xcc555555555555cc
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn checkedbitpattern_c_enum_with_fields() {
|
||||
let pod = [0x00, 0x00, 0x00, 0x00, 0xcc, 0x55, 0x55, 0xcc];
|
||||
let value = bytemuck::checked::pod_read_unaligned::<
|
||||
CheckedBitPatternCEnumWithFields,
|
||||
>(&pod);
|
||||
assert_eq!(value, CheckedBitPatternCEnumWithFields::A(0xcc5555cc));
|
||||
|
||||
let pod = [0x01, 0x00, 0x00, 0x00, 0xcc, 0x55, 0x55, 0xcc];
|
||||
let value = bytemuck::checked::pod_read_unaligned::<
|
||||
CheckedBitPatternCEnumWithFields,
|
||||
>(&pod);
|
||||
assert_eq!(value, CheckedBitPatternCEnumWithFields::B { c: 0xcc5555cc });
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn checkedbitpattern_int_enum_with_fields() {
|
||||
let pod = [0x00, 0x55, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
|
||||
let value = bytemuck::checked::pod_read_unaligned::<
|
||||
CheckedBitPatternIntEnumWithFields,
|
||||
>(&pod);
|
||||
assert_eq!(value, CheckedBitPatternIntEnumWithFields::A(0x55));
|
||||
|
||||
let pod = [0x01, 0x00, 0x00, 0x00, 0xcc, 0x55, 0x55, 0xcc];
|
||||
let value = bytemuck::checked::pod_read_unaligned::<
|
||||
CheckedBitPatternIntEnumWithFields,
|
||||
>(&pod);
|
||||
assert_eq!(value, CheckedBitPatternIntEnumWithFields::B { c: 0xcc5555cc });
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn checkedbitpattern_nested_enum_with_fields() {
|
||||
// total size 24 bytes. first byte always the u8 discriminant.
|
||||
|
||||
#[repr(C, align(8))]
|
||||
struct Align8Bytes([u8; 24]);
|
||||
|
||||
// first we'll check variantA, nested variant A
|
||||
let pod = Align8Bytes([
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // byte 0 discriminant = 0 = variant A, bytes 1-7 irrelevant padding.
|
||||
0x00, 0x00, 0x00, 0x00, 0xcc, 0x55, 0x55, 0xcc, // bytes 8-15 are the nested CheckedBitPatternCEnumWithFields,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // bytes 16-23 padding
|
||||
]);
|
||||
let value = bytemuck::checked::from_bytes::<
|
||||
CheckedBitPatternEnumNested,
|
||||
>(&pod.0);
|
||||
assert_eq!(value, &CheckedBitPatternEnumNested::A(CheckedBitPatternCEnumWithFields::A(0xcc5555cc)));
|
||||
|
||||
// next we'll check invalid first discriminant fails
|
||||
let pod = Align8Bytes([
|
||||
0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // byte 0 discriminant = 2 = invalid, bytes 1-7 padding
|
||||
0x00, 0x00, 0x00, 0x00, 0xcc, 0x55, 0x55, 0xcc, // bytes 8-15 are the nested CheckedBitPatternCEnumWithFields = A,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // bytes 16-23 padding
|
||||
]);
|
||||
let result = bytemuck::checked::try_from_bytes::<
|
||||
CheckedBitPatternEnumNested,
|
||||
>(&pod.0);
|
||||
assert_eq!(result, Err(CheckedCastError::InvalidBitPattern));
|
||||
|
||||
|
||||
// next we'll check variant B, nested variant B
|
||||
let pod = Align8Bytes([
|
||||
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // byte 0 discriminant = 1 = variant B, bytes 1-7 padding
|
||||
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // bytes 8-15 is C int size discriminant of CheckedBitPatternCDefaultDiscrimimantEnumWithFields, 1 (LE byte order) = variant B
|
||||
0xcc, 0x55, 0x55, 0x55, 0x55, 0x55, 0x55, 0xcc, // bytes 16-13 is the data contained in nested variant B
|
||||
]);
|
||||
let value = bytemuck::checked::from_bytes::<
|
||||
CheckedBitPatternEnumNested,
|
||||
>(&pod.0);
|
||||
assert_eq!(
|
||||
value,
|
||||
&CheckedBitPatternEnumNested::B(CheckedBitPatternCDefaultDiscriminantEnumWithFields::B {
|
||||
c: 0xcc555555555555cc
|
||||
})
|
||||
);
|
||||
|
||||
// finally we'll check variant B, nested invalid discriminant
|
||||
let pod = Align8Bytes([
|
||||
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // 1 discriminant = variant B, bytes 1-7 padding
|
||||
0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // bytes 8-15 is C int size discriminant of CheckedBitPatternCDefaultDiscrimimantEnumWithFields, 0x08 is invalid
|
||||
0xcc, 0x55, 0x55, 0x55, 0x55, 0x55, 0x55, 0xcc, // bytes 16-13 is the data contained in nested variant B
|
||||
]);
|
||||
let result = bytemuck::checked::try_from_bytes::<
|
||||
CheckedBitPatternEnumNested,
|
||||
>(&pod.0);
|
||||
assert_eq!(result, Err(CheckedCastError::InvalidBitPattern));
|
||||
}
|
||||
#[test]
|
||||
fn checkedbitpattern_transparent_enum_with_fields() {
|
||||
let pod = [0xcc, 0x55, 0x55, 0xcc];
|
||||
let value = bytemuck::checked::pod_read_unaligned::<
|
||||
CheckedBitPatternTransparentEnumWithFields,
|
||||
>(&pod);
|
||||
assert_eq!(
|
||||
value,
|
||||
CheckedBitPatternTransparentEnumWithFields::A { b: 0xcc5555cc }
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
|
||||
#[repr(C, align(16))]
|
||||
struct Issue127 {}
|
||||
|
Loading…
Reference in New Issue
Block a user