Allow repr(transparent) to be used generically in derive(Pod) (#139)

* Enabled transparent generics

* Move trait checks to implementation block

* Replace add_trait_marker impl
This commit is contained in:
John Nunley 2022-11-03 06:53:01 -07:00 committed by GitHub
parent 7b67524a43
commit 518baf9c0b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 149 additions and 116 deletions

View File

@ -9,7 +9,8 @@ use quote::quote;
use syn::{parse_macro_input, DeriveInput, Result}; use syn::{parse_macro_input, DeriveInput, Result};
use crate::traits::{ use crate::traits::{
AnyBitPattern, Contiguous, Derivable, CheckedBitPattern, NoUninit, Pod, TransparentWrapper, Zeroable, AnyBitPattern, CheckedBitPattern, Contiguous, Derivable, NoUninit, Pod,
TransparentWrapper, Zeroable,
}; };
/// Derive the `Pod` trait for a struct /// Derive the `Pod` trait for a struct
@ -56,8 +57,9 @@ pub fn derive_pod(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
pub fn derive_anybitpattern( pub fn derive_anybitpattern(
input: proc_macro::TokenStream, input: proc_macro::TokenStream,
) -> proc_macro::TokenStream { ) -> proc_macro::TokenStream {
let expanded = let expanded = derive_marker_trait::<AnyBitPattern>(parse_macro_input!(
derive_marker_trait::<AnyBitPattern>(parse_macro_input!(input as DeriveInput)); input as DeriveInput
));
proc_macro::TokenStream::from(expanded) proc_macro::TokenStream::from(expanded)
} }
@ -99,8 +101,8 @@ pub fn derive_zeroable(
/// for the `NoUninit` trait. /// for the `NoUninit` trait.
/// ///
/// The following constraints need to be satisfied for the macro to succeed /// The following constraints need to be satisfied for the macro to succeed
/// (the rest of the constraints are guaranteed by the `NoUninit` subtrait bounds, /// (the rest of the constraints are guaranteed by the `NoUninit` subtrait
/// i.e. the type must be `Sized + Copy + 'static`): /// bounds, i.e. the type must be `Sized + Copy + 'static`):
/// ///
/// If applied to a struct: /// If applied to a struct:
/// - All fields in the struct must implement `NoUninit` /// - All fields in the struct must implement `NoUninit`
@ -129,9 +131,9 @@ pub fn derive_no_uninit(
/// definition and `is_valid_bit_pattern` method for the type automatically. /// definition and `is_valid_bit_pattern` method for the type automatically.
/// ///
/// The following constraints need to be satisfied for the macro to succeed /// The following constraints need to be satisfied for the macro to succeed
/// (the rest of the constraints are guaranteed by the `CheckedBitPattern` subtrait bounds, /// (the rest of the constraints are guaranteed by the `CheckedBitPattern`
/// i.e. are guaranteed by the requirements of the `NoUninit` trait which `CheckedBitPattern` /// subtrait bounds, i.e. are guaranteed by the requirements of the `NoUninit`
/// is a subtrait of): /// trait which `CheckedBitPattern` is a subtrait of):
/// ///
/// If applied to a struct: /// If applied to a struct:
/// - All fields must implement `CheckedBitPattern` /// - All fields must implement `CheckedBitPattern`
@ -142,8 +144,9 @@ pub fn derive_no_uninit(
pub fn derive_maybe_pod( pub fn derive_maybe_pod(
input: proc_macro::TokenStream, input: proc_macro::TokenStream,
) -> proc_macro::TokenStream { ) -> proc_macro::TokenStream {
let expanded = let expanded = derive_marker_trait::<CheckedBitPattern>(parse_macro_input!(
derive_marker_trait::<CheckedBitPattern>(parse_macro_input!(input as DeriveInput)); input as DeriveInput
));
proc_macro::TokenStream::from(expanded) proc_macro::TokenStream::from(expanded)
} }
@ -228,17 +231,19 @@ fn derive_marker_trait<Trait: Derivable>(input: DeriveInput) -> TokenStream {
} }
fn derive_marker_trait_inner<Trait: Derivable>( fn derive_marker_trait_inner<Trait: Derivable>(
input: DeriveInput, mut input: DeriveInput,
) -> Result<TokenStream> { ) -> Result<TokenStream> {
// Enforce Pod on all generic fields.
let trait_ = Trait::ident(&input)?;
add_trait_marker(&mut input.generics, &trait_);
let name = &input.ident; let name = &input.ident;
let (impl_generics, ty_generics, where_clause) = let (impl_generics, ty_generics, where_clause) =
input.generics.split_for_impl(); input.generics.split_for_impl();
let trait_ = Trait::ident();
Trait::check_attributes(&input.data, &input.attrs)?; Trait::check_attributes(&input.data, &input.attrs)?;
let asserts = Trait::asserts(&input)?; let asserts = Trait::asserts(&input)?;
let trait_params = Trait::generic_params(&input)?;
let (trait_impl_extras, trait_impl) = Trait::trait_impl(&input)?; let (trait_impl_extras, trait_impl) = Trait::trait_impl(&input)?;
let implies_trait = if let Some(implies_trait) = Trait::implies_trait() { let implies_trait = if let Some(implies_trait) = Trait::implies_trait() {
@ -252,10 +257,23 @@ fn derive_marker_trait_inner<Trait: Derivable>(
#trait_impl_extras #trait_impl_extras
unsafe impl #impl_generics #trait_ #trait_params for #name #ty_generics #where_clause { unsafe impl #impl_generics #trait_ for #name #ty_generics #where_clause {
#trait_impl #trait_impl
} }
#implies_trait #implies_trait
}) })
} }
/// Add a trait marker to the generics if it is not already present
fn add_trait_marker(generics: &mut syn::Generics, trait_name: &syn::Path) {
// Get each generic type parameter.
let type_params = generics
.type_params()
.map(|param| &param.ident)
.map(|param| syn::parse_quote!(
#param: #trait_name
)).collect::<Vec<syn::WherePredicate>>();
generics.make_where_clause().predicates.extend(type_params);
}

View File

@ -1,42 +1,35 @@
#![allow(unused_imports)] #![allow(unused_imports)]
use proc_macro2::{Ident, Span, TokenStream, TokenTree}; use proc_macro2::{Ident, Span, TokenStream, TokenTree};
use quote::{quote, quote_spanned, ToTokens}; use quote::{quote, quote_spanned, ToTokens};
use syn::{*, use syn::{
parse::{Parse, Parser, ParseStream}, parse::{Parse, ParseStream, Parser},
punctuated::Punctuated, punctuated::Punctuated,
spanned::Spanned, spanned::Spanned,
Result, Result, *,
}; };
macro_rules! bail { macro_rules! bail {
($msg:expr $(,)?) => ( ($msg:expr $(,)?) => {
return Err(Error::new(Span::call_site(), &$msg[..])) return Err(Error::new(Span::call_site(), &$msg[..]))
); };
( $msg:expr => $span_to_blame:expr $(,)? ) => ( ( $msg:expr => $span_to_blame:expr $(,)? ) => {
return Err(Error::new_spanned(&$span_to_blame, $msg)) return Err(Error::new_spanned(&$span_to_blame, $msg))
); };
} }
pub trait Derivable { pub trait Derivable {
fn ident() -> TokenStream; fn ident(input: &DeriveInput) -> Result<syn::Path>;
fn implies_trait() -> Option<TokenStream> { fn implies_trait() -> Option<TokenStream> {
None None
} }
fn generic_params(_input: &DeriveInput) -> Result<TokenStream> {
Ok(quote!())
}
fn asserts(_input: &DeriveInput) -> Result<TokenStream> { fn asserts(_input: &DeriveInput) -> Result<TokenStream> {
Ok(quote!()) Ok(quote!())
} }
fn check_attributes( fn check_attributes(_ty: &Data, _attributes: &[Attribute]) -> Result<()> {
_ty: &Data, _attributes: &[Attribute],
) -> Result<()> {
Ok(()) Ok(())
} }
fn trait_impl( fn trait_impl(_input: &DeriveInput) -> Result<(TokenStream, TokenStream)> {
_input: &DeriveInput,
) -> Result<(TokenStream, TokenStream)> {
Ok((quote!(), quote!())) Ok((quote!(), quote!()))
} }
} }
@ -44,14 +37,15 @@ pub trait Derivable {
pub struct Pod; pub struct Pod;
impl Derivable for Pod { impl Derivable for Pod {
fn ident() -> TokenStream { fn ident(_: &DeriveInput) -> Result<syn::Path> {
quote!(::bytemuck::Pod) Ok(syn::parse_quote!(::bytemuck::Pod))
} }
fn asserts(input: &DeriveInput) -> Result<TokenStream> { fn asserts(input: &DeriveInput) -> Result<TokenStream> {
let repr = get_repr(&input.attrs)?; let repr = get_repr(&input.attrs)?;
let completly_packed = repr.packed == Some(1); let completly_packed =
repr.packed == Some(1) || repr.repr == Repr::Transparent;
if !completly_packed && !input.generics.params.is_empty() { if !completly_packed && !input.generics.params.is_empty() {
bail!("\ bail!("\
@ -69,7 +63,7 @@ impl Derivable for Pod {
None None
}; };
let assert_fields_are_pod = let assert_fields_are_pod =
generate_fields_are_trait(input, Self::ident())?; generate_fields_are_trait(input, Self::ident(input)?)?;
Ok(quote!( Ok(quote!(
#assert_no_padding #assert_no_padding
@ -81,9 +75,7 @@ impl Derivable for Pod {
} }
} }
fn check_attributes( fn check_attributes(_ty: &Data, attributes: &[Attribute]) -> Result<()> {
_ty: &Data, attributes: &[Attribute],
) -> Result<()> {
let repr = get_repr(attributes)?; let repr = get_repr(attributes)?;
match repr.repr { match repr.repr {
Repr::C => Ok(()), Repr::C => Ok(()),
@ -98,8 +90,8 @@ impl Derivable for Pod {
pub struct AnyBitPattern; pub struct AnyBitPattern;
impl Derivable for AnyBitPattern { impl Derivable for AnyBitPattern {
fn ident() -> TokenStream { fn ident(_: &DeriveInput) -> Result<syn::Path> {
quote!(::bytemuck::AnyBitPattern) Ok(syn::parse_quote!(::bytemuck::AnyBitPattern))
} }
fn implies_trait() -> Option<TokenStream> { fn implies_trait() -> Option<TokenStream> {
@ -109,8 +101,10 @@ impl Derivable for AnyBitPattern {
fn asserts(input: &DeriveInput) -> Result<TokenStream> { fn asserts(input: &DeriveInput) -> Result<TokenStream> {
match &input.data { match &input.data {
Data::Union(_) => Ok(quote!()), // unions are always `AnyBitPattern` Data::Union(_) => Ok(quote!()), // unions are always `AnyBitPattern`
Data::Struct(_) => generate_fields_are_trait(input, Self::ident()), Data::Struct(_) => generate_fields_are_trait(input, Self::ident(input)?),
Data::Enum(_) => bail!("Deriving AnyBitPattern is not supported for enums"), Data::Enum(_) => {
bail!("Deriving AnyBitPattern is not supported for enums")
}
} }
} }
} }
@ -118,14 +112,14 @@ impl Derivable for AnyBitPattern {
pub struct Zeroable; pub struct Zeroable;
impl Derivable for Zeroable { impl Derivable for Zeroable {
fn ident() -> TokenStream { fn ident(_: &DeriveInput) -> Result<syn::Path> {
quote!(::bytemuck::Zeroable) Ok(syn::parse_quote!(::bytemuck::Zeroable))
} }
fn asserts(input: &DeriveInput) -> Result<TokenStream> { fn asserts(input: &DeriveInput) -> Result<TokenStream> {
match &input.data { match &input.data {
Data::Union(_) => Ok(quote!()), // unions are always `Zeroable` Data::Union(_) => Ok(quote!()), // unions are always `Zeroable`
Data::Struct(_) => generate_fields_are_trait(input, Self::ident()), Data::Struct(_) => generate_fields_are_trait(input, Self::ident(input)?),
Data::Enum(_) => bail!("Deriving Zeroable is not supported for enums"), Data::Enum(_) => bail!("Deriving Zeroable is not supported for enums"),
} }
} }
@ -134,13 +128,11 @@ impl Derivable for Zeroable {
pub struct NoUninit; pub struct NoUninit;
impl Derivable for NoUninit { impl Derivable for NoUninit {
fn ident() -> TokenStream { fn ident(_: &DeriveInput) -> Result<syn::Path> {
quote!(::bytemuck::NoUninit) Ok(syn::parse_quote!(::bytemuck::NoUninit))
} }
fn check_attributes( fn check_attributes(ty: &Data, attributes: &[Attribute]) -> Result<()> {
ty: &Data, attributes: &[Attribute],
) -> Result<()> {
let repr = get_repr(attributes)?; let repr = get_repr(attributes)?;
match ty { match ty {
Data::Struct(_) => match repr.repr { Data::Struct(_) => match repr.repr {
@ -165,7 +157,7 @@ impl Derivable for NoUninit {
Data::Struct(DataStruct { .. }) => { Data::Struct(DataStruct { .. }) => {
let assert_no_padding = generate_assert_no_padding(&input)?; let assert_no_padding = generate_assert_no_padding(&input)?;
let assert_fields_are_no_padding = let assert_fields_are_no_padding =
generate_fields_are_trait(&input, Self::ident())?; generate_fields_are_trait(&input, Self::ident(input)?)?;
Ok(quote!( Ok(quote!(
#assert_no_padding #assert_no_padding
@ -179,13 +171,11 @@ impl Derivable for NoUninit {
Ok(quote!()) Ok(quote!())
} }
} }
Data::Union(_) => bail!("NoUninit cannot be derived for unions"), // shouldn't be possible since we already error in attribute check for this case Data::Union(_) => bail!("NoUninit cannot be derived for unions"), /* shouldn't be possible since we already error in attribute check for this case */
} }
} }
fn trait_impl( fn trait_impl(_input: &DeriveInput) -> Result<(TokenStream, TokenStream)> {
_input: &DeriveInput,
) -> Result<(TokenStream, TokenStream)> {
Ok((quote!(), quote!())) Ok((quote!(), quote!()))
} }
} }
@ -193,13 +183,11 @@ impl Derivable for NoUninit {
pub struct CheckedBitPattern; pub struct CheckedBitPattern;
impl Derivable for CheckedBitPattern { impl Derivable for CheckedBitPattern {
fn ident() -> TokenStream { fn ident(_: &DeriveInput) -> Result<syn::Path> {
quote!(::bytemuck::CheckedBitPattern) Ok(syn::parse_quote!(::bytemuck::CheckedBitPattern))
} }
fn check_attributes( fn check_attributes(ty: &Data, attributes: &[Attribute]) -> Result<()> {
ty: &Data, attributes: &[Attribute],
) -> Result<()> {
let repr = get_repr(attributes)?; let repr = get_repr(attributes)?;
match ty { match ty {
Data::Struct(_) => match repr.repr { Data::Struct(_) => match repr.repr {
@ -223,24 +211,23 @@ impl Derivable for CheckedBitPattern {
match &input.data { match &input.data {
Data::Struct(DataStruct { .. }) => { Data::Struct(DataStruct { .. }) => {
let assert_fields_are_maybe_pod = let assert_fields_are_maybe_pod =
generate_fields_are_trait(&input, Self::ident())?; generate_fields_are_trait(&input, Self::ident(input)?)?;
Ok(assert_fields_are_maybe_pod) Ok(assert_fields_are_maybe_pod)
} }
Data::Enum(_) => Ok(quote!()), // nothing needed, already guaranteed OK by NoUninit Data::Enum(_) => Ok(quote!()), /* nothing needed, already guaranteed
Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), // shouldn't be possible since we already error in attribute check for this case * OK by NoUninit */
Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), /* shouldn't be possible since we already error in attribute check for this case */
} }
} }
fn trait_impl( fn trait_impl(input: &DeriveInput) -> Result<(TokenStream, TokenStream)> {
input: &DeriveInput,
) -> Result<(TokenStream, TokenStream)> {
match &input.data { match &input.data {
Data::Struct(DataStruct { fields, .. }) => { Data::Struct(DataStruct { fields, .. }) => {
generate_checked_bit_pattern_struct(&input.ident, fields, &input.attrs) generate_checked_bit_pattern_struct(&input.ident, fields, &input.attrs)
}, }
Data::Enum(_) => generate_checked_bit_pattern_enum(input), Data::Enum(_) => generate_checked_bit_pattern_enum(input),
Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), // shouldn't be possible since we already error in attribute check for this case Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), /* shouldn't be possible since we already error in attribute check for this case */
} }
} }
} }
@ -266,20 +253,20 @@ impl TransparentWrapper {
} }
impl Derivable for TransparentWrapper { impl Derivable for TransparentWrapper {
fn ident() -> TokenStream { fn ident(input: &DeriveInput) -> Result<syn::Path> {
quote!(::bytemuck::TransparentWrapper)
}
fn generic_params(input: &DeriveInput) -> Result<TokenStream> {
let fields = get_struct_fields(input)?; let fields = get_struct_fields(input)?;
match Self::get_wrapper_type(&input.attrs, &fields) { let ty = match Self::get_wrapper_type(&input.attrs, &fields) {
| Some(ty) => Ok(quote!(<#ty>)), Some(ty) => ty,
| None => bail!("\ None => bail!(
"\
when deriving TransparentWrapper for a struct with more than one field \ when deriving TransparentWrapper for a struct with more than one field \
you need to specify the transparent field using #[transparent(T)]\ you need to specify the transparent field using #[transparent(T)]\
"), "
} ),
};
Ok(syn::parse_quote!(::bytemuck::TransparentWrapper<#ty>))
} }
fn asserts(input: &DeriveInput) -> Result<TokenStream> { fn asserts(input: &DeriveInput) -> Result<TokenStream> {
@ -301,15 +288,15 @@ impl Derivable for TransparentWrapper {
} }
} }
fn check_attributes( fn check_attributes(_ty: &Data, attributes: &[Attribute]) -> Result<()> {
_ty: &Data, attributes: &[Attribute],
) -> Result<()> {
let repr = get_repr(attributes)?; let repr = get_repr(attributes)?;
match repr.repr { match repr.repr {
Repr::Transparent => Ok(()), Repr::Transparent => Ok(()),
_ => { _ => {
bail!("TransparentWrapper requires the struct to be #[repr(transparent)]") bail!(
"TransparentWrapper requires the struct to be #[repr(transparent)]"
)
} }
} }
} }
@ -318,13 +305,11 @@ impl Derivable for TransparentWrapper {
pub struct Contiguous; pub struct Contiguous;
impl Derivable for Contiguous { impl Derivable for Contiguous {
fn ident() -> TokenStream { fn ident(_: &DeriveInput) -> Result<syn::Path> {
quote!(::bytemuck::Contiguous) Ok(syn::parse_quote!(::bytemuck::Contiguous))
} }
fn trait_impl( fn trait_impl(input: &DeriveInput) -> Result<(TokenStream, TokenStream)> {
input: &DeriveInput,
) -> Result<(TokenStream, TokenStream)> {
let repr = get_repr(&input.attrs)?; let repr = get_repr(&input.attrs)?;
let integer_ty = if let Some(integer_ty) = repr.repr.as_integer_type() { let integer_ty = if let Some(integer_ty) = repr.repr.as_integer_type() {
@ -422,7 +407,8 @@ fn generate_checked_bit_pattern_struct(
let field_name = &field_names[..]; let field_name = &field_names[..];
let field_ty = &field_tys[..]; let field_ty = &field_tys[..];
let derive_dbg = quote!(#[cfg_attr(not(target_arch = "spirv"), derive(Debug))]); let derive_dbg =
quote!(#[cfg_attr(not(target_arch = "spirv"), derive(Debug))]);
Ok(( Ok((
quote! { quote! {
@ -456,7 +442,11 @@ fn generate_checked_bit_pattern_enum(
(i64::max_value(), i64::min_value(), 0), (i64::max_value(), i64::min_value(), 0),
|(min, max, count), res| { |(min, max, count), res| {
let discriminant = res?; let discriminant = res?;
Ok::<_, Error>((i64::min(min, discriminant), i64::max(max, discriminant), count + 1)) Ok::<_, Error>((
i64::min(min, discriminant),
i64::max(max, discriminant),
count + 1,
))
}, },
)?; )?;
@ -503,9 +493,7 @@ fn generate_checked_bit_pattern_enum(
/// Check that a struct has no padding by asserting that the size of the struct /// Check that a struct has no padding by asserting that the size of the struct
/// is equal to the sum of the size of it's fields /// is equal to the sum of the size of it's fields
fn generate_assert_no_padding( fn generate_assert_no_padding(input: &DeriveInput) -> Result<TokenStream> {
input: &DeriveInput,
) -> Result<TokenStream> {
let struct_type = &input.ident; let struct_type = &input.ident;
let span = input.ident.span(); let span = input.ident.span();
let fields = get_fields(input)?; let fields = get_fields(input)?;
@ -529,7 +517,7 @@ fn generate_assert_no_padding(
/// Check that all fields implement a given trait /// Check that all fields implement a given trait
fn generate_fields_are_trait( fn generate_fields_are_trait(
input: &DeriveInput, trait_: TokenStream, input: &DeriveInput, trait_: syn::Path,
) -> Result<TokenStream> { ) -> Result<TokenStream> {
let (impl_generics, _ty_generics, where_clause) = let (impl_generics, _ty_generics, where_clause) =
input.generics.split_for_impl(); input.generics.split_for_impl();
@ -574,28 +562,30 @@ fn get_simple_attr(attributes: &[Attribute], attr_name: &str) -> Option<Ident> {
fn get_repr(attributes: &[Attribute]) -> Result<Representation> { fn get_repr(attributes: &[Attribute]) -> Result<Representation> {
attributes attributes
.iter() .iter()
.filter_map(|attr| if attr.path.is_ident("repr") { .filter_map(|attr| {
Some(attr.parse_args::<Representation>()) if attr.path.is_ident("repr") {
} else { Some(attr.parse_args::<Representation>())
None } else {
None
}
}) })
.try_fold(Representation::default(), |a, b| { .try_fold(Representation::default(), |a, b| {
let b = b?; let b = b?;
Ok(Representation { Ok(Representation {
repr: match (a.repr, b.repr) { repr: match (a.repr, b.repr) {
| (a, Repr::Rust) => a, (a, Repr::Rust) => a,
| (Repr::Rust, b) => b, (Repr::Rust, b) => b,
| _ => bail!("conflicting representation hints"), _ => bail!("conflicting representation hints"),
}, },
packed: match (a.packed, b.packed) { packed: match (a.packed, b.packed) {
| (a, None) => a, (a, None) => a,
| (None, b) => b, (None, b) => b,
| _ => bail!("conflicting representation hints"), _ => bail!("conflicting representation hints"),
}, },
align: match (a.align, b.align) { align: match (a.align, b.align) {
| (a, None) => a, (a, None) => a,
| (None, b) => b, (None, b) => b,
| _ => bail!("conflicting representation hints"), _ => bail!("conflicting representation hints"),
}, },
}) })
}) })
@ -719,7 +709,8 @@ macro_rules! mk_repr {(
)); ));
} }
} }
)} use mk_repr; )}
use mk_repr;
struct VariantDiscriminantIterator<'a, I: Iterator<Item = &'a Variant> + 'a> { struct VariantDiscriminantIterator<'a, I: Iterator<Item = &'a Variant> + 'a> {
inner: I, inner: I,
@ -767,9 +758,7 @@ fn parse_int_expr(expr: &Expr) -> Result<i64> {
Expr::Unary(ExprUnary { op: UnOp::Neg(_), expr, .. }) => { Expr::Unary(ExprUnary { op: UnOp::Neg(_), expr, .. }) => {
parse_int_expr(expr).map(|int| -int) parse_int_expr(expr).map(|int| -int)
} }
Expr::Lit(ExprLit { lit: Lit::Int(int), .. }) => { Expr::Lit(ExprLit { lit: Lit::Int(int), .. }) => int.base10_parse(),
int.base10_parse()
}
_ => bail!("Not an integer expression"), _ => bail!("Not an integer expression"),
} }
} }

View File

@ -1,7 +1,8 @@
#![allow(dead_code)] #![allow(dead_code)]
use bytemuck::{ use bytemuck::{
AnyBitPattern, Contiguous, CheckedBitPattern, NoUninit, Pod, TransparentWrapper, Zeroable, AnyBitPattern, CheckedBitPattern, Contiguous, NoUninit, Pod,
TransparentWrapper, Zeroable,
}; };
use std::marker::PhantomData; use std::marker::PhantomData;
@ -138,9 +139,26 @@ struct CheckedBitPatternStruct {
#[repr(C)] #[repr(C)]
struct AnyBitPatternTest { struct AnyBitPatternTest {
a: u16, a: u16,
b: u16 b: u16,
} }
/// ```compile_fail
/// use bytemuck::{Pod, Zeroable};
///
/// #[derive(Pod, Zeroable)]
/// #[repr(transparent)]
/// struct TransparentSingle<T>(T);
///
/// struct NotPod(u32);
///
/// let _: u32 = bytemuck::cast(TransparentSingle(NotPod(0u32)));
/// ```
#[derive(
Debug, Copy, Clone, PartialEq, Eq, Pod, Zeroable, TransparentWrapper,
)]
#[repr(transparent)]
struct NewtypeWrapperTest<T>(T);
#[test] #[test]
fn fails_cast_contiguous() { fn fails_cast_contiguous() {
let can_cast = CheckedBitPatternEnumWithValues::is_valid_bit_pattern(&5); let can_cast = CheckedBitPatternEnumWithValues::is_valid_bit_pattern(&5);
@ -149,7 +167,8 @@ fn fails_cast_contiguous() {
#[test] #[test]
fn passes_cast_contiguous() { fn passes_cast_contiguous() {
let res = bytemuck::checked::from_bytes::<CheckedBitPatternEnumWithValues>(&[2u8]); let res =
bytemuck::checked::from_bytes::<CheckedBitPatternEnumWithValues>(&[2u8]);
assert_eq!(*res, CheckedBitPatternEnumWithValues::C); assert_eq!(*res, CheckedBitPatternEnumWithValues::C);
} }
@ -162,7 +181,9 @@ fn fails_cast_noncontiguous() {
#[test] #[test]
fn passes_cast_noncontiguous() { fn passes_cast_noncontiguous() {
let res = let res =
bytemuck::checked::from_bytes::<CheckedBitPatternEnumNonContiguous>(&[56u8]); bytemuck::checked::from_bytes::<CheckedBitPatternEnumNonContiguous>(&[
56u8,
]);
assert_eq!(*res, CheckedBitPatternEnumNonContiguous::E); assert_eq!(*res, CheckedBitPatternEnumNonContiguous::E);
} }
@ -177,7 +198,10 @@ fn fails_cast_struct() {
fn passes_cast_struct() { fn passes_cast_struct() {
let pod = [0u8, 8u8]; let pod = [0u8, 8u8];
let res = bytemuck::checked::from_bytes::<CheckedBitPatternStruct>(&pod); let res = bytemuck::checked::from_bytes::<CheckedBitPatternStruct>(&pod);
assert_eq!(*res, CheckedBitPatternStruct { a: 0, b: CheckedBitPatternEnumNonContiguous::B }); assert_eq!(
*res,
CheckedBitPatternStruct { a: 0, b: CheckedBitPatternEnumNonContiguous::B }
);
} }
#[test] #[test]

View File

@ -23,7 +23,9 @@ possibility code branch.
#[cfg(not(target_arch = "spirv"))] #[cfg(not(target_arch = "spirv"))]
#[cold] #[cold]
#[inline(never)] #[inline(never)]
pub(crate) fn something_went_wrong<D: core::fmt::Display>(_src: &str, _err: D) -> ! { pub(crate) fn something_went_wrong<D: core::fmt::Display>(
_src: &str, _err: D,
) -> ! {
// Note(Lokathor): Keeping the panic here makes the panic _formatting_ go // Note(Lokathor): Keeping the panic here makes the panic _formatting_ go
// here too, which helps assembly readability and also helps keep down // here too, which helps assembly readability and also helps keep down
// the inline pressure. // the inline pressure.