mirror of
https://github.com/Lokathor/bytemuck.git
synced 2024-11-22 06:42:25 +00:00
Support unions in derive macros (#94)
* allow deriving traits on unions in some cases * basic union tests for nopadding and anybitpattern * implement derives for unions for more traits * remove Pod and AnyBitPattern derives for unions due to possible unsoundness
This commit is contained in:
parent
b472189ff8
commit
1652a2dcd2
@ -5,8 +5,8 @@ extern crate proc_macro;
|
|||||||
mod traits;
|
mod traits;
|
||||||
|
|
||||||
use proc_macro2::TokenStream;
|
use proc_macro2::TokenStream;
|
||||||
use quote::quote;
|
use quote::{quote, quote_spanned};
|
||||||
use syn::{parse_macro_input, DeriveInput};
|
use syn::{parse_macro_input, DeriveInput, spanned::Spanned};
|
||||||
|
|
||||||
use crate::traits::{
|
use crate::traits::{
|
||||||
AnyBitPattern, Contiguous, Derivable, CheckedBitPattern, NoPadding, Pod, TransparentWrapper, Zeroable,
|
AnyBitPattern, Contiguous, Derivable, CheckedBitPattern, NoPadding, Pod, TransparentWrapper, Zeroable,
|
||||||
@ -223,8 +223,9 @@ pub fn derive_contiguous(
|
|||||||
|
|
||||||
/// Basic wrapper for error handling
|
/// Basic wrapper for error handling
|
||||||
fn derive_marker_trait<Trait: Derivable>(input: DeriveInput) -> TokenStream {
|
fn derive_marker_trait<Trait: Derivable>(input: DeriveInput) -> TokenStream {
|
||||||
|
let span = input.span();
|
||||||
derive_marker_trait_inner::<Trait>(input).unwrap_or_else(|err| {
|
derive_marker_trait_inner::<Trait>(input).unwrap_or_else(|err| {
|
||||||
quote! {
|
quote_spanned! { span =>
|
||||||
compile_error!(#err);
|
compile_error!(#err);
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -3,7 +3,7 @@ use quote::{quote, quote_spanned, ToTokens};
|
|||||||
use syn::{
|
use syn::{
|
||||||
spanned::Spanned, AttrStyle, Attribute, Data, DataEnum, DataStruct,
|
spanned::Spanned, AttrStyle, Attribute, Data, DataEnum, DataStruct,
|
||||||
DeriveInput, Expr, ExprLit, ExprUnary, Fields, Lit, LitInt, Type, UnOp,
|
DeriveInput, Expr, ExprLit, ExprUnary, Fields, Lit, LitInt, Type, UnOp,
|
||||||
Variant,
|
Variant, DataUnion,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub trait Derivable {
|
pub trait Derivable {
|
||||||
@ -38,9 +38,11 @@ impl Derivable for Pod {
|
|||||||
|
|
||||||
fn asserts(input: &DeriveInput) -> Result<TokenStream, &'static str> {
|
fn asserts(input: &DeriveInput) -> Result<TokenStream, &'static str> {
|
||||||
if !input.generics.params.is_empty() {
|
if !input.generics.params.is_empty() {
|
||||||
return Err("Pod requires cannot be derived for structs containing generic parameters because the padding requirements can't be verified for generic structs");
|
return Err("Pod requires cannot be derived for types containing generic parameters because the padding requirements can't be verified for generic structs");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
match &input.data {
|
||||||
|
Data::Struct(_) => {
|
||||||
let assert_no_padding = generate_assert_no_padding(input)?;
|
let assert_no_padding = generate_assert_no_padding(input)?;
|
||||||
let assert_fields_are_pod =
|
let assert_fields_are_pod =
|
||||||
generate_fields_are_trait(input, Self::ident())?;
|
generate_fields_are_trait(input, Self::ident())?;
|
||||||
@ -49,6 +51,10 @@ impl Derivable for Pod {
|
|||||||
#assert_no_padding
|
#assert_no_padding
|
||||||
#assert_fields_are_pod
|
#assert_fields_are_pod
|
||||||
))
|
))
|
||||||
|
},
|
||||||
|
Data::Enum(_) => Err("Deriving Pod is not supported for enums"),
|
||||||
|
Data::Union(_) => Err("Deriving Pod is not supported for unions")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_attributes(
|
fn check_attributes(
|
||||||
@ -59,7 +65,7 @@ impl Derivable for Pod {
|
|||||||
Some("C") => Ok(()),
|
Some("C") => Ok(()),
|
||||||
Some("transparent") => Ok(()),
|
Some("transparent") => Ok(()),
|
||||||
_ => {
|
_ => {
|
||||||
Err("Pod requires the struct to be #[repr(C)] or #[repr(transparent)]")
|
Err("Pod requires the type to be #[repr(C)] or #[repr(transparent)]")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -77,7 +83,11 @@ impl Derivable for AnyBitPattern {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn asserts(input: &DeriveInput) -> Result<TokenStream, &'static str> {
|
fn asserts(input: &DeriveInput) -> Result<TokenStream, &'static str> {
|
||||||
generate_fields_are_trait(input, Self::ident())
|
match &input.data {
|
||||||
|
Data::Union(_) => Ok(quote!()), // unions are always `AnyBitPattern`
|
||||||
|
Data::Struct(_) => generate_fields_are_trait(input, Self::ident()),
|
||||||
|
Data::Enum(_) => Err("Deriving AnyBitPattern is not supported for enums"),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -89,7 +99,11 @@ impl Derivable for Zeroable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn asserts(input: &DeriveInput) -> Result<TokenStream, &'static str> {
|
fn asserts(input: &DeriveInput) -> Result<TokenStream, &'static str> {
|
||||||
generate_fields_are_trait(input, Self::ident())
|
match &input.data {
|
||||||
|
Data::Union(_) => Ok(quote!()), // unions are always `Zeroable`
|
||||||
|
Data::Struct(_) => generate_fields_are_trait(input, Self::ident()),
|
||||||
|
Data::Enum(_) => Err("Deriving Zeroable is not supported for enums"),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -107,14 +121,14 @@ impl Derivable for NoPadding {
|
|||||||
match ty {
|
match ty {
|
||||||
Data::Struct(_) => match repr.as_deref() {
|
Data::Struct(_) => match repr.as_deref() {
|
||||||
Some("C" | "transparent") => Ok(()),
|
Some("C" | "transparent") => Ok(()),
|
||||||
_ => Err("NoPadding requires the struct to be #[repr(C)] or #[repr(transparent)]"),
|
_ => Err("NoPadding derive requires the type to be #[repr(C)] or #[repr(transparent)]"),
|
||||||
},
|
},
|
||||||
Data::Enum(_) => if repr.map(|repr| repr.starts_with('u') || repr.starts_with('i')) == Some(true) {
|
Data::Enum(_) => if repr.map(|repr| repr.starts_with('u') || repr.starts_with('i')) == Some(true) {
|
||||||
Ok(())
|
Ok(())
|
||||||
} else {
|
} else {
|
||||||
Err("NoPadding requires the enum to be an explicit #[repr(Int)]")
|
Err("NoPadding requires the enum to be an explicit #[repr(Int)]")
|
||||||
},
|
},
|
||||||
Data::Union(_) => Err("NoPadding can only be derived on enums and structs")
|
Data::Union(_) => Err("NoPadding cannot be derived for unions")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -141,7 +155,7 @@ impl Derivable for NoPadding {
|
|||||||
Ok(quote!())
|
Ok(quote!())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Data::Union(_) => Err("Internal error in NoPadding derive"), // shouldn't be possible since we already error in attribute check for this case
|
Data::Union(_) => Err("NoPadding cannot be derived for unions")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -215,7 +229,7 @@ impl TransparentWrapper {
|
|||||||
) -> Option<TokenStream> {
|
) -> Option<TokenStream> {
|
||||||
let transparent_param = get_simple_attr(attributes, "transparent");
|
let transparent_param = get_simple_attr(attributes, "transparent");
|
||||||
transparent_param.map(|ident| ident.to_token_stream()).or_else(|| {
|
transparent_param.map(|ident| ident.to_token_stream()).or_else(|| {
|
||||||
let mut types = get_field_types(fields);
|
let mut types = get_field_types(&fields);
|
||||||
let first_type = types.next();
|
let first_type = types.next();
|
||||||
if let Some(_) = types.next() {
|
if let Some(_) = types.next() {
|
||||||
// can't guess param type if there is more than one field
|
// can't guess param type if there is more than one field
|
||||||
@ -235,13 +249,13 @@ impl Derivable for TransparentWrapper {
|
|||||||
fn generic_params(input: &DeriveInput) -> Result<TokenStream, &'static str> {
|
fn generic_params(input: &DeriveInput) -> Result<TokenStream, &'static str> {
|
||||||
let fields = get_struct_fields(input)?;
|
let fields = get_struct_fields(input)?;
|
||||||
|
|
||||||
Self::get_wrapper_type(&input.attrs, fields).map(|ty| quote!(<#ty>))
|
Self::get_wrapper_type(&input.attrs, &fields).map(|ty| quote!(<#ty>))
|
||||||
.ok_or("when deriving TransparentWrapper for a struct with more than one field you need to specify the transparent field using #[transparent(T)]")
|
.ok_or("when deriving TransparentWrapper for a struct with more than one field you need to specify the transparent field using #[transparent(T)]")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn asserts(input: &DeriveInput) -> Result<TokenStream, &'static str> {
|
fn asserts(input: &DeriveInput) -> Result<TokenStream, &'static str> {
|
||||||
let fields = get_struct_fields(input)?;
|
let fields = get_struct_fields(input)?;
|
||||||
let wrapped_type = match Self::get_wrapper_type(&input.attrs, fields) {
|
let wrapped_type = match Self::get_wrapper_type(&input.attrs, &fields) {
|
||||||
Some(wrapped_type) => wrapped_type.to_string(),
|
Some(wrapped_type) => wrapped_type.to_string(),
|
||||||
None => unreachable!(), /* other code will already reject this derive */
|
None => unreachable!(), /* other code will already reject this derive */
|
||||||
};
|
};
|
||||||
@ -334,6 +348,14 @@ fn get_struct_fields(input: &DeriveInput) -> Result<&Fields, &'static str> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_fields(input: &DeriveInput) -> Result<Fields, &'static str> {
|
||||||
|
match &input.data {
|
||||||
|
Data::Struct(DataStruct { fields, .. }) => Ok(fields.clone()),
|
||||||
|
Data::Union(DataUnion { fields, .. }) => Ok(Fields::Named(fields.clone())),
|
||||||
|
Data::Enum(_) => Err("deriving this trait is not supported for enums")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn get_enum_variants<'a>(
|
fn get_enum_variants<'a>(
|
||||||
input: &'a DeriveInput,
|
input: &'a DeriveInput,
|
||||||
) -> Result<impl Iterator<Item = &'a Variant> + 'a, &'static str> {
|
) -> Result<impl Iterator<Item = &'a Variant> + 'a, &'static str> {
|
||||||
@ -459,7 +481,7 @@ fn generate_assert_no_padding(
|
|||||||
) -> Result<TokenStream, &'static str> {
|
) -> Result<TokenStream, &'static str> {
|
||||||
let struct_type = &input.ident;
|
let struct_type = &input.ident;
|
||||||
let span = input.ident.span();
|
let span = input.ident.span();
|
||||||
let fields = get_struct_fields(input)?;
|
let fields = get_fields(input)?;
|
||||||
|
|
||||||
let mut field_types = get_field_types(&fields);
|
let mut field_types = get_field_types(&fields);
|
||||||
let size_sum = if let Some(first) = field_types.next() {
|
let size_sum = if let Some(first) = field_types.next() {
|
||||||
@ -484,7 +506,7 @@ fn generate_fields_are_trait(
|
|||||||
) -> Result<TokenStream, &'static str> {
|
) -> Result<TokenStream, &'static str> {
|
||||||
let (impl_generics, _ty_generics, where_clause) =
|
let (impl_generics, _ty_generics, where_clause) =
|
||||||
input.generics.split_for_impl();
|
input.generics.split_for_impl();
|
||||||
let fields = get_struct_fields(input)?;
|
let fields = get_fields(input)?;
|
||||||
let span = input.span();
|
let span = input.span();
|
||||||
let field_types = get_field_types(&fields);
|
let field_types = get_field_types(&fields);
|
||||||
Ok(quote_spanned! {span => #(const _: fn() = || {
|
Ok(quote_spanned! {span => #(const _: fn() = || {
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#![allow(dead_code)]
|
#![allow(dead_code)]
|
||||||
|
|
||||||
use bytemuck::{
|
use bytemuck::{
|
||||||
AnyBitPattern, Contiguous, CheckedBitPattern, NoPadding, Pod, TransparentWrapper, Zeroable,
|
Contiguous, CheckedBitPattern, NoPadding, Pod, TransparentWrapper, Zeroable, AnyBitPattern,
|
||||||
};
|
};
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
@ -58,6 +58,13 @@ struct NoPaddingTest {
|
|||||||
b: u16,
|
b: u16,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Copy, Clone, AnyBitPattern)]
|
||||||
|
#[repr(C)]
|
||||||
|
union UnionTestAnyBitPattern {
|
||||||
|
a: u8,
|
||||||
|
b: u16,
|
||||||
|
}
|
||||||
|
|
||||||
#[repr(u8)]
|
#[repr(u8)]
|
||||||
#[derive(Debug, Clone, Copy, NoPadding, CheckedBitPattern, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, NoPadding, CheckedBitPattern, PartialEq, Eq)]
|
||||||
enum CheckedBitPatternEnumWithValues {
|
enum CheckedBitPatternEnumWithValues {
|
||||||
|
Loading…
Reference in New Issue
Block a user