mirror of
https://github.com/Lokathor/bytemuck.git
synced 2024-11-23 23:32:24 +00:00
derive(Zeroable) on fieldful enums and repr(C) enums (#257)
* Add support for deriving Zeroable for fieldful enums if: 1. the enum is repr(Int), repr(C), or repr(C, Int), 2. the enum has a variant with discriminant 0, 3. and all fields of the variant with discriminant 0 are Zeroable. * Allow using derive(Zeroable) with explicit bounds. Update documentation and doctests. * doc update * doc update * remove unused * Factor out get_zero_variant helper function. * Use i128 to track disciminants instead of i64. * Add doc-comment for `get_fields` Co-authored-by: Daniel Henry-Mantilla <daniel.henry.mantilla@gmail.com> * Update derive/src/traits.rs Co-authored-by: Daniel Henry-Mantilla <daniel.henry.mantilla@gmail.com> --------- Co-authored-by: Daniel Henry-Mantilla <daniel.henry.mantilla@gmail.com>
This commit is contained in:
parent
bb368799c3
commit
a637e1d983
@ -114,14 +114,26 @@ pub fn derive_anybitpattern(
|
||||
proc_macro::TokenStream::from(expanded)
|
||||
}
|
||||
|
||||
/// Derive the `Zeroable` trait for a struct
|
||||
/// Derive the `Zeroable` trait for a type.
|
||||
///
|
||||
/// The macro ensures that the struct follows all the the safety requirements
|
||||
/// The macro ensures that the type follows all the the safety requirements
|
||||
/// for the `Zeroable` trait.
|
||||
///
|
||||
/// The following constraints need to be satisfied for the macro to succeed
|
||||
/// The following constraints need to be satisfied for the macro to succeed on a
|
||||
/// struct:
|
||||
///
|
||||
/// - All fields in the struct must to implement `Zeroable`
|
||||
/// - All fields in the struct must implement `Zeroable`
|
||||
///
|
||||
/// The following constraints need to be satisfied for the macro to succeed on
|
||||
/// an enum:
|
||||
///
|
||||
/// - The enum has an explicit `#[repr(Int)]`, `#[repr(C)]`, or `#[repr(C,
|
||||
/// Int)]`.
|
||||
/// - The enum has a variant with discriminant 0 (explicitly or implicitly).
|
||||
/// - All fields in the variant with discriminant 0 (if any) must implement
|
||||
/// `Zeroable`
|
||||
///
|
||||
/// The macro always succeeds on unions.
|
||||
///
|
||||
/// ## Example
|
||||
///
|
||||
@ -134,6 +146,23 @@ pub fn derive_anybitpattern(
|
||||
/// b: u16,
|
||||
/// }
|
||||
/// ```
|
||||
/// ```rust
|
||||
/// # use bytemuck_derive::{Zeroable};
|
||||
/// #[derive(Copy, Clone, Zeroable)]
|
||||
/// #[repr(i32)]
|
||||
/// enum Values {
|
||||
/// A = 0,
|
||||
/// B = 1,
|
||||
/// C = 2,
|
||||
/// }
|
||||
/// #[derive(Clone, Zeroable)]
|
||||
/// #[repr(C)]
|
||||
/// enum Implicit {
|
||||
/// A(bool, u8, char),
|
||||
/// B(String),
|
||||
/// C(std::num::NonZeroU8),
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// # Custom bounds
|
||||
///
|
||||
@ -157,6 +186,18 @@ pub fn derive_anybitpattern(
|
||||
///
|
||||
/// AlwaysZeroable::<std::num::NonZeroU8>::zeroed();
|
||||
/// ```
|
||||
/// ```rust
|
||||
/// # use bytemuck::{Zeroable};
|
||||
/// #[derive(Copy, Clone, Zeroable)]
|
||||
/// #[repr(u8)]
|
||||
/// #[zeroable(bound = "")]
|
||||
/// enum MyOption<T> {
|
||||
/// None,
|
||||
/// Some(T),
|
||||
/// }
|
||||
///
|
||||
/// assert!(matches!(MyOption::<std::num::NonZeroU8>::zeroed(), MyOption::None));
|
||||
/// ```
|
||||
///
|
||||
/// ```rust,compile_fail
|
||||
/// # use bytemuck::Zeroable;
|
||||
@ -407,7 +448,8 @@ pub fn derive_byte_eq(
|
||||
let input = parse_macro_input!(input as DeriveInput);
|
||||
let crate_name = bytemuck_crate_name(&input);
|
||||
let ident = input.ident;
|
||||
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
|
||||
let (impl_generics, ty_generics, where_clause) =
|
||||
input.generics.split_for_impl();
|
||||
|
||||
proc_macro::TokenStream::from(quote! {
|
||||
impl #impl_generics ::core::cmp::PartialEq for #ident #ty_generics #where_clause {
|
||||
@ -460,7 +502,8 @@ pub fn derive_byte_hash(
|
||||
let input = parse_macro_input!(input as DeriveInput);
|
||||
let crate_name = bytemuck_crate_name(&input);
|
||||
let ident = input.ident;
|
||||
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
|
||||
let (impl_generics, ty_generics, where_clause) =
|
||||
input.generics.split_for_impl();
|
||||
|
||||
proc_macro::TokenStream::from(quote! {
|
||||
impl #impl_generics ::core::hash::Hash for #ident #ty_generics #where_clause {
|
||||
@ -569,19 +612,18 @@ fn derive_marker_trait_inner<Trait: Derivable>(
|
||||
.flatten()
|
||||
.collect::<Vec<syn::WherePredicate>>();
|
||||
|
||||
let predicates = &mut input.generics.make_where_clause().predicates;
|
||||
|
||||
predicates.extend(explicit_bounds);
|
||||
|
||||
let fields = match &input.data {
|
||||
syn::Data::Struct(syn::DataStruct { fields, .. }) => fields.clone(),
|
||||
syn::Data::Union(_) => {
|
||||
let fields = match (Trait::perfect_derive_fields(&input), &input.data) {
|
||||
(Some(fields), _) => fields,
|
||||
(None, syn::Data::Struct(syn::DataStruct { fields, .. })) => {
|
||||
fields.clone()
|
||||
}
|
||||
(None, syn::Data::Union(_)) => {
|
||||
return Err(syn::Error::new_spanned(
|
||||
trait_,
|
||||
&"perfect derive is not supported for unions",
|
||||
));
|
||||
}
|
||||
syn::Data::Enum(_) => {
|
||||
(None, syn::Data::Enum(_)) => {
|
||||
return Err(syn::Error::new_spanned(
|
||||
trait_,
|
||||
&"perfect derive is not supported for enums",
|
||||
@ -589,6 +631,10 @@ fn derive_marker_trait_inner<Trait: Derivable>(
|
||||
}
|
||||
};
|
||||
|
||||
let predicates = &mut input.generics.make_where_clause().predicates;
|
||||
|
||||
predicates.extend(explicit_bounds);
|
||||
|
||||
for field in fields {
|
||||
let ty = field.ty;
|
||||
predicates.push(syn::parse_quote!(
|
||||
|
@ -44,6 +44,14 @@ pub trait Derivable {
|
||||
fn explicit_bounds_attribute_name() -> Option<&'static str> {
|
||||
None
|
||||
}
|
||||
|
||||
/// If this trait has a custom meaning for "perfect derive", this function
|
||||
/// should be overridden to return `Some`.
|
||||
///
|
||||
/// The default is "the fields of a struct; unions and enums not supported".
|
||||
fn perfect_derive_fields(_input: &DeriveInput) -> Option<Fields> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Pod;
|
||||
@ -76,8 +84,11 @@ impl Derivable for Pod {
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let assert_fields_are_pod =
|
||||
generate_fields_are_trait(input, Self::ident(input, crate_name)?)?;
|
||||
let assert_fields_are_pod = generate_fields_are_trait(
|
||||
input,
|
||||
None,
|
||||
Self::ident(input, crate_name)?,
|
||||
)?;
|
||||
|
||||
Ok(quote!(
|
||||
#assert_no_padding
|
||||
@ -118,7 +129,7 @@ impl Derivable for AnyBitPattern {
|
||||
match &input.data {
|
||||
Data::Union(_) => Ok(quote!()), // unions are always `AnyBitPattern`
|
||||
Data::Struct(_) => {
|
||||
generate_fields_are_trait(input, Self::ident(input, crate_name)?)
|
||||
generate_fields_are_trait(input, None, Self::ident(input, crate_name)?)
|
||||
}
|
||||
Data::Enum(_) => {
|
||||
bail!("Deriving AnyBitPattern is not supported for enums")
|
||||
@ -129,6 +140,21 @@ impl Derivable for AnyBitPattern {
|
||||
|
||||
pub struct Zeroable;
|
||||
|
||||
/// Helper function to get the variant with discriminant zero (implicit or
|
||||
/// explicit).
|
||||
fn get_zero_variant(enum_: &DataEnum) -> Result<Option<&Variant>> {
|
||||
let iter = VariantDiscriminantIterator::new(enum_.variants.iter());
|
||||
let mut zero_variant = None;
|
||||
for res in iter {
|
||||
let (discriminant, variant) = res?;
|
||||
if discriminant == 0 {
|
||||
zero_variant = Some(variant);
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(zero_variant)
|
||||
}
|
||||
|
||||
impl Derivable for Zeroable {
|
||||
fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
|
||||
Ok(syn::parse_quote!(#crate_name::Zeroable))
|
||||
@ -138,27 +164,16 @@ impl Derivable for Zeroable {
|
||||
let repr = get_repr(attributes)?;
|
||||
match ty {
|
||||
Data::Struct(_) => Ok(()),
|
||||
Data::Enum(DataEnum { variants, .. }) => {
|
||||
if !repr.repr.is_integer() {
|
||||
bail!("Zeroable requires the enum to be an explicit #[repr(Int)]")
|
||||
Data::Enum(_) => {
|
||||
if !matches!(
|
||||
repr.repr,
|
||||
Repr::C | Repr::Integer(_) | Repr::CWithDiscriminant(_)
|
||||
) {
|
||||
bail!("Zeroable requires the enum to be an explicit #[repr(Int)] and/or #[repr(C)]")
|
||||
}
|
||||
|
||||
if variants.iter().any(|variant| !variant.fields.is_empty()) {
|
||||
bail!("Only fieldless enums are supported for Zeroable")
|
||||
}
|
||||
|
||||
let iter = VariantDiscriminantIterator::new(variants.iter());
|
||||
let mut has_zero_variant = false;
|
||||
for res in iter {
|
||||
let discriminant = res?;
|
||||
if discriminant == 0 {
|
||||
has_zero_variant = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if !has_zero_variant {
|
||||
bail!("No variant's discriminant is 0")
|
||||
}
|
||||
// We ensure there is a zero variant in `asserts`, since it is needed
|
||||
// there anyway.
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -172,15 +187,40 @@ impl Derivable for Zeroable {
|
||||
match &input.data {
|
||||
Data::Union(_) => Ok(quote!()), // unions are always `Zeroable`
|
||||
Data::Struct(_) => {
|
||||
generate_fields_are_trait(input, Self::ident(input, crate_name)?)
|
||||
generate_fields_are_trait(input, None, Self::ident(input, crate_name)?)
|
||||
}
|
||||
Data::Enum(enum_) => {
|
||||
let zero_variant = get_zero_variant(enum_)?;
|
||||
|
||||
if zero_variant.is_none() {
|
||||
bail!("No variant's discriminant is 0")
|
||||
};
|
||||
|
||||
generate_fields_are_trait(
|
||||
input,
|
||||
zero_variant,
|
||||
Self::ident(input, crate_name)?,
|
||||
)
|
||||
}
|
||||
Data::Enum(_) => Ok(quote!()),
|
||||
}
|
||||
}
|
||||
|
||||
fn explicit_bounds_attribute_name() -> Option<&'static str> {
|
||||
Some("zeroable")
|
||||
}
|
||||
|
||||
fn perfect_derive_fields(input: &DeriveInput) -> Option<Fields> {
|
||||
match &input.data {
|
||||
Data::Struct(struct_) => Some(struct_.fields.clone()),
|
||||
Data::Enum(enum_) => {
|
||||
// We handle `Err` returns from `get_zero_variant` in `asserts`, so it's
|
||||
// fine to just ignore them here and return `None`.
|
||||
// Otherwise, we clone the `fields` of the zero variant (if any).
|
||||
Some(get_zero_variant(enum_).ok()??.fields.clone())
|
||||
}
|
||||
Data::Union(_) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct NoUninit;
|
||||
@ -216,8 +256,11 @@ impl Derivable for NoUninit {
|
||||
match &input.data {
|
||||
Data::Struct(DataStruct { .. }) => {
|
||||
let assert_no_padding = generate_assert_no_padding(&input)?;
|
||||
let assert_fields_are_no_padding =
|
||||
generate_fields_are_trait(&input, Self::ident(input, crate_name)?)?;
|
||||
let assert_fields_are_no_padding = generate_fields_are_trait(
|
||||
&input,
|
||||
None,
|
||||
Self::ident(input, crate_name)?,
|
||||
)?;
|
||||
|
||||
Ok(quote!(
|
||||
#assert_no_padding
|
||||
@ -282,13 +325,16 @@ impl Derivable for CheckedBitPattern {
|
||||
|
||||
match &input.data {
|
||||
Data::Struct(DataStruct { .. }) => {
|
||||
let assert_fields_are_maybe_pod =
|
||||
generate_fields_are_trait(&input, Self::ident(input, crate_name)?)?;
|
||||
let assert_fields_are_maybe_pod = generate_fields_are_trait(
|
||||
&input,
|
||||
None,
|
||||
Self::ident(input, crate_name)?,
|
||||
)?;
|
||||
|
||||
Ok(assert_fields_are_maybe_pod)
|
||||
}
|
||||
Data::Enum(_) => Ok(quote!()), /* nothing needed, already guaranteed
|
||||
* OK by NoUninit */
|
||||
// nothing needed, already guaranteed OK by NoUninit.
|
||||
Data::Enum(_) => Ok(quote!()),
|
||||
Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), /* shouldn't be possible since we already error in attribute check for this case */
|
||||
}
|
||||
}
|
||||
@ -439,16 +485,16 @@ impl Derivable for Contiguous {
|
||||
));
|
||||
}
|
||||
|
||||
let mut variants_with_discriminator =
|
||||
let mut variants_with_discriminant =
|
||||
VariantDiscriminantIterator::new(variants);
|
||||
|
||||
let (min, max, count) = variants_with_discriminator.try_fold(
|
||||
(i64::max_value(), i64::min_value(), 0),
|
||||
let (min, max, count) = variants_with_discriminant.try_fold(
|
||||
(i128::MAX, i128::MIN, 0),
|
||||
|(min, max, count), res| {
|
||||
let discriminator = res?;
|
||||
let (discriminant, _variant) = res?;
|
||||
Ok::<_, Error>((
|
||||
i64::min(min, discriminator),
|
||||
i64::max(max, discriminator),
|
||||
i128::min(min, discriminant),
|
||||
i128::max(max, discriminant),
|
||||
count + 1,
|
||||
))
|
||||
},
|
||||
@ -505,11 +551,21 @@ fn get_struct_fields(input: &DeriveInput) -> Result<&Fields> {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_fields(input: &DeriveInput) -> Result<Fields> {
|
||||
/// Extract the `Fields` off a `DeriveInput`, or, in the `enum` case, off
|
||||
/// those of the `enum_variant`, when provided (e.g., for `Zeroable`).
|
||||
///
|
||||
/// We purposely allow not providing an `enum_variant` for cases where
|
||||
/// the caller wants to reject supporting `enum`s (e.g., `NoPadding`).
|
||||
fn get_fields(
|
||||
input: &DeriveInput, enum_variant: Option<&Variant>,
|
||||
) -> Result<Fields> {
|
||||
match &input.data {
|
||||
Data::Struct(DataStruct { fields, .. }) => Ok(fields.clone()),
|
||||
Data::Union(DataUnion { fields, .. }) => Ok(Fields::Named(fields.clone())),
|
||||
Data::Enum(_) => bail!("deriving this trait is not supported for enums"),
|
||||
Data::Enum(_) => match enum_variant {
|
||||
Some(variant) => Ok(variant.fields.clone()),
|
||||
None => bail!("deriving this trait is not supported for enums"),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@ -596,12 +652,12 @@ fn generate_checked_bit_pattern_enum_without_fields(
|
||||
VariantDiscriminantIterator::new(variants.iter());
|
||||
|
||||
let (min, max, count) = variants_with_discriminant.try_fold(
|
||||
(i64::max_value(), i64::min_value(), 0),
|
||||
(i128::MAX, i128::MIN, 0),
|
||||
|(min, max, count), res| {
|
||||
let discriminant = res?;
|
||||
let (discriminant, _variant) = res?;
|
||||
Ok::<_, Error>((
|
||||
i64::min(min, discriminant),
|
||||
i64::max(max, discriminant),
|
||||
i128::min(min, discriminant),
|
||||
i128::max(max, discriminant),
|
||||
count + 1,
|
||||
))
|
||||
},
|
||||
@ -617,16 +673,17 @@ fn generate_checked_bit_pattern_enum_without_fields(
|
||||
quote!(*bits >= #min_lit && *bits <= #max_lit)
|
||||
} else {
|
||||
// not contiguous range, check for each
|
||||
let variant_lits = VariantDiscriminantIterator::new(variants.iter())
|
||||
.map(|res| {
|
||||
let variant = res?;
|
||||
Ok(LitInt::new(&format!("{}", variant), span))
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let variant_discriminant_lits =
|
||||
VariantDiscriminantIterator::new(variants.iter())
|
||||
.map(|res| {
|
||||
let (discriminant, _variant) = res?;
|
||||
Ok(LitInt::new(&format!("{}", discriminant), span))
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
// count is at least 1
|
||||
let first = &variant_lits[0];
|
||||
let rest = &variant_lits[1..];
|
||||
let first = &variant_discriminant_lits[0];
|
||||
let rest = &variant_discriminant_lits[1..];
|
||||
|
||||
quote!(matches!(*bits, #first #(| #rest )*))
|
||||
};
|
||||
@ -720,7 +777,7 @@ fn generate_checked_bit_pattern_enum_with_fields(
|
||||
.zip(VariantDiscriminantIterator::new(variants.iter()))
|
||||
.zip(variants.iter())
|
||||
.map(|((variant_struct_ident, discriminant), v)| -> Result<_> {
|
||||
let discriminant = discriminant?;
|
||||
let (discriminant, _variant) = discriminant?;
|
||||
let discriminant = LitInt::new(&discriminant.to_string(), v.span());
|
||||
let ident = &v.ident;
|
||||
Ok(quote! {
|
||||
@ -850,7 +907,7 @@ fn generate_checked_bit_pattern_enum_with_fields(
|
||||
.zip(VariantDiscriminantIterator::new(variants.iter()))
|
||||
.zip(variants.iter())
|
||||
.map(|((variant_struct_ident, discriminant), v)| -> Result<_> {
|
||||
let discriminant = discriminant?;
|
||||
let (discriminant, _variant) = discriminant?;
|
||||
let discriminant = LitInt::new(&discriminant.to_string(), v.span());
|
||||
let ident = &v.ident;
|
||||
Ok(quote! {
|
||||
@ -906,7 +963,8 @@ fn generate_checked_bit_pattern_enum_with_fields(
|
||||
fn generate_assert_no_padding(input: &DeriveInput) -> Result<TokenStream> {
|
||||
let struct_type = &input.ident;
|
||||
let span = input.ident.span();
|
||||
let fields = get_fields(input)?;
|
||||
let enum_variant = None; // `no padding` check is not supported for `enum`s yet.
|
||||
let fields = get_fields(input, enum_variant)?;
|
||||
|
||||
let mut field_types = get_field_types(&fields);
|
||||
let size_sum = if let Some(first) = field_types.next() {
|
||||
@ -928,11 +986,11 @@ fn generate_assert_no_padding(input: &DeriveInput) -> Result<TokenStream> {
|
||||
|
||||
/// Check that all fields implement a given trait
|
||||
fn generate_fields_are_trait(
|
||||
input: &DeriveInput, trait_: syn::Path,
|
||||
input: &DeriveInput, enum_variant: Option<&Variant>, trait_: syn::Path,
|
||||
) -> Result<TokenStream> {
|
||||
let (impl_generics, _ty_generics, where_clause) =
|
||||
input.generics.split_for_impl();
|
||||
let fields = get_fields(input)?;
|
||||
let fields = get_fields(input, enum_variant)?;
|
||||
let span = input.span();
|
||||
let field_types = get_field_types(&fields);
|
||||
Ok(quote_spanned! {span => #(const _: fn() = || {
|
||||
@ -1186,7 +1244,7 @@ fn enum_has_fields<'a>(
|
||||
|
||||
struct VariantDiscriminantIterator<'a, I: Iterator<Item = &'a Variant> + 'a> {
|
||||
inner: I,
|
||||
last_value: i64,
|
||||
last_value: i128,
|
||||
}
|
||||
|
||||
impl<'a, I: Iterator<Item = &'a Variant> + 'a>
|
||||
@ -1200,7 +1258,7 @@ impl<'a, I: Iterator<Item = &'a Variant> + 'a>
|
||||
impl<'a, I: Iterator<Item = &'a Variant> + 'a> Iterator
|
||||
for VariantDiscriminantIterator<'a, I>
|
||||
{
|
||||
type Item = Result<i64>;
|
||||
type Item = Result<(i128, &'a Variant)>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
let variant = self.inner.next()?;
|
||||
@ -1212,14 +1270,38 @@ impl<'a, I: Iterator<Item = &'a Variant> + 'a> Iterator
|
||||
};
|
||||
self.last_value = discriminant_value;
|
||||
} else {
|
||||
self.last_value += 1;
|
||||
// If this wraps, then either:
|
||||
// 1. the enum is using repr(u128), so wrapping is correct
|
||||
// 2. the enum is using repr(i<=128 or u<128), so the compiler will
|
||||
// already emit a "wrapping discriminant" E0370 error.
|
||||
self.last_value = self.last_value.wrapping_add(1);
|
||||
// Static assert that there is no integer repr > 128 bits. If that
|
||||
// changes, the above comment is inaccurate and needs to be updated!
|
||||
// FIXME(zachs18): maybe should also do something to ensure `isize::BITS
|
||||
// <= 128`?
|
||||
if let Some(repr) = None::<IntegerRepr> {
|
||||
match repr {
|
||||
IntegerRepr::U8
|
||||
| IntegerRepr::I8
|
||||
| IntegerRepr::U16
|
||||
| IntegerRepr::I16
|
||||
| IntegerRepr::U32
|
||||
| IntegerRepr::I32
|
||||
| IntegerRepr::U64
|
||||
| IntegerRepr::I64
|
||||
| IntegerRepr::I128
|
||||
| IntegerRepr::U128
|
||||
| IntegerRepr::Usize
|
||||
| IntegerRepr::Isize => (),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some(Ok(self.last_value))
|
||||
Some(Ok((self.last_value, variant)))
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_int_expr(expr: &Expr) -> Result<i64> {
|
||||
fn parse_int_expr(expr: &Expr) -> Result<i128> {
|
||||
match expr {
|
||||
Expr::Unary(ExprUnary { op: UnOp::Neg(_), expr, .. }) => {
|
||||
parse_int_expr(expr).map(|int| -int)
|
||||
|
@ -1,8 +1,8 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
use bytemuck::{
|
||||
AnyBitPattern, CheckedBitPattern, Contiguous, NoUninit, Pod,
|
||||
TransparentWrapper, Zeroable, checked::CheckedCastError,
|
||||
checked::CheckedCastError, AnyBitPattern, CheckedBitPattern, Contiguous,
|
||||
NoUninit, Pod, TransparentWrapper, Zeroable,
|
||||
};
|
||||
use std::marker::{PhantomData, PhantomPinned};
|
||||
|
||||
@ -58,6 +58,45 @@ enum ZeroEnum {
|
||||
C = 2,
|
||||
}
|
||||
|
||||
#[derive(Zeroable)]
|
||||
#[repr(u8)]
|
||||
enum BasicFieldfulZeroEnum {
|
||||
A(u8) = 0,
|
||||
B = 1,
|
||||
C(String) = 2,
|
||||
}
|
||||
|
||||
#[derive(Zeroable)]
|
||||
#[repr(C)]
|
||||
enum ReprCFieldfulZeroEnum {
|
||||
A(u8),
|
||||
B(Box<[u8]>),
|
||||
C,
|
||||
}
|
||||
|
||||
#[derive(Zeroable)]
|
||||
#[repr(C, i32)]
|
||||
enum ReprCIntFieldfulZeroEnum {
|
||||
B(String) = 1,
|
||||
A(u8, bool, char) = 0,
|
||||
C = 2,
|
||||
}
|
||||
|
||||
#[derive(Zeroable)]
|
||||
#[repr(i32)]
|
||||
enum GenericFieldfulZeroEnum<T> {
|
||||
A(Box<T>) = 1,
|
||||
B(T, T) = 0,
|
||||
}
|
||||
|
||||
#[derive(Zeroable)]
|
||||
#[repr(i32)]
|
||||
#[zeroable(bound = "")]
|
||||
enum GenericCustomBoundFieldfulZeroEnum<T> {
|
||||
A(Option<Box<T>>),
|
||||
B(String),
|
||||
}
|
||||
|
||||
#[derive(TransparentWrapper)]
|
||||
#[repr(transparent)]
|
||||
struct TransparentSingle {
|
||||
@ -202,8 +241,10 @@ enum CheckedBitPatternTransparentEnumWithFields {
|
||||
}
|
||||
|
||||
// size 24, align 8.
|
||||
// first byte always the u8 discriminant, then 7 bytes of padding until the payload union since the align of the payload
|
||||
// is the greatest of the align of all the variants, which is 8 (from CheckedBitPatternCDefaultDiscriminantEnumWithFields)
|
||||
// first byte always the u8 discriminant, then 7 bytes of padding until the
|
||||
// payload union since the align of the payload is the greatest of the align of
|
||||
// all the variants, which is 8 (from
|
||||
// CheckedBitPatternCDefaultDiscriminantEnumWithFields)
|
||||
#[derive(Debug, Clone, Copy, CheckedBitPattern, PartialEq, Eq)]
|
||||
#[repr(C, u8)]
|
||||
enum CheckedBitPatternEnumNested {
|
||||
@ -388,52 +429,68 @@ fn checkedbitpattern_nested_enum_with_fields() {
|
||||
|
||||
// first we'll check variantA, nested variant A
|
||||
let pod = Align8Bytes([
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // byte 0 discriminant = 0 = variant A, bytes 1-7 irrelevant padding.
|
||||
0x00, 0x00, 0x00, 0x00, 0xcc, 0x55, 0x55, 0xcc, // bytes 8-15 are the nested CheckedBitPatternCEnumWithFields,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, // byte 0 discriminant = 0 = variant A, bytes 1-7 irrelevant padding.
|
||||
0x00, 0x00, 0x00, 0x00, 0xcc, 0x55, 0x55,
|
||||
0xcc, // bytes 8-15 are the nested CheckedBitPatternCEnumWithFields,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // bytes 16-23 padding
|
||||
]);
|
||||
let value = bytemuck::checked::from_bytes::<
|
||||
CheckedBitPatternEnumNested,
|
||||
>(&pod.0);
|
||||
assert_eq!(value, &CheckedBitPatternEnumNested::A(CheckedBitPatternCEnumWithFields::A(0xcc5555cc)));
|
||||
let value =
|
||||
bytemuck::checked::from_bytes::<CheckedBitPatternEnumNested>(&pod.0);
|
||||
assert_eq!(
|
||||
value,
|
||||
&CheckedBitPatternEnumNested::A(CheckedBitPatternCEnumWithFields::A(
|
||||
0xcc5555cc
|
||||
))
|
||||
);
|
||||
|
||||
// next we'll check invalid first discriminant fails
|
||||
let pod = Align8Bytes([
|
||||
0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // byte 0 discriminant = 2 = invalid, bytes 1-7 padding
|
||||
0x00, 0x00, 0x00, 0x00, 0xcc, 0x55, 0x55, 0xcc, // bytes 8-15 are the nested CheckedBitPatternCEnumWithFields = A,
|
||||
0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, // byte 0 discriminant = 2 = invalid, bytes 1-7 padding
|
||||
0x00, 0x00, 0x00, 0x00, 0xcc, 0x55, 0x55,
|
||||
0xcc, // bytes 8-15 are the nested CheckedBitPatternCEnumWithFields = A,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // bytes 16-23 padding
|
||||
]);
|
||||
let result = bytemuck::checked::try_from_bytes::<
|
||||
CheckedBitPatternEnumNested,
|
||||
>(&pod.0);
|
||||
let result =
|
||||
bytemuck::checked::try_from_bytes::<CheckedBitPatternEnumNested>(&pod.0);
|
||||
assert_eq!(result, Err(CheckedCastError::InvalidBitPattern));
|
||||
|
||||
|
||||
// next we'll check variant B, nested variant B
|
||||
let pod = Align8Bytes([
|
||||
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // byte 0 discriminant = 1 = variant B, bytes 1-7 padding
|
||||
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // bytes 8-15 is C int size discriminant of CheckedBitPatternCDefaultDiscrimimantEnumWithFields, 1 (LE byte order) = variant B
|
||||
0xcc, 0x55, 0x55, 0x55, 0x55, 0x55, 0x55, 0xcc, // bytes 16-13 is the data contained in nested variant B
|
||||
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, // byte 0 discriminant = 1 = variant B, bytes 1-7 padding
|
||||
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, /* bytes 8-15 is C int size discriminant of
|
||||
* CheckedBitPatternCDefaultDiscrimimantEnumWithFields, 1 (LE byte
|
||||
* order) = variant B */
|
||||
0xcc, 0x55, 0x55, 0x55, 0x55, 0x55, 0x55,
|
||||
0xcc, // bytes 16-13 is the data contained in nested variant B
|
||||
]);
|
||||
let value = bytemuck::checked::from_bytes::<
|
||||
CheckedBitPatternEnumNested,
|
||||
>(&pod.0);
|
||||
let value =
|
||||
bytemuck::checked::from_bytes::<CheckedBitPatternEnumNested>(&pod.0);
|
||||
assert_eq!(
|
||||
value,
|
||||
&CheckedBitPatternEnumNested::B(CheckedBitPatternCDefaultDiscriminantEnumWithFields::B {
|
||||
c: 0xcc555555555555cc
|
||||
})
|
||||
&CheckedBitPatternEnumNested::B(
|
||||
CheckedBitPatternCDefaultDiscriminantEnumWithFields::B {
|
||||
c: 0xcc555555555555cc
|
||||
}
|
||||
)
|
||||
);
|
||||
|
||||
// finally we'll check variant B, nested invalid discriminant
|
||||
let pod = Align8Bytes([
|
||||
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // 1 discriminant = variant B, bytes 1-7 padding
|
||||
0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // bytes 8-15 is C int size discriminant of CheckedBitPatternCDefaultDiscrimimantEnumWithFields, 0x08 is invalid
|
||||
0xcc, 0x55, 0x55, 0x55, 0x55, 0x55, 0x55, 0xcc, // bytes 16-13 is the data contained in nested variant B
|
||||
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, // 1 discriminant = variant B, bytes 1-7 padding
|
||||
0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, /* bytes 8-15 is C int size discriminant of
|
||||
* CheckedBitPatternCDefaultDiscrimimantEnumWithFields, 0x08 is
|
||||
* invalid */
|
||||
0xcc, 0x55, 0x55, 0x55, 0x55, 0x55, 0x55,
|
||||
0xcc, // bytes 16-13 is the data contained in nested variant B
|
||||
]);
|
||||
let result = bytemuck::checked::try_from_bytes::<
|
||||
CheckedBitPatternEnumNested,
|
||||
>(&pod.0);
|
||||
let result =
|
||||
bytemuck::checked::try_from_bytes::<CheckedBitPatternEnumNested>(&pod.0);
|
||||
assert_eq!(result, Err(CheckedCastError::InvalidBitPattern));
|
||||
}
|
||||
#[test]
|
||||
@ -457,4 +514,3 @@ use bytemuck as reexport_name;
|
||||
#[bytemuck(crate = "reexport_name")]
|
||||
#[repr(C)]
|
||||
struct Issue93 {}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user