Support unions in derive macros (#94)

* allow deriving traits on unions in some cases

* basic union tests for nopadding and anybitpattern

* implement derives for unions for more traits

* remove Pod and AnyBitPattern derives for unions due to possible unsoundness
This commit is contained in:
Gray Olson 2022-03-29 19:01:48 -07:00 committed by GitHub
parent b472189ff8
commit 1652a2dcd2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 54 additions and 24 deletions

View File

@ -5,8 +5,8 @@ extern crate proc_macro;
mod traits;
use proc_macro2::TokenStream;
use quote::quote;
use syn::{parse_macro_input, DeriveInput};
use quote::{quote, quote_spanned};
use syn::{parse_macro_input, DeriveInput, spanned::Spanned};
use crate::traits::{
AnyBitPattern, Contiguous, Derivable, CheckedBitPattern, NoPadding, Pod, TransparentWrapper, Zeroable,
@ -223,8 +223,9 @@ pub fn derive_contiguous(
/// Basic wrapper for error handling
fn derive_marker_trait<Trait: Derivable>(input: DeriveInput) -> TokenStream {
let span = input.span();
derive_marker_trait_inner::<Trait>(input).unwrap_or_else(|err| {
quote! {
quote_spanned! { span =>
compile_error!(#err);
}
})

View File

@ -3,7 +3,7 @@ use quote::{quote, quote_spanned, ToTokens};
use syn::{
spanned::Spanned, AttrStyle, Attribute, Data, DataEnum, DataStruct,
DeriveInput, Expr, ExprLit, ExprUnary, Fields, Lit, LitInt, Type, UnOp,
Variant,
Variant, DataUnion,
};
pub trait Derivable {
@ -38,9 +38,11 @@ impl Derivable for Pod {
fn asserts(input: &DeriveInput) -> Result<TokenStream, &'static str> {
if !input.generics.params.is_empty() {
return Err("Pod requires cannot be derived for structs containing generic parameters because the padding requirements can't be verified for generic structs");
return Err("Pod requires cannot be derived for types containing generic parameters because the padding requirements can't be verified for generic structs");
}
match &input.data {
Data::Struct(_) => {
let assert_no_padding = generate_assert_no_padding(input)?;
let assert_fields_are_pod =
generate_fields_are_trait(input, Self::ident())?;
@ -49,6 +51,10 @@ impl Derivable for Pod {
#assert_no_padding
#assert_fields_are_pod
))
},
Data::Enum(_) => Err("Deriving Pod is not supported for enums"),
Data::Union(_) => Err("Deriving Pod is not supported for unions")
}
}
fn check_attributes(
@ -59,7 +65,7 @@ impl Derivable for Pod {
Some("C") => Ok(()),
Some("transparent") => Ok(()),
_ => {
Err("Pod requires the struct to be #[repr(C)] or #[repr(transparent)]")
Err("Pod requires the type to be #[repr(C)] or #[repr(transparent)]")
}
}
}
@ -77,7 +83,11 @@ impl Derivable for AnyBitPattern {
}
fn asserts(input: &DeriveInput) -> Result<TokenStream, &'static str> {
generate_fields_are_trait(input, Self::ident())
match &input.data {
Data::Union(_) => Ok(quote!()), // unions are always `AnyBitPattern`
Data::Struct(_) => generate_fields_are_trait(input, Self::ident()),
Data::Enum(_) => Err("Deriving AnyBitPattern is not supported for enums"),
}
}
}
@ -89,7 +99,11 @@ impl Derivable for Zeroable {
}
fn asserts(input: &DeriveInput) -> Result<TokenStream, &'static str> {
generate_fields_are_trait(input, Self::ident())
match &input.data {
Data::Union(_) => Ok(quote!()), // unions are always `Zeroable`
Data::Struct(_) => generate_fields_are_trait(input, Self::ident()),
Data::Enum(_) => Err("Deriving Zeroable is not supported for enums"),
}
}
}
@ -107,14 +121,14 @@ impl Derivable for NoPadding {
match ty {
Data::Struct(_) => match repr.as_deref() {
Some("C" | "transparent") => Ok(()),
_ => Err("NoPadding requires the struct to be #[repr(C)] or #[repr(transparent)]"),
_ => Err("NoPadding derive requires the type to be #[repr(C)] or #[repr(transparent)]"),
},
Data::Enum(_) => if repr.map(|repr| repr.starts_with('u') || repr.starts_with('i')) == Some(true) {
Ok(())
} else {
Err("NoPadding requires the enum to be an explicit #[repr(Int)]")
},
Data::Union(_) => Err("NoPadding can only be derived on enums and structs")
Data::Union(_) => Err("NoPadding cannot be derived for unions")
}
}
@ -141,7 +155,7 @@ impl Derivable for NoPadding {
Ok(quote!())
}
}
Data::Union(_) => Err("Internal error in NoPadding derive"), // shouldn't be possible since we already error in attribute check for this case
Data::Union(_) => Err("NoPadding cannot be derived for unions")
}
}
@ -215,7 +229,7 @@ impl TransparentWrapper {
) -> Option<TokenStream> {
let transparent_param = get_simple_attr(attributes, "transparent");
transparent_param.map(|ident| ident.to_token_stream()).or_else(|| {
let mut types = get_field_types(fields);
let mut types = get_field_types(&fields);
let first_type = types.next();
if let Some(_) = types.next() {
// can't guess param type if there is more than one field
@ -235,13 +249,13 @@ impl Derivable for TransparentWrapper {
fn generic_params(input: &DeriveInput) -> Result<TokenStream, &'static str> {
let fields = get_struct_fields(input)?;
Self::get_wrapper_type(&input.attrs, fields).map(|ty| quote!(<#ty>))
Self::get_wrapper_type(&input.attrs, &fields).map(|ty| quote!(<#ty>))
.ok_or("when deriving TransparentWrapper for a struct with more than one field you need to specify the transparent field using #[transparent(T)]")
}
fn asserts(input: &DeriveInput) -> Result<TokenStream, &'static str> {
let fields = get_struct_fields(input)?;
let wrapped_type = match Self::get_wrapper_type(&input.attrs, fields) {
let wrapped_type = match Self::get_wrapper_type(&input.attrs, &fields) {
Some(wrapped_type) => wrapped_type.to_string(),
None => unreachable!(), /* other code will already reject this derive */
};
@ -334,6 +348,14 @@ fn get_struct_fields(input: &DeriveInput) -> Result<&Fields, &'static str> {
}
}
fn get_fields(input: &DeriveInput) -> Result<Fields, &'static str> {
match &input.data {
Data::Struct(DataStruct { fields, .. }) => Ok(fields.clone()),
Data::Union(DataUnion { fields, .. }) => Ok(Fields::Named(fields.clone())),
Data::Enum(_) => Err("deriving this trait is not supported for enums")
}
}
fn get_enum_variants<'a>(
input: &'a DeriveInput,
) -> Result<impl Iterator<Item = &'a Variant> + 'a, &'static str> {
@ -459,7 +481,7 @@ fn generate_assert_no_padding(
) -> Result<TokenStream, &'static str> {
let struct_type = &input.ident;
let span = input.ident.span();
let fields = get_struct_fields(input)?;
let fields = get_fields(input)?;
let mut field_types = get_field_types(&fields);
let size_sum = if let Some(first) = field_types.next() {
@ -484,7 +506,7 @@ fn generate_fields_are_trait(
) -> Result<TokenStream, &'static str> {
let (impl_generics, _ty_generics, where_clause) =
input.generics.split_for_impl();
let fields = get_struct_fields(input)?;
let fields = get_fields(input)?;
let span = input.span();
let field_types = get_field_types(&fields);
Ok(quote_spanned! {span => #(const _: fn() = || {

View File

@ -1,7 +1,7 @@
#![allow(dead_code)]
use bytemuck::{
AnyBitPattern, Contiguous, CheckedBitPattern, NoPadding, Pod, TransparentWrapper, Zeroable,
Contiguous, CheckedBitPattern, NoPadding, Pod, TransparentWrapper, Zeroable, AnyBitPattern,
};
use std::marker::PhantomData;
@ -58,6 +58,13 @@ struct NoPaddingTest {
b: u16,
}
#[derive(Copy, Clone, AnyBitPattern)]
#[repr(C)]
union UnionTestAnyBitPattern {
a: u8,
b: u16,
}
#[repr(u8)]
#[derive(Debug, Clone, Copy, NoPadding, CheckedBitPattern, PartialEq, Eq)]
enum CheckedBitPatternEnumWithValues {