allow deriving CheckedBitPattern for enums with fields (#171)

* simplify `ToTokens` impl for `Representation`

Instead of collecting the representation and modifier into `Option`s
and determining whether a comma is needed manually, we can use the
`Puncutuated` struct which handles commas automatically.

This will also make emitting the `align` modifier in the future easier.

* emit alignment modifier

This is required for correctly implementing `CheckedBitPattern` because
we need the layout of the type and its `Bits` type to have the same
layout.

* add unit test for `#[repr]` parsing

* allow multiple alignment modifiers

According to RFC #1358, if multiple alignment modifiers are specified,
the resulting alignment is the maximum of all alignment modifiers.

* actually return the error we just created

* factor out the integer Repr's into their own type

This is a preparation step for adding support for `#[repr(C, int)]`.

* allow parsing `#[repr(C, int)]`

This can be used on enums with fields.

* derive `CheckedBitPattern` for enums with fields

The implementation mostly mirrors the desugaring described at
https://doc.rust-lang.org/reference/type-layout.html

* add comments and rename some idents

* update error message

* update docs for `CheckedBitPattern` derive

* add new nested test case, change generated type naming scheme

* fix wrong comment

* small nit

---------

Co-authored-by: Gray Olson <gray@grayolson.com>
This commit is contained in:
Tom Dohrmann 2023-09-06 17:37:07 +02:00 committed by GitHub
parent ff0b14dae9
commit d10fbfc6ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 713 additions and 121 deletions

View File

@ -218,7 +218,7 @@ pub fn derive_zeroable(
/// - The struct must contain no generic parameters
///
/// If applied to an enum:
/// - The enum must be explicit `#[repr(Int)]`
/// - The enum must be explicit `#[repr(Int)]`, `#[repr(C)]`, or both
/// - All variants must be fieldless
/// - The enum must contain no generic parameters
#[proc_macro_derive(NoUninit)]
@ -237,16 +237,17 @@ pub fn derive_no_uninit(
/// for the `CheckedBitPattern` trait and derives the required `Bits` type
/// definition and `is_valid_bit_pattern` method for the type automatically.
///
/// The following constraints need to be satisfied for the macro to succeed
/// (the rest of the constraints are guaranteed by the `CheckedBitPattern`
/// subtrait bounds, i.e. are guaranteed by the requirements of the `NoUninit`
/// trait which `CheckedBitPattern` is a subtrait of):
/// The following constraints need to be satisfied for the macro to succeed:
///
/// If applied to a struct:
/// - All fields must implement `CheckedBitPattern`
/// - The struct must be `#[repr(C)]` or `#[repr(transparent)]`
/// - The struct must contain no generic parameters
///
/// If applied to an enum:
/// - All requirements already checked by `NoUninit`, just impls the trait
/// - The enum must be explicit `#[repr(Int)]`
/// - All fields in variants must implement `CheckedBitPattern`
/// - The enum must contain no generic parameters
#[proc_macro_derive(CheckedBitPattern)]
pub fn derive_maybe_pod(
input: proc_macro::TokenStream,

View File

@ -1,4 +1,6 @@
#![allow(unused_imports)]
use std::{cmp, convert::TryFrom};
use proc_macro2::{Ident, Span, TokenStream, TokenTree};
use quote::{quote, quote_spanned, ToTokens};
use syn::{
@ -204,11 +206,19 @@ impl Derivable for CheckedBitPattern {
Repr::C | Repr::Transparent => Ok(()),
_ => bail!("CheckedBitPattern derive requires the struct to be #[repr(C)] or #[repr(transparent)]"),
},
Data::Enum(_) => if repr.repr.is_integer() {
Ok(())
} else {
bail!("CheckedBitPattern requires the enum to be an explicit #[repr(Int)]")
},
Data::Enum(DataEnum { variants,.. }) => {
if !enum_has_fields(variants.iter()){
if repr.repr.is_integer() {
Ok(())
} else {
bail!("CheckedBitPattern requires the enum to be an explicit #[repr(Int)]")
}
} else if matches!(repr.repr, Repr::Rust) {
bail!("CheckedBitPattern requires an explicit repr annotation because `repr(Rust)` doesn't have a specified type layout")
} else {
Ok(())
}
}
Data::Union(_) => bail!("CheckedBitPattern can only be derived on enums and structs")
}
}
@ -235,7 +245,9 @@ impl Derivable for CheckedBitPattern {
Data::Struct(DataStruct { fields, .. }) => {
generate_checked_bit_pattern_struct(&input.ident, fields, &input.attrs)
}
Data::Enum(_) => generate_checked_bit_pattern_enum(input),
Data::Enum(DataEnum { variants, .. }) => {
generate_checked_bit_pattern_enum(input, variants)
}
Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), /* shouldn't be possible since we already error in attribute check for this case */
}
}
@ -347,13 +359,20 @@ impl Derivable for Contiguous {
fn trait_impl(input: &DeriveInput) -> Result<(TokenStream, TokenStream)> {
let repr = get_repr(&input.attrs)?;
let integer_ty = if let Some(integer_ty) = repr.repr.as_integer_type() {
let integer_ty = if let Some(integer_ty) = repr.repr.as_integer() {
integer_ty
} else {
bail!("Contiguous requires the enum to be #[repr(Int)]");
};
let variants = get_enum_variants(input)?;
if enum_has_fields(variants.clone()) {
return Err(Error::new_spanned(
&input,
"Only fieldless enums are supported",
));
}
let mut variants_with_discriminator =
VariantDiscriminantIterator::new(variants);
@ -426,7 +445,7 @@ fn get_fields(input: &DeriveInput) -> Result<Fields> {
fn get_enum_variants<'a>(
input: &'a DeriveInput,
) -> Result<impl Iterator<Item = &'a Variant> + 'a> {
) -> Result<impl Iterator<Item = &'a Variant> + Clone + 'a> {
if let Data::Enum(DataEnum { variants, .. }) = &input.data {
Ok(variants.iter())
} else {
@ -486,11 +505,21 @@ fn generate_checked_bit_pattern_struct(
}
fn generate_checked_bit_pattern_enum(
input: &DeriveInput,
input: &DeriveInput, variants: &Punctuated<Variant, Token![,]>,
) -> Result<(TokenStream, TokenStream)> {
if enum_has_fields(variants.iter()) {
generate_checked_bit_pattern_enum_with_fields(input, variants)
} else {
generate_checked_bit_pattern_enum_without_fields(input, variants)
}
}
fn generate_checked_bit_pattern_enum_without_fields(
input: &DeriveInput, variants: &Punctuated<Variant, Token![,]>,
) -> Result<(TokenStream, TokenStream)> {
let span = input.span();
let mut variants_with_discriminant =
VariantDiscriminantIterator::new(get_enum_variants(input)?);
VariantDiscriminantIterator::new(variants.iter());
let (min, max, count) = variants_with_discriminant.try_fold(
(i64::max_value(), i64::min_value(), 0),
@ -514,13 +543,12 @@ fn generate_checked_bit_pattern_enum(
quote!(*bits >= #min_lit && *bits <= #max_lit)
} else {
// not contiguous range, check for each
let variant_lits =
VariantDiscriminantIterator::new(get_enum_variants(input)?)
.map(|res| {
let variant = res?;
Ok(LitInt::new(&format!("{}", variant), span))
})
.collect::<Result<Vec<_>>>()?;
let variant_lits = VariantDiscriminantIterator::new(variants.iter())
.map(|res| {
let variant = res?;
Ok(LitInt::new(&format!("{}", variant), span))
})
.collect::<Result<Vec<_>>>()?;
// count is at least 1
let first = &variant_lits[0];
@ -530,11 +558,11 @@ fn generate_checked_bit_pattern_enum(
};
let repr = get_repr(&input.attrs)?;
let integer_ty = repr.repr.as_integer_type().unwrap(); // should be checked in attr check already
let integer = repr.repr.as_integer().unwrap(); // should be checked in attr check already
Ok((
quote!(),
quote! {
type Bits = #integer_ty;
type Bits = #integer;
#[inline]
#[allow(clippy::double_comparisons)]
@ -545,6 +573,244 @@ fn generate_checked_bit_pattern_enum(
))
}
fn generate_checked_bit_pattern_enum_with_fields(
input: &DeriveInput, variants: &Punctuated<Variant, Token![,]>,
) -> Result<(TokenStream, TokenStream)> {
let representation = get_repr(&input.attrs)?;
let vis = &input.vis;
let derive_dbg =
quote!(#[cfg_attr(not(target_arch = "spirv"), derive(Debug))]);
match representation.repr {
Repr::Rust => unreachable!(),
repr @ (Repr::C | Repr::CWithDiscriminant(_)) => {
let integer = match repr {
Repr::C => quote!(::core::ffi::c_int),
Repr::CWithDiscriminant(integer) => quote!(#integer),
_ => unreachable!(),
};
let input_ident = &input.ident;
let bits_repr = Representation { repr: Repr::C, ..representation };
// the enum manually re-configured as the actual tagged union it represents,
// thus circumventing the requirements rust imposes on the tag even when using
// #[repr(C)] enum layout
// see: https://doc.rust-lang.org/reference/type-layout.html#reprc-enums-with-fields
let bits_ty_ident = Ident::new(&format!("{input_ident}Bits"), input.span());
// the variants union part of the tagged union. These get put into a union which gets the
// AnyBitPattern derive applied to it, thus checking that the fields of the union obey the requriements of AnyBitPattern.
// The types that actually go in the union are one more level of indirection deep: we generate new structs for each variant
// (`variant_struct_definitions`) which themselves have the `CheckedBitPattern` derive applied, thus generating `{variant_struct_ident}Bits`
// structs, which are the ones that go into this union.
let variants_union_ident =
Ident::new(&format!("{}Variants", input.ident), input.span());
let variant_struct_idents = variants
.iter()
.map(|v| Ident::new(&format!("{input_ident}Variant{}", v.ident), v.span()));
let variant_struct_definitions =
variant_struct_idents.clone().zip(variants.iter()).map(|(variant_struct_ident, v)| {
let fields = v.fields.iter().map(|v| &v.ty);
quote! {
#[derive(::core::clone::Clone, ::core::marker::Copy, ::bytemuck::CheckedBitPattern)]
#[repr(C)]
#vis struct #variant_struct_ident(#(#fields),*);
}
});
let union_fields =
variant_struct_idents.clone().zip(variants.iter()).map(|(variant_struct_ident, v)| {
let variant_struct_bits_ident =
Ident::new(&format!("{variant_struct_ident}Bits"), input.span());
let field_ident = &v.ident;
quote! {
#field_ident: #variant_struct_bits_ident
}
});
let variant_checks = variant_struct_idents
.clone()
.zip(VariantDiscriminantIterator::new(variants.iter()))
.zip(variants.iter())
.map(|((variant_struct_ident, discriminant), v)| -> Result<_> {
let discriminant = discriminant?;
let discriminant = LitInt::new(&discriminant.to_string(), v.span());
let ident = &v.ident;
Ok(quote! {
#discriminant => {
let payload = unsafe { &bits.payload.#ident };
<#variant_struct_ident as ::bytemuck::CheckedBitPattern>::is_valid_bit_pattern(payload)
}
})
})
.collect::<Result<Vec<_>>>()?;
Ok((
quote! {
#[derive(::core::clone::Clone, ::core::marker::Copy, ::bytemuck::AnyBitPattern)]
#derive_dbg
#bits_repr
#vis struct #bits_ty_ident {
tag: #integer,
payload: #variants_union_ident,
}
#[derive(::core::clone::Clone, ::core::marker::Copy, ::bytemuck::AnyBitPattern)]
#[repr(C)]
#[allow(non_snake_case)]
#vis union #variants_union_ident {
#(#union_fields,)*
}
#[cfg(not(target_arch = "spirv"))]
impl ::core::fmt::Debug for #variants_union_ident {
fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
let mut debug_struct = ::core::fmt::Formatter::debug_struct(f, ::core::stringify!(#variants_union_ident));
::core::fmt::DebugStruct::finish_non_exhaustive(&mut debug_struct)
}
}
#(#variant_struct_definitions)*
},
quote! {
type Bits = #bits_ty_ident;
#[inline]
#[allow(clippy::double_comparisons)]
fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
match bits.tag {
#(#variant_checks)*
_ => false,
}
}
},
))
}
Repr::Transparent => {
if variants.len() != 1 {
bail!("enums with more than one variant cannot be transparent")
}
let variant = &variants[0];
let bits_ty = Ident::new(&format!("{}Bits", input.ident), input.span());
let fields = variant.fields.iter().map(|v| &v.ty);
Ok((
quote! {
#[derive(::core::clone::Clone, ::core::marker::Copy, ::bytemuck::CheckedBitPattern)]
#[repr(C)]
#vis struct #bits_ty(#(#fields),*);
},
quote! {
type Bits = <#bits_ty as ::bytemuck::CheckedBitPattern>::Bits;
#[inline]
#[allow(clippy::double_comparisons)]
fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
<#bits_ty as ::bytemuck::CheckedBitPattern>::is_valid_bit_pattern(bits)
}
},
))
}
Repr::Integer(integer) => {
let bits_repr = Representation { repr: Repr::C, ..representation };
let input_ident = &input.ident;
// the enum manually re-configured as the union it represents. such a union is the union of variants
// as a repr(c) struct with the discriminator type inserted at the beginning.
// in our case we union the `Bits` representation of each variant rather than the variant itself, which we generate
// via a nested `CheckedBitPattern` derive on the `variant_struct_definitions` generated below.
//
// see: https://doc.rust-lang.org/reference/type-layout.html#primitive-representation-of-enums-with-fields
let bits_ty_ident = Ident::new(&format!("{input_ident}Bits"), input.span());
let variant_struct_idents = variants
.iter()
.map(|v| Ident::new(&format!("{input_ident}Variant{}", v.ident), v.span()));
let variant_struct_definitions =
variant_struct_idents.clone().zip(variants.iter()).map(|(variant_struct_ident, v)| {
let fields = v.fields.iter().map(|v| &v.ty);
// adding the discriminant repr integer as first field, as described above
quote! {
#[derive(::core::clone::Clone, ::core::marker::Copy, ::bytemuck::CheckedBitPattern)]
#[repr(C)]
#vis struct #variant_struct_ident(#integer, #(#fields),*);
}
});
let union_fields =
variant_struct_idents.clone().zip(variants.iter()).map(|(variant_struct_ident, v)| {
let variant_struct_bits_ident =
Ident::new(&format!("{variant_struct_ident}Bits"), input.span());
let field_ident = &v.ident;
quote! {
#field_ident: #variant_struct_bits_ident
}
});
let variant_checks = variant_struct_idents
.clone()
.zip(VariantDiscriminantIterator::new(variants.iter()))
.zip(variants.iter())
.map(|((variant_struct_ident, discriminant), v)| -> Result<_> {
let discriminant = discriminant?;
let discriminant = LitInt::new(&discriminant.to_string(), v.span());
let ident = &v.ident;
Ok(quote! {
#discriminant => {
let payload = unsafe { &bits.#ident };
<#variant_struct_ident as ::bytemuck::CheckedBitPattern>::is_valid_bit_pattern(payload)
}
})
})
.collect::<Result<Vec<_>>>()?;
Ok((
quote! {
#[derive(::core::clone::Clone, ::core::marker::Copy, ::bytemuck::AnyBitPattern)]
#bits_repr
#[allow(non_snake_case)]
#vis union #bits_ty_ident {
__tag: #integer,
#(#union_fields,)*
}
#[cfg(not(target_arch = "spirv"))]
impl ::core::fmt::Debug for #bits_ty_ident {
fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
let mut debug_struct = ::core::fmt::Formatter::debug_struct(f, ::core::stringify!(#bits_ty_ident));
::core::fmt::DebugStruct::field(&mut debug_struct, "tag", unsafe { &self.__tag });
::core::fmt::DebugStruct::finish_non_exhaustive(&mut debug_struct)
}
}
#(#variant_struct_definitions)*
},
quote! {
type Bits = #bits_ty_ident;
#[inline]
#[allow(clippy::double_comparisons)]
fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
match unsafe { bits.__tag } {
#(#variant_checks)*
_ => false,
}
}
},
))
}
}
}
/// Check that a struct has no padding by asserting that the size of the struct
/// is equal to the sum of the size of it's fields
fn generate_assert_no_padding(input: &DeriveInput) -> Result<TokenStream> {
@ -637,9 +903,9 @@ fn get_repr(attributes: &[Attribute]) -> Result<Representation> {
_ => bail!("conflicting representation hints"),
},
align: match (a.align, b.align) {
(Some(a), Some(b)) => Some(cmp::max(a, b)),
(a, None) => a,
(None, b) => b,
_ => bail!("conflicting representation hints"),
},
})
})
@ -665,112 +931,169 @@ macro_rules! mk_repr {(
$Xn:ident => $xn:ident
),* $(,)?
) => (
#[derive(Clone, Copy, PartialEq)]
enum Repr {
Rust,
C,
Transparent,
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum IntegerRepr {
$($Xn),*
}
impl Repr {
fn is_integer(self) -> bool {
match self {
Repr::Rust | Repr::C | Repr::Transparent => false,
_ => true,
}
}
impl<'a> TryFrom<&'a str> for IntegerRepr {
type Error = &'a str;
fn as_integer_type(self) -> Option<TokenStream> {
match self {
Repr::Rust | Repr::C | Repr::Transparent => None,
fn try_from(value: &'a str) -> std::result::Result<Self, &'a str> {
match value {
$(
Repr::$Xn => Some(quote! { ::core::primitive::$xn }),
stringify!($xn) => Ok(Self::$Xn),
)*
_ => Err(value),
}
}
}
#[derive(Clone, Copy)]
struct Representation {
packed: Option<u32>,
align: Option<u32>,
repr: Repr,
}
impl Default for Representation {
fn default() -> Self {
Self { packed: None, align: None, repr: Repr::Rust }
}
}
impl Parse for Representation {
fn parse(input: ParseStream<'_>) -> Result<Representation> {
let mut ret = Representation::default();
while !input.is_empty() {
let keyword = input.parse::<Ident>()?;
// preëmptively call `.to_string()` *once* (rather than on `is_ident()`)
let keyword_str = keyword.to_string();
let new_repr = match keyword_str.as_str() {
"C" => Repr::C,
"transparent" => Repr::Transparent,
"packed" => {
ret.packed = Some(if input.peek(token::Paren) {
let contents; parenthesized!(contents in input);
LitInt::base10_parse::<u32>(&contents.parse()?)?
} else {
1
});
let _: Option<Token![,]> = input.parse()?;
continue;
},
"align" => {
let contents; parenthesized!(contents in input);
ret.align = Some(LitInt::base10_parse::<u32>(&contents.parse()?)?);
let _: Option<Token![,]> = input.parse()?;
continue;
},
$(
stringify!($xn) => Repr::$Xn,
)*
_ => return Err(input.error("unrecognized representation hint"))
};
if ::core::mem::replace(&mut ret.repr, new_repr) != Repr::Rust {
input.error("duplicate representation hint");
}
let _: Option<Token![,]> = input.parse()?;
}
Ok(ret)
}
}
impl ToTokens for Representation {
impl ToTokens for IntegerRepr {
fn to_tokens(&self, tokens: &mut TokenStream) {
let repr = match self.repr {
Repr::Rust => None,
Repr::C => Some(quote!(C)),
Repr::Transparent => Some(quote!(transparent)),
match self {
$(
Repr::$Xn => Some(quote!($xn)),
Self::$Xn => tokens.extend(quote!($xn)),
)*
};
let packed = self.packed.map(|p| {
let lit = LitInt::new(&p.to_string(), Span::call_site());
quote!(packed(#lit))
});
let comma = if packed.is_some() && repr.is_some() {
Some(quote!(,))
} else {
None
};
tokens.extend(quote!(
#[repr( #repr #comma #packed )]
));
}
}
}
)}
use mk_repr;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Repr {
Rust,
C,
Transparent,
Integer(IntegerRepr),
CWithDiscriminant(IntegerRepr),
}
impl Repr {
fn is_integer(&self) -> bool {
matches!(self, Self::Integer(..))
}
fn as_integer(&self) -> Option<IntegerRepr> {
if let Self::Integer(v) = self {
Some(*v)
} else {
None
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct Representation {
packed: Option<u32>,
align: Option<u32>,
repr: Repr,
}
impl Default for Representation {
fn default() -> Self {
Self { packed: None, align: None, repr: Repr::Rust }
}
}
impl Parse for Representation {
fn parse(input: ParseStream<'_>) -> Result<Representation> {
let mut ret = Representation::default();
while !input.is_empty() {
let keyword = input.parse::<Ident>()?;
// preëmptively call `.to_string()` *once* (rather than on `is_ident()`)
let keyword_str = keyword.to_string();
let new_repr = match keyword_str.as_str() {
"C" => Repr::C,
"transparent" => Repr::Transparent,
"packed" => {
ret.packed = Some(if input.peek(token::Paren) {
let contents;
parenthesized!(contents in input);
LitInt::base10_parse::<u32>(&contents.parse()?)?
} else {
1
});
let _: Option<Token![,]> = input.parse()?;
continue;
}
"align" => {
let contents;
parenthesized!(contents in input);
let new_align = LitInt::base10_parse::<u32>(&contents.parse()?)?;
ret.align = Some(
ret
.align
.map_or(new_align, |old_align| cmp::max(old_align, new_align)),
);
let _: Option<Token![,]> = input.parse()?;
continue;
}
ident => {
let primitive = IntegerRepr::try_from(ident)
.map_err(|_| input.error("unrecognized representation hint"))?;
Repr::Integer(primitive)
}
};
ret.repr = match (ret.repr, new_repr) {
(Repr::Rust, new_repr) => {
// This is the first explicit repr.
new_repr
}
(Repr::C, Repr::Integer(integer))
| (Repr::Integer(integer), Repr::C) => {
// Both the C repr and an integer repr have been specified
// -> merge into a C wit discriminant.
Repr::CWithDiscriminant(integer)
}
(_, _) => {
return Err(input.error("duplicate representation hint"));
}
};
let _: Option<Token![,]> = input.parse()?;
}
Ok(ret)
}
}
impl ToTokens for Representation {
fn to_tokens(&self, tokens: &mut TokenStream) {
let mut meta = Punctuated::<_, Token![,]>::new();
match self.repr {
Repr::Rust => {}
Repr::C => meta.push(quote!(C)),
Repr::Transparent => meta.push(quote!(transparent)),
Repr::Integer(primitive) => meta.push(quote!(#primitive)),
Repr::CWithDiscriminant(primitive) => {
meta.push(quote!(C));
meta.push(quote!(#primitive));
}
}
if let Some(packed) = self.packed.as_ref() {
let lit = LitInt::new(&packed.to_string(), Span::call_site());
meta.push(quote!(packed(#lit)));
}
if let Some(align) = self.align.as_ref() {
let lit = LitInt::new(&align.to_string(), Span::call_site());
meta.push(quote!(align(#lit)));
}
tokens.extend(quote!(
#[repr(#meta)]
));
}
}
fn enum_has_fields<'a>(
mut variants: impl Iterator<Item = &'a Variant>,
) -> bool {
variants.any(|v| matches!(v.fields, Fields::Named(_) | Fields::Unnamed(_)))
}
struct VariantDiscriminantIterator<'a, I: Iterator<Item = &'a Variant> + 'a> {
inner: I,
last_value: i64,
@ -791,12 +1114,6 @@ impl<'a, I: Iterator<Item = &'a Variant> + 'a> Iterator
fn next(&mut self) -> Option<Self::Item> {
let variant = self.inner.next()?;
if !variant.fields.is_empty() {
return Some(Err(Error::new_spanned(
&variant.fields,
"Only fieldless enums are supported",
)));
}
if let Some((_, discriminant)) = &variant.discriminant {
let discriminant_value = match parse_int_expr(discriminant) {
@ -822,3 +1139,83 @@ fn parse_int_expr(expr: &Expr) -> Result<i64> {
_ => bail!("Not an integer expression"),
}
}
#[cfg(test)]
mod tests {
use syn::parse_quote;
use super::{get_repr, IntegerRepr, Repr, Representation};
#[test]
fn parse_basic_repr() {
let attr = parse_quote!(#[repr(C)]);
let repr = get_repr(&[attr]).unwrap();
assert_eq!(repr, Representation { repr: Repr::C, ..Default::default() });
let attr = parse_quote!(#[repr(transparent)]);
let repr = get_repr(&[attr]).unwrap();
assert_eq!(
repr,
Representation { repr: Repr::Transparent, ..Default::default() }
);
let attr = parse_quote!(#[repr(u8)]);
let repr = get_repr(&[attr]).unwrap();
assert_eq!(
repr,
Representation {
repr: Repr::Integer(IntegerRepr::U8),
..Default::default()
}
);
let attr = parse_quote!(#[repr(packed)]);
let repr = get_repr(&[attr]).unwrap();
assert_eq!(repr, Representation { packed: Some(1), ..Default::default() });
let attr = parse_quote!(#[repr(packed(1))]);
let repr = get_repr(&[attr]).unwrap();
assert_eq!(repr, Representation { packed: Some(1), ..Default::default() });
let attr = parse_quote!(#[repr(packed(2))]);
let repr = get_repr(&[attr]).unwrap();
assert_eq!(repr, Representation { packed: Some(2), ..Default::default() });
let attr = parse_quote!(#[repr(align(2))]);
let repr = get_repr(&[attr]).unwrap();
assert_eq!(repr, Representation { align: Some(2), ..Default::default() });
}
#[test]
fn parse_advanced_repr() {
let attr = parse_quote!(#[repr(align(4), align(2))]);
let repr = get_repr(&[attr]).unwrap();
assert_eq!(repr, Representation { align: Some(4), ..Default::default() });
let attr1 = parse_quote!(#[repr(align(1))]);
let attr2 = parse_quote!(#[repr(align(4))]);
let attr3 = parse_quote!(#[repr(align(2))]);
let repr = get_repr(&[attr1, attr2, attr3]).unwrap();
assert_eq!(repr, Representation { align: Some(4), ..Default::default() });
let attr = parse_quote!(#[repr(C, u8)]);
let repr = get_repr(&[attr]).unwrap();
assert_eq!(
repr,
Representation {
repr: Repr::CWithDiscriminant(IntegerRepr::U8),
..Default::default()
}
);
let attr = parse_quote!(#[repr(u8, C)]);
let repr = get_repr(&[attr]).unwrap();
assert_eq!(
repr,
Representation {
repr: Repr::CWithDiscriminant(IntegerRepr::U8),
..Default::default()
}
);
}
}

View File

@ -2,7 +2,7 @@
use bytemuck::{
AnyBitPattern, CheckedBitPattern, Contiguous, NoUninit, Pod,
TransparentWrapper, Zeroable,
TransparentWrapper, Zeroable, checked::CheckedCastError,
};
use std::marker::{PhantomData, PhantomPinned};
@ -160,6 +160,66 @@ struct AnyBitPatternTest<A: AnyBitPattern, B: AnyBitPattern> {
b: B,
}
#[derive(Clone, Copy, CheckedBitPattern)]
#[repr(C, align(8))]
struct CheckedBitPatternAlignedStruct {
a: u16,
}
#[derive(Debug, Clone, Copy, CheckedBitPattern, PartialEq, Eq)]
#[repr(C)]
enum CheckedBitPatternCDefaultDiscriminantEnumWithFields {
A(u64),
B { c: u64 },
}
#[derive(Debug, Clone, Copy, CheckedBitPattern, PartialEq, Eq)]
#[repr(C, u8)]
enum CheckedBitPatternCEnumWithFields {
A(u32),
B { c: u32 },
}
#[derive(Debug, Clone, Copy, CheckedBitPattern, PartialEq, Eq)]
#[repr(u8)]
enum CheckedBitPatternIntEnumWithFields {
A(u8),
B { c: u32 },
}
#[derive(Debug, Clone, Copy, CheckedBitPattern, PartialEq, Eq)]
#[repr(transparent)]
enum CheckedBitPatternTransparentEnumWithFields {
A { b: u32 },
}
// size 24, align 8.
// first byte always the u8 discriminant, then 7 bytes of padding until the payload union since the align of the payload
// is the greatest of the align of all the variants, which is 8 (from CheckedBitPatternCDefaultDiscriminantEnumWithFields)
#[derive(Debug, Clone, Copy, CheckedBitPattern, PartialEq, Eq)]
#[repr(C, u8)]
enum CheckedBitPatternEnumNested {
A(CheckedBitPatternCEnumWithFields),
B(CheckedBitPatternCDefaultDiscriminantEnumWithFields),
}
/// ```compile_fail
/// use bytemuck::{Pod, Zeroable};
///
/// #[derive(Pod, Zeroable)]
/// #[repr(transparent)]
/// struct TransparentSingle<T>(T);
///
/// struct NotPod(u32);
///
/// let _: u32 = bytemuck::cast(TransparentSingle(NotPod(0u32)));
/// ```
#[derive(
Debug, Copy, Clone, PartialEq, Eq, Pod, Zeroable, TransparentWrapper,
)]
#[repr(transparent)]
struct NewtypeWrapperTest<T>(T);
#[test]
fn fails_cast_contiguous() {
let can_cast = CheckedBitPatternEnumWithValues::is_valid_bit_pattern(&5);
@ -246,6 +306,140 @@ fn checkedbitpattern_try_pod_read_unaligned() {
assert!(res.is_err());
}
#[test]
fn checkedbitpattern_aligned_struct() {
let pod = [0u8; 8];
bytemuck::checked::pod_read_unaligned::<CheckedBitPatternAlignedStruct>(&pod);
}
#[test]
fn checkedbitpattern_c_default_discriminant_enum_with_fields() {
let pod = [
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xcc, 0x55, 0x55, 0x55,
0x55, 0x55, 0x55, 0xcc,
];
let value = bytemuck::checked::pod_read_unaligned::<
CheckedBitPatternCDefaultDiscriminantEnumWithFields,
>(&pod);
assert_eq!(
value,
CheckedBitPatternCDefaultDiscriminantEnumWithFields::A(0xcc555555555555cc)
);
let pod = [
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xcc, 0x55, 0x55, 0x55,
0x55, 0x55, 0x55, 0xcc,
];
let value = bytemuck::checked::pod_read_unaligned::<
CheckedBitPatternCDefaultDiscriminantEnumWithFields,
>(&pod);
assert_eq!(
value,
CheckedBitPatternCDefaultDiscriminantEnumWithFields::B {
c: 0xcc555555555555cc
}
);
}
#[test]
fn checkedbitpattern_c_enum_with_fields() {
let pod = [0x00, 0x00, 0x00, 0x00, 0xcc, 0x55, 0x55, 0xcc];
let value = bytemuck::checked::pod_read_unaligned::<
CheckedBitPatternCEnumWithFields,
>(&pod);
assert_eq!(value, CheckedBitPatternCEnumWithFields::A(0xcc5555cc));
let pod = [0x01, 0x00, 0x00, 0x00, 0xcc, 0x55, 0x55, 0xcc];
let value = bytemuck::checked::pod_read_unaligned::<
CheckedBitPatternCEnumWithFields,
>(&pod);
assert_eq!(value, CheckedBitPatternCEnumWithFields::B { c: 0xcc5555cc });
}
#[test]
fn checkedbitpattern_int_enum_with_fields() {
let pod = [0x00, 0x55, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
let value = bytemuck::checked::pod_read_unaligned::<
CheckedBitPatternIntEnumWithFields,
>(&pod);
assert_eq!(value, CheckedBitPatternIntEnumWithFields::A(0x55));
let pod = [0x01, 0x00, 0x00, 0x00, 0xcc, 0x55, 0x55, 0xcc];
let value = bytemuck::checked::pod_read_unaligned::<
CheckedBitPatternIntEnumWithFields,
>(&pod);
assert_eq!(value, CheckedBitPatternIntEnumWithFields::B { c: 0xcc5555cc });
}
#[test]
fn checkedbitpattern_nested_enum_with_fields() {
// total size 24 bytes. first byte always the u8 discriminant.
#[repr(C, align(8))]
struct Align8Bytes([u8; 24]);
// first we'll check variantA, nested variant A
let pod = Align8Bytes([
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // byte 0 discriminant = 0 = variant A, bytes 1-7 irrelevant padding.
0x00, 0x00, 0x00, 0x00, 0xcc, 0x55, 0x55, 0xcc, // bytes 8-15 are the nested CheckedBitPatternCEnumWithFields,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // bytes 16-23 padding
]);
let value = bytemuck::checked::from_bytes::<
CheckedBitPatternEnumNested,
>(&pod.0);
assert_eq!(value, &CheckedBitPatternEnumNested::A(CheckedBitPatternCEnumWithFields::A(0xcc5555cc)));
// next we'll check invalid first discriminant fails
let pod = Align8Bytes([
0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // byte 0 discriminant = 2 = invalid, bytes 1-7 padding
0x00, 0x00, 0x00, 0x00, 0xcc, 0x55, 0x55, 0xcc, // bytes 8-15 are the nested CheckedBitPatternCEnumWithFields = A,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // bytes 16-23 padding
]);
let result = bytemuck::checked::try_from_bytes::<
CheckedBitPatternEnumNested,
>(&pod.0);
assert_eq!(result, Err(CheckedCastError::InvalidBitPattern));
// next we'll check variant B, nested variant B
let pod = Align8Bytes([
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // byte 0 discriminant = 1 = variant B, bytes 1-7 padding
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // bytes 8-15 is C int size discriminant of CheckedBitPatternCDefaultDiscrimimantEnumWithFields, 1 (LE byte order) = variant B
0xcc, 0x55, 0x55, 0x55, 0x55, 0x55, 0x55, 0xcc, // bytes 16-13 is the data contained in nested variant B
]);
let value = bytemuck::checked::from_bytes::<
CheckedBitPatternEnumNested,
>(&pod.0);
assert_eq!(
value,
&CheckedBitPatternEnumNested::B(CheckedBitPatternCDefaultDiscriminantEnumWithFields::B {
c: 0xcc555555555555cc
})
);
// finally we'll check variant B, nested invalid discriminant
let pod = Align8Bytes([
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // 1 discriminant = variant B, bytes 1-7 padding
0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // bytes 8-15 is C int size discriminant of CheckedBitPatternCDefaultDiscrimimantEnumWithFields, 0x08 is invalid
0xcc, 0x55, 0x55, 0x55, 0x55, 0x55, 0x55, 0xcc, // bytes 16-13 is the data contained in nested variant B
]);
let result = bytemuck::checked::try_from_bytes::<
CheckedBitPatternEnumNested,
>(&pod.0);
assert_eq!(result, Err(CheckedCastError::InvalidBitPattern));
}
#[test]
fn checkedbitpattern_transparent_enum_with_fields() {
let pod = [0xcc, 0x55, 0x55, 0xcc];
let value = bytemuck::checked::pod_read_unaligned::<
CheckedBitPatternTransparentEnumWithFields,
>(&pod);
assert_eq!(
value,
CheckedBitPatternTransparentEnumWithFields::A { b: 0xcc5555cc }
);
}
#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
#[repr(C, align(16))]
struct Issue127 {}