mirror of
https://github.com/Lokathor/bytemuck.git
synced 2024-11-21 14:22:26 +00:00
Support unions in derive macros (#94)
* allow deriving traits on unions in some cases * basic union tests for nopadding and anybitpattern * implement derives for unions for more traits * remove Pod and AnyBitPattern derives for unions due to possible unsoundness
This commit is contained in:
parent
b472189ff8
commit
1652a2dcd2
@ -5,8 +5,8 @@ extern crate proc_macro;
|
||||
mod traits;
|
||||
|
||||
use proc_macro2::TokenStream;
|
||||
use quote::quote;
|
||||
use syn::{parse_macro_input, DeriveInput};
|
||||
use quote::{quote, quote_spanned};
|
||||
use syn::{parse_macro_input, DeriveInput, spanned::Spanned};
|
||||
|
||||
use crate::traits::{
|
||||
AnyBitPattern, Contiguous, Derivable, CheckedBitPattern, NoPadding, Pod, TransparentWrapper, Zeroable,
|
||||
@ -223,8 +223,9 @@ pub fn derive_contiguous(
|
||||
|
||||
/// Basic wrapper for error handling
|
||||
fn derive_marker_trait<Trait: Derivable>(input: DeriveInput) -> TokenStream {
|
||||
let span = input.span();
|
||||
derive_marker_trait_inner::<Trait>(input).unwrap_or_else(|err| {
|
||||
quote! {
|
||||
quote_spanned! { span =>
|
||||
compile_error!(#err);
|
||||
}
|
||||
})
|
||||
|
@ -3,7 +3,7 @@ use quote::{quote, quote_spanned, ToTokens};
|
||||
use syn::{
|
||||
spanned::Spanned, AttrStyle, Attribute, Data, DataEnum, DataStruct,
|
||||
DeriveInput, Expr, ExprLit, ExprUnary, Fields, Lit, LitInt, Type, UnOp,
|
||||
Variant,
|
||||
Variant, DataUnion,
|
||||
};
|
||||
|
||||
pub trait Derivable {
|
||||
@ -38,17 +38,23 @@ impl Derivable for Pod {
|
||||
|
||||
fn asserts(input: &DeriveInput) -> Result<TokenStream, &'static str> {
|
||||
if !input.generics.params.is_empty() {
|
||||
return Err("Pod requires cannot be derived for structs containing generic parameters because the padding requirements can't be verified for generic structs");
|
||||
return Err("Pod requires cannot be derived for types containing generic parameters because the padding requirements can't be verified for generic structs");
|
||||
}
|
||||
|
||||
let assert_no_padding = generate_assert_no_padding(input)?;
|
||||
let assert_fields_are_pod =
|
||||
generate_fields_are_trait(input, Self::ident())?;
|
||||
match &input.data {
|
||||
Data::Struct(_) => {
|
||||
let assert_no_padding = generate_assert_no_padding(input)?;
|
||||
let assert_fields_are_pod =
|
||||
generate_fields_are_trait(input, Self::ident())?;
|
||||
|
||||
Ok(quote!(
|
||||
#assert_no_padding
|
||||
#assert_fields_are_pod
|
||||
))
|
||||
Ok(quote!(
|
||||
#assert_no_padding
|
||||
#assert_fields_are_pod
|
||||
))
|
||||
},
|
||||
Data::Enum(_) => Err("Deriving Pod is not supported for enums"),
|
||||
Data::Union(_) => Err("Deriving Pod is not supported for unions")
|
||||
}
|
||||
}
|
||||
|
||||
fn check_attributes(
|
||||
@ -59,7 +65,7 @@ impl Derivable for Pod {
|
||||
Some("C") => Ok(()),
|
||||
Some("transparent") => Ok(()),
|
||||
_ => {
|
||||
Err("Pod requires the struct to be #[repr(C)] or #[repr(transparent)]")
|
||||
Err("Pod requires the type to be #[repr(C)] or #[repr(transparent)]")
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -77,7 +83,11 @@ impl Derivable for AnyBitPattern {
|
||||
}
|
||||
|
||||
fn asserts(input: &DeriveInput) -> Result<TokenStream, &'static str> {
|
||||
generate_fields_are_trait(input, Self::ident())
|
||||
match &input.data {
|
||||
Data::Union(_) => Ok(quote!()), // unions are always `AnyBitPattern`
|
||||
Data::Struct(_) => generate_fields_are_trait(input, Self::ident()),
|
||||
Data::Enum(_) => Err("Deriving AnyBitPattern is not supported for enums"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -89,7 +99,11 @@ impl Derivable for Zeroable {
|
||||
}
|
||||
|
||||
fn asserts(input: &DeriveInput) -> Result<TokenStream, &'static str> {
|
||||
generate_fields_are_trait(input, Self::ident())
|
||||
match &input.data {
|
||||
Data::Union(_) => Ok(quote!()), // unions are always `Zeroable`
|
||||
Data::Struct(_) => generate_fields_are_trait(input, Self::ident()),
|
||||
Data::Enum(_) => Err("Deriving Zeroable is not supported for enums"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -107,14 +121,14 @@ impl Derivable for NoPadding {
|
||||
match ty {
|
||||
Data::Struct(_) => match repr.as_deref() {
|
||||
Some("C" | "transparent") => Ok(()),
|
||||
_ => Err("NoPadding requires the struct to be #[repr(C)] or #[repr(transparent)]"),
|
||||
_ => Err("NoPadding derive requires the type to be #[repr(C)] or #[repr(transparent)]"),
|
||||
},
|
||||
Data::Enum(_) => if repr.map(|repr| repr.starts_with('u') || repr.starts_with('i')) == Some(true) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err("NoPadding requires the enum to be an explicit #[repr(Int)]")
|
||||
},
|
||||
Data::Union(_) => Err("NoPadding can only be derived on enums and structs")
|
||||
Data::Union(_) => Err("NoPadding cannot be derived for unions")
|
||||
}
|
||||
}
|
||||
|
||||
@ -141,7 +155,7 @@ impl Derivable for NoPadding {
|
||||
Ok(quote!())
|
||||
}
|
||||
}
|
||||
Data::Union(_) => Err("Internal error in NoPadding derive"), // shouldn't be possible since we already error in attribute check for this case
|
||||
Data::Union(_) => Err("NoPadding cannot be derived for unions")
|
||||
}
|
||||
}
|
||||
|
||||
@ -215,7 +229,7 @@ impl TransparentWrapper {
|
||||
) -> Option<TokenStream> {
|
||||
let transparent_param = get_simple_attr(attributes, "transparent");
|
||||
transparent_param.map(|ident| ident.to_token_stream()).or_else(|| {
|
||||
let mut types = get_field_types(fields);
|
||||
let mut types = get_field_types(&fields);
|
||||
let first_type = types.next();
|
||||
if let Some(_) = types.next() {
|
||||
// can't guess param type if there is more than one field
|
||||
@ -235,13 +249,13 @@ impl Derivable for TransparentWrapper {
|
||||
fn generic_params(input: &DeriveInput) -> Result<TokenStream, &'static str> {
|
||||
let fields = get_struct_fields(input)?;
|
||||
|
||||
Self::get_wrapper_type(&input.attrs, fields).map(|ty| quote!(<#ty>))
|
||||
Self::get_wrapper_type(&input.attrs, &fields).map(|ty| quote!(<#ty>))
|
||||
.ok_or("when deriving TransparentWrapper for a struct with more than one field you need to specify the transparent field using #[transparent(T)]")
|
||||
}
|
||||
|
||||
fn asserts(input: &DeriveInput) -> Result<TokenStream, &'static str> {
|
||||
let fields = get_struct_fields(input)?;
|
||||
let wrapped_type = match Self::get_wrapper_type(&input.attrs, fields) {
|
||||
let wrapped_type = match Self::get_wrapper_type(&input.attrs, &fields) {
|
||||
Some(wrapped_type) => wrapped_type.to_string(),
|
||||
None => unreachable!(), /* other code will already reject this derive */
|
||||
};
|
||||
@ -334,6 +348,14 @@ fn get_struct_fields(input: &DeriveInput) -> Result<&Fields, &'static str> {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_fields(input: &DeriveInput) -> Result<Fields, &'static str> {
|
||||
match &input.data {
|
||||
Data::Struct(DataStruct { fields, .. }) => Ok(fields.clone()),
|
||||
Data::Union(DataUnion { fields, .. }) => Ok(Fields::Named(fields.clone())),
|
||||
Data::Enum(_) => Err("deriving this trait is not supported for enums")
|
||||
}
|
||||
}
|
||||
|
||||
fn get_enum_variants<'a>(
|
||||
input: &'a DeriveInput,
|
||||
) -> Result<impl Iterator<Item = &'a Variant> + 'a, &'static str> {
|
||||
@ -459,7 +481,7 @@ fn generate_assert_no_padding(
|
||||
) -> Result<TokenStream, &'static str> {
|
||||
let struct_type = &input.ident;
|
||||
let span = input.ident.span();
|
||||
let fields = get_struct_fields(input)?;
|
||||
let fields = get_fields(input)?;
|
||||
|
||||
let mut field_types = get_field_types(&fields);
|
||||
let size_sum = if let Some(first) = field_types.next() {
|
||||
@ -484,7 +506,7 @@ fn generate_fields_are_trait(
|
||||
) -> Result<TokenStream, &'static str> {
|
||||
let (impl_generics, _ty_generics, where_clause) =
|
||||
input.generics.split_for_impl();
|
||||
let fields = get_struct_fields(input)?;
|
||||
let fields = get_fields(input)?;
|
||||
let span = input.span();
|
||||
let field_types = get_field_types(&fields);
|
||||
Ok(quote_spanned! {span => #(const _: fn() = || {
|
||||
|
@ -1,7 +1,7 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
use bytemuck::{
|
||||
AnyBitPattern, Contiguous, CheckedBitPattern, NoPadding, Pod, TransparentWrapper, Zeroable,
|
||||
Contiguous, CheckedBitPattern, NoPadding, Pod, TransparentWrapper, Zeroable, AnyBitPattern,
|
||||
};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
@ -58,6 +58,13 @@ struct NoPaddingTest {
|
||||
b: u16,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, AnyBitPattern)]
|
||||
#[repr(C)]
|
||||
union UnionTestAnyBitPattern {
|
||||
a: u8,
|
||||
b: u16,
|
||||
}
|
||||
|
||||
#[repr(u8)]
|
||||
#[derive(Debug, Clone, Copy, NoPadding, CheckedBitPattern, PartialEq, Eq)]
|
||||
enum CheckedBitPatternEnumWithValues {
|
||||
|
Loading…
Reference in New Issue
Block a user