Support unions in derive macros (#94)

* allow deriving traits on unions in some cases

* basic union tests for nopadding and anybitpattern

* implement derives for unions for more traits

* remove Pod and AnyBitPattern derives for unions due to possible unsoundness
This commit is contained in:
Gray Olson 2022-03-29 19:01:48 -07:00 committed by GitHub
parent b472189ff8
commit 1652a2dcd2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 54 additions and 24 deletions

View File

@ -5,8 +5,8 @@ extern crate proc_macro;
mod traits; mod traits;
use proc_macro2::TokenStream; use proc_macro2::TokenStream;
use quote::quote; use quote::{quote, quote_spanned};
use syn::{parse_macro_input, DeriveInput}; use syn::{parse_macro_input, DeriveInput, spanned::Spanned};
use crate::traits::{ use crate::traits::{
AnyBitPattern, Contiguous, Derivable, CheckedBitPattern, NoPadding, Pod, TransparentWrapper, Zeroable, AnyBitPattern, Contiguous, Derivable, CheckedBitPattern, NoPadding, Pod, TransparentWrapper, Zeroable,
@ -223,8 +223,9 @@ pub fn derive_contiguous(
/// Basic wrapper for error handling /// Basic wrapper for error handling
fn derive_marker_trait<Trait: Derivable>(input: DeriveInput) -> TokenStream { fn derive_marker_trait<Trait: Derivable>(input: DeriveInput) -> TokenStream {
let span = input.span();
derive_marker_trait_inner::<Trait>(input).unwrap_or_else(|err| { derive_marker_trait_inner::<Trait>(input).unwrap_or_else(|err| {
quote! { quote_spanned! { span =>
compile_error!(#err); compile_error!(#err);
} }
}) })

View File

@ -3,7 +3,7 @@ use quote::{quote, quote_spanned, ToTokens};
use syn::{ use syn::{
spanned::Spanned, AttrStyle, Attribute, Data, DataEnum, DataStruct, spanned::Spanned, AttrStyle, Attribute, Data, DataEnum, DataStruct,
DeriveInput, Expr, ExprLit, ExprUnary, Fields, Lit, LitInt, Type, UnOp, DeriveInput, Expr, ExprLit, ExprUnary, Fields, Lit, LitInt, Type, UnOp,
Variant, Variant, DataUnion,
}; };
pub trait Derivable { pub trait Derivable {
@ -38,9 +38,11 @@ impl Derivable for Pod {
fn asserts(input: &DeriveInput) -> Result<TokenStream, &'static str> { fn asserts(input: &DeriveInput) -> Result<TokenStream, &'static str> {
if !input.generics.params.is_empty() { if !input.generics.params.is_empty() {
return Err("Pod requires cannot be derived for structs containing generic parameters because the padding requirements can't be verified for generic structs"); return Err("Pod requires cannot be derived for types containing generic parameters because the padding requirements can't be verified for generic structs");
} }
match &input.data {
Data::Struct(_) => {
let assert_no_padding = generate_assert_no_padding(input)?; let assert_no_padding = generate_assert_no_padding(input)?;
let assert_fields_are_pod = let assert_fields_are_pod =
generate_fields_are_trait(input, Self::ident())?; generate_fields_are_trait(input, Self::ident())?;
@ -49,6 +51,10 @@ impl Derivable for Pod {
#assert_no_padding #assert_no_padding
#assert_fields_are_pod #assert_fields_are_pod
)) ))
},
Data::Enum(_) => Err("Deriving Pod is not supported for enums"),
Data::Union(_) => Err("Deriving Pod is not supported for unions")
}
} }
fn check_attributes( fn check_attributes(
@ -59,7 +65,7 @@ impl Derivable for Pod {
Some("C") => Ok(()), Some("C") => Ok(()),
Some("transparent") => Ok(()), Some("transparent") => Ok(()),
_ => { _ => {
Err("Pod requires the struct to be #[repr(C)] or #[repr(transparent)]") Err("Pod requires the type to be #[repr(C)] or #[repr(transparent)]")
} }
} }
} }
@ -77,7 +83,11 @@ impl Derivable for AnyBitPattern {
} }
fn asserts(input: &DeriveInput) -> Result<TokenStream, &'static str> { fn asserts(input: &DeriveInput) -> Result<TokenStream, &'static str> {
generate_fields_are_trait(input, Self::ident()) match &input.data {
Data::Union(_) => Ok(quote!()), // unions are always `AnyBitPattern`
Data::Struct(_) => generate_fields_are_trait(input, Self::ident()),
Data::Enum(_) => Err("Deriving AnyBitPattern is not supported for enums"),
}
} }
} }
@ -89,7 +99,11 @@ impl Derivable for Zeroable {
} }
fn asserts(input: &DeriveInput) -> Result<TokenStream, &'static str> { fn asserts(input: &DeriveInput) -> Result<TokenStream, &'static str> {
generate_fields_are_trait(input, Self::ident()) match &input.data {
Data::Union(_) => Ok(quote!()), // unions are always `Zeroable`
Data::Struct(_) => generate_fields_are_trait(input, Self::ident()),
Data::Enum(_) => Err("Deriving Zeroable is not supported for enums"),
}
} }
} }
@ -107,14 +121,14 @@ impl Derivable for NoPadding {
match ty { match ty {
Data::Struct(_) => match repr.as_deref() { Data::Struct(_) => match repr.as_deref() {
Some("C" | "transparent") => Ok(()), Some("C" | "transparent") => Ok(()),
_ => Err("NoPadding requires the struct to be #[repr(C)] or #[repr(transparent)]"), _ => Err("NoPadding derive requires the type to be #[repr(C)] or #[repr(transparent)]"),
}, },
Data::Enum(_) => if repr.map(|repr| repr.starts_with('u') || repr.starts_with('i')) == Some(true) { Data::Enum(_) => if repr.map(|repr| repr.starts_with('u') || repr.starts_with('i')) == Some(true) {
Ok(()) Ok(())
} else { } else {
Err("NoPadding requires the enum to be an explicit #[repr(Int)]") Err("NoPadding requires the enum to be an explicit #[repr(Int)]")
}, },
Data::Union(_) => Err("NoPadding can only be derived on enums and structs") Data::Union(_) => Err("NoPadding cannot be derived for unions")
} }
} }
@ -141,7 +155,7 @@ impl Derivable for NoPadding {
Ok(quote!()) Ok(quote!())
} }
} }
Data::Union(_) => Err("Internal error in NoPadding derive"), // shouldn't be possible since we already error in attribute check for this case Data::Union(_) => Err("NoPadding cannot be derived for unions")
} }
} }
@ -215,7 +229,7 @@ impl TransparentWrapper {
) -> Option<TokenStream> { ) -> Option<TokenStream> {
let transparent_param = get_simple_attr(attributes, "transparent"); let transparent_param = get_simple_attr(attributes, "transparent");
transparent_param.map(|ident| ident.to_token_stream()).or_else(|| { transparent_param.map(|ident| ident.to_token_stream()).or_else(|| {
let mut types = get_field_types(fields); let mut types = get_field_types(&fields);
let first_type = types.next(); let first_type = types.next();
if let Some(_) = types.next() { if let Some(_) = types.next() {
// can't guess param type if there is more than one field // can't guess param type if there is more than one field
@ -235,13 +249,13 @@ impl Derivable for TransparentWrapper {
fn generic_params(input: &DeriveInput) -> Result<TokenStream, &'static str> { fn generic_params(input: &DeriveInput) -> Result<TokenStream, &'static str> {
let fields = get_struct_fields(input)?; let fields = get_struct_fields(input)?;
Self::get_wrapper_type(&input.attrs, fields).map(|ty| quote!(<#ty>)) Self::get_wrapper_type(&input.attrs, &fields).map(|ty| quote!(<#ty>))
.ok_or("when deriving TransparentWrapper for a struct with more than one field you need to specify the transparent field using #[transparent(T)]") .ok_or("when deriving TransparentWrapper for a struct with more than one field you need to specify the transparent field using #[transparent(T)]")
} }
fn asserts(input: &DeriveInput) -> Result<TokenStream, &'static str> { fn asserts(input: &DeriveInput) -> Result<TokenStream, &'static str> {
let fields = get_struct_fields(input)?; let fields = get_struct_fields(input)?;
let wrapped_type = match Self::get_wrapper_type(&input.attrs, fields) { let wrapped_type = match Self::get_wrapper_type(&input.attrs, &fields) {
Some(wrapped_type) => wrapped_type.to_string(), Some(wrapped_type) => wrapped_type.to_string(),
None => unreachable!(), /* other code will already reject this derive */ None => unreachable!(), /* other code will already reject this derive */
}; };
@ -334,6 +348,14 @@ fn get_struct_fields(input: &DeriveInput) -> Result<&Fields, &'static str> {
} }
} }
fn get_fields(input: &DeriveInput) -> Result<Fields, &'static str> {
match &input.data {
Data::Struct(DataStruct { fields, .. }) => Ok(fields.clone()),
Data::Union(DataUnion { fields, .. }) => Ok(Fields::Named(fields.clone())),
Data::Enum(_) => Err("deriving this trait is not supported for enums")
}
}
fn get_enum_variants<'a>( fn get_enum_variants<'a>(
input: &'a DeriveInput, input: &'a DeriveInput,
) -> Result<impl Iterator<Item = &'a Variant> + 'a, &'static str> { ) -> Result<impl Iterator<Item = &'a Variant> + 'a, &'static str> {
@ -459,7 +481,7 @@ fn generate_assert_no_padding(
) -> Result<TokenStream, &'static str> { ) -> Result<TokenStream, &'static str> {
let struct_type = &input.ident; let struct_type = &input.ident;
let span = input.ident.span(); let span = input.ident.span();
let fields = get_struct_fields(input)?; let fields = get_fields(input)?;
let mut field_types = get_field_types(&fields); let mut field_types = get_field_types(&fields);
let size_sum = if let Some(first) = field_types.next() { let size_sum = if let Some(first) = field_types.next() {
@ -484,7 +506,7 @@ fn generate_fields_are_trait(
) -> Result<TokenStream, &'static str> { ) -> Result<TokenStream, &'static str> {
let (impl_generics, _ty_generics, where_clause) = let (impl_generics, _ty_generics, where_clause) =
input.generics.split_for_impl(); input.generics.split_for_impl();
let fields = get_struct_fields(input)?; let fields = get_fields(input)?;
let span = input.span(); let span = input.span();
let field_types = get_field_types(&fields); let field_types = get_field_types(&fields);
Ok(quote_spanned! {span => #(const _: fn() = || { Ok(quote_spanned! {span => #(const _: fn() = || {

View File

@ -1,7 +1,7 @@
#![allow(dead_code)] #![allow(dead_code)]
use bytemuck::{ use bytemuck::{
AnyBitPattern, Contiguous, CheckedBitPattern, NoPadding, Pod, TransparentWrapper, Zeroable, Contiguous, CheckedBitPattern, NoPadding, Pod, TransparentWrapper, Zeroable, AnyBitPattern,
}; };
use std::marker::PhantomData; use std::marker::PhantomData;
@ -58,6 +58,13 @@ struct NoPaddingTest {
b: u16, b: u16,
} }
#[derive(Copy, Clone, AnyBitPattern)]
#[repr(C)]
union UnionTestAnyBitPattern {
a: u8,
b: u16,
}
#[repr(u8)] #[repr(u8)]
#[derive(Debug, Clone, Copy, NoPadding, CheckedBitPattern, PartialEq, Eq)] #[derive(Debug, Clone, Copy, NoPadding, CheckedBitPattern, PartialEq, Eq)]
enum CheckedBitPatternEnumWithValues { enum CheckedBitPatternEnumWithValues {