diff --git a/derive/src/lib.rs b/derive/src/lib.rs index b26d90a..bbbdf9d 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -5,8 +5,8 @@ extern crate proc_macro; mod traits; use proc_macro2::TokenStream; -use quote::quote; -use syn::{parse_macro_input, DeriveInput}; +use quote::{quote, quote_spanned}; +use syn::{parse_macro_input, DeriveInput, spanned::Spanned}; use crate::traits::{ AnyBitPattern, Contiguous, Derivable, CheckedBitPattern, NoPadding, Pod, TransparentWrapper, Zeroable, @@ -223,8 +223,9 @@ pub fn derive_contiguous( /// Basic wrapper for error handling fn derive_marker_trait(input: DeriveInput) -> TokenStream { + let span = input.span(); derive_marker_trait_inner::(input).unwrap_or_else(|err| { - quote! { + quote_spanned! { span => compile_error!(#err); } }) diff --git a/derive/src/traits.rs b/derive/src/traits.rs index 31beeb4..38ba919 100644 --- a/derive/src/traits.rs +++ b/derive/src/traits.rs @@ -3,7 +3,7 @@ use quote::{quote, quote_spanned, ToTokens}; use syn::{ spanned::Spanned, AttrStyle, Attribute, Data, DataEnum, DataStruct, DeriveInput, Expr, ExprLit, ExprUnary, Fields, Lit, LitInt, Type, UnOp, - Variant, + Variant, DataUnion, }; pub trait Derivable { @@ -38,17 +38,23 @@ impl Derivable for Pod { fn asserts(input: &DeriveInput) -> Result { if !input.generics.params.is_empty() { - return Err("Pod requires cannot be derived for structs containing generic parameters because the padding requirements can't be verified for generic structs"); + return Err("Pod requires cannot be derived for types containing generic parameters because the padding requirements can't be verified for generic structs"); } - let assert_no_padding = generate_assert_no_padding(input)?; - let assert_fields_are_pod = - generate_fields_are_trait(input, Self::ident())?; + match &input.data { + Data::Struct(_) => { + let assert_no_padding = generate_assert_no_padding(input)?; + let assert_fields_are_pod = + generate_fields_are_trait(input, Self::ident())?; - Ok(quote!( - #assert_no_padding - #assert_fields_are_pod - )) + Ok(quote!( + #assert_no_padding + #assert_fields_are_pod + )) + }, + Data::Enum(_) => Err("Deriving Pod is not supported for enums"), + Data::Union(_) => Err("Deriving Pod is not supported for unions") + } } fn check_attributes( @@ -59,7 +65,7 @@ impl Derivable for Pod { Some("C") => Ok(()), Some("transparent") => Ok(()), _ => { - Err("Pod requires the struct to be #[repr(C)] or #[repr(transparent)]") + Err("Pod requires the type to be #[repr(C)] or #[repr(transparent)]") } } } @@ -77,7 +83,11 @@ impl Derivable for AnyBitPattern { } fn asserts(input: &DeriveInput) -> Result { - generate_fields_are_trait(input, Self::ident()) + match &input.data { + Data::Union(_) => Ok(quote!()), // unions are always `AnyBitPattern` + Data::Struct(_) => generate_fields_are_trait(input, Self::ident()), + Data::Enum(_) => Err("Deriving AnyBitPattern is not supported for enums"), + } } } @@ -89,7 +99,11 @@ impl Derivable for Zeroable { } fn asserts(input: &DeriveInput) -> Result { - generate_fields_are_trait(input, Self::ident()) + match &input.data { + Data::Union(_) => Ok(quote!()), // unions are always `Zeroable` + Data::Struct(_) => generate_fields_are_trait(input, Self::ident()), + Data::Enum(_) => Err("Deriving Zeroable is not supported for enums"), + } } } @@ -107,14 +121,14 @@ impl Derivable for NoPadding { match ty { Data::Struct(_) => match repr.as_deref() { Some("C" | "transparent") => Ok(()), - _ => Err("NoPadding requires the struct to be #[repr(C)] or #[repr(transparent)]"), + _ => Err("NoPadding derive requires the type to be #[repr(C)] or #[repr(transparent)]"), }, Data::Enum(_) => if repr.map(|repr| repr.starts_with('u') || repr.starts_with('i')) == Some(true) { Ok(()) } else { Err("NoPadding requires the enum to be an explicit #[repr(Int)]") }, - Data::Union(_) => Err("NoPadding can only be derived on enums and structs") + Data::Union(_) => Err("NoPadding cannot be derived for unions") } } @@ -141,7 +155,7 @@ impl Derivable for NoPadding { Ok(quote!()) } } - Data::Union(_) => Err("Internal error in NoPadding derive"), // shouldn't be possible since we already error in attribute check for this case + Data::Union(_) => Err("NoPadding cannot be derived for unions") } } @@ -215,7 +229,7 @@ impl TransparentWrapper { ) -> Option { let transparent_param = get_simple_attr(attributes, "transparent"); transparent_param.map(|ident| ident.to_token_stream()).or_else(|| { - let mut types = get_field_types(fields); + let mut types = get_field_types(&fields); let first_type = types.next(); if let Some(_) = types.next() { // can't guess param type if there is more than one field @@ -235,13 +249,13 @@ impl Derivable for TransparentWrapper { fn generic_params(input: &DeriveInput) -> Result { let fields = get_struct_fields(input)?; - Self::get_wrapper_type(&input.attrs, fields).map(|ty| quote!(<#ty>)) + Self::get_wrapper_type(&input.attrs, &fields).map(|ty| quote!(<#ty>)) .ok_or("when deriving TransparentWrapper for a struct with more than one field you need to specify the transparent field using #[transparent(T)]") } fn asserts(input: &DeriveInput) -> Result { let fields = get_struct_fields(input)?; - let wrapped_type = match Self::get_wrapper_type(&input.attrs, fields) { + let wrapped_type = match Self::get_wrapper_type(&input.attrs, &fields) { Some(wrapped_type) => wrapped_type.to_string(), None => unreachable!(), /* other code will already reject this derive */ }; @@ -334,6 +348,14 @@ fn get_struct_fields(input: &DeriveInput) -> Result<&Fields, &'static str> { } } +fn get_fields(input: &DeriveInput) -> Result { + match &input.data { + Data::Struct(DataStruct { fields, .. }) => Ok(fields.clone()), + Data::Union(DataUnion { fields, .. }) => Ok(Fields::Named(fields.clone())), + Data::Enum(_) => Err("deriving this trait is not supported for enums") + } +} + fn get_enum_variants<'a>( input: &'a DeriveInput, ) -> Result + 'a, &'static str> { @@ -459,7 +481,7 @@ fn generate_assert_no_padding( ) -> Result { let struct_type = &input.ident; let span = input.ident.span(); - let fields = get_struct_fields(input)?; + let fields = get_fields(input)?; let mut field_types = get_field_types(&fields); let size_sum = if let Some(first) = field_types.next() { @@ -484,7 +506,7 @@ fn generate_fields_are_trait( ) -> Result { let (impl_generics, _ty_generics, where_clause) = input.generics.split_for_impl(); - let fields = get_struct_fields(input)?; + let fields = get_fields(input)?; let span = input.span(); let field_types = get_field_types(&fields); Ok(quote_spanned! {span => #(const _: fn() = || { diff --git a/derive/tests/basic.rs b/derive/tests/basic.rs index 17b4d51..252ce02 100644 --- a/derive/tests/basic.rs +++ b/derive/tests/basic.rs @@ -1,7 +1,7 @@ #![allow(dead_code)] use bytemuck::{ - AnyBitPattern, Contiguous, CheckedBitPattern, NoPadding, Pod, TransparentWrapper, Zeroable, + Contiguous, CheckedBitPattern, NoPadding, Pod, TransparentWrapper, Zeroable, AnyBitPattern, }; use std::marker::PhantomData; @@ -58,6 +58,13 @@ struct NoPaddingTest { b: u16, } +#[derive(Copy, Clone, AnyBitPattern)] +#[repr(C)] +union UnionTestAnyBitPattern { + a: u8, + b: u16, +} + #[repr(u8)] #[derive(Debug, Clone, Copy, NoPadding, CheckedBitPattern, PartialEq, Eq)] enum CheckedBitPatternEnumWithValues {