diff --git a/Cargo.toml b/Cargo.toml index e768b61..3fc2ef1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,12 @@ exclude = ["/pedantic.bat"] extern_crate_alloc = [] extern_crate_std = ["extern_crate_alloc"] zeroable_maybe_uninit = [] +derive = ["bytemuck_derive"] + +[dependencies] +bytemuck_derive = { version = "1.0", path = "derive", optional = true } [package.metadata.docs.rs] all-features = true + +[workspace] \ No newline at end of file diff --git a/derive/Cargo.toml b/derive/Cargo.toml new file mode 100644 index 0000000..251615f --- /dev/null +++ b/derive/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "bytemuck_derive" +description = "A crate for mucking around with piles of bytes." +version = "1.0.0" +authors = ["Lokathor "] +repository = "https://github.com/Lokathor/bytemuck" +readme = "README.md" +keywords = ["transmute", "bytes", "casting"] +categories = ["encoding", "no-std"] +edition = "2018" +license = "Zlib OR Apache-2.0 OR MIT" + +[lib] +name = "bytemuck_derive" +proc-macro = true + +[dependencies] +syn = "1.0" +quote = "1.0" +proc-macro2 = "1.0" + +[dev-dependencies] +bytemuck = { version = "1.3.2-alpha.0", path = ".." } diff --git a/derive/src/lib.rs b/derive/src/lib.rs new file mode 100644 index 0000000..36f89a5 --- /dev/null +++ b/derive/src/lib.rs @@ -0,0 +1,140 @@ +//! Derive macros for [bytemuck](https://docs.rs/bytemuck) traits + +mod traits; + +use proc_macro2::TokenStream; +use quote::quote; +use syn::{parse_macro_input, DeriveInput}; + +use crate::traits::{Derivable, Pod, TransparentWrapper, Zeroable}; + +/// Derive the `Pod` trait for a struct +/// +/// The macro ensures that the struct follows all the the safety requirements +/// for the `Pod` trait. +/// +/// The following constraints need to be satisfied for the macro to succeed +/// +/// - All fields in the struct must implement `Pod` +/// - The struct must be `#[repr(C)]` or `#[repr(transparent)]` +/// - The struct must not contain any padding bytes +/// - The struct contains no generic parameters +/// +/// ## Example +/// +/// ```rust +/// # use bytemuck_derive::{Pod, Zeroable}; +/// +/// #[derive(Copy, Clone, Pod, Zeroable)] +/// #[repr(C)] +/// struct Test { +/// a: u16, +/// b: u16, +/// } +/// ``` +#[proc_macro_derive(Pod)] +pub fn derive_pod(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let expanded = + derive_marker_trait::(parse_macro_input!(input as DeriveInput)); + + proc_macro::TokenStream::from(expanded) +} + +/// Derive the `Zeroable` trait for a struct +/// +/// The macro ensures that the struct follows all the the safety requirements +/// for the `Zeroable` trait. +/// +/// The following constraints need to be satisfied for the macro to succeed +/// +/// - All fields ind the struct must to implement `Zeroable` +/// +/// ## Example +/// +/// ```rust +/// # use bytemuck_derive::{Zeroable}; +/// +/// #[derive(Copy, Clone, Zeroable)] +/// #[repr(C)] +/// struct Test { +/// a: u16, +/// b: u16, +/// } +/// ``` +#[proc_macro_derive(Zeroable)] +pub fn derive_zeroable( + input: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + let expanded = + derive_marker_trait::(parse_macro_input!(input as DeriveInput)); + + proc_macro::TokenStream::from(expanded) +} + +/// Derive the `TransparentWrapper` trait for a struct +/// +/// The macro ensures that the struct follows all the the safety requirements +/// for the `TransparentWrapper` trait. +/// +/// The following constraints need to be satisfied for the macro to succeed +/// +/// - The struct must be `#[repr(transparent)] +/// - The struct must contain the `Wrapped` type +/// +/// If the struct only contains a single field, the `Wrapped` type will +/// automatically be determined if there is more then one field in the struct, +/// you need to specify the `Wrapped` type using `#[transparent(T)]` +/// +/// ## Example +/// +/// ```rust +/// # use bytemuck_derive::TransparentWrapper; +/// # use std::marker::PhantomData; +/// +/// #[derive(Copy, Clone, TransparentWrapper)] +/// #[repr(transparent)] +/// #[transparent(u16)] +/// struct Test { +/// inner: u16, +/// extra: PhantomData, +/// } +/// ``` +#[proc_macro_derive(TransparentWrapper, attributes(transparent))] +pub fn derive_transparent( + input: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + let expanded = derive_marker_trait::(parse_macro_input!( + input as DeriveInput + )); + + proc_macro::TokenStream::from(expanded) +} + +/// Basic wrapper for error handling +fn derive_marker_trait(input: DeriveInput) -> TokenStream { + derive_marker_trait_inner::(input).unwrap_or_else(|err| { + quote! { + compile_error!(#err); + } + }) +} + +fn derive_marker_trait_inner( + input: DeriveInput, +) -> Result { + let name = &input.ident; + + let (impl_generics, ty_generics, where_clause) = + input.generics.split_for_impl(); + + let trait_ = Trait::ident(); + Trait::check_attributes(&input.attrs)?; + let asserts = Trait::struct_asserts(&input)?; + let trait_params = Trait::generic_params(&input)?; + + Ok(quote! { + #asserts + + unsafe impl #impl_generics #trait_ #trait_params for #name #ty_generics #where_clause {} + }) +} diff --git a/derive/src/traits.rs b/derive/src/traits.rs new file mode 100644 index 0000000..5eb86ee --- /dev/null +++ b/derive/src/traits.rs @@ -0,0 +1,207 @@ +use proc_macro2::{Ident, TokenStream, TokenTree}; +use quote::{quote, quote_spanned, ToTokens}; +use syn::{ + spanned::Spanned, AttrStyle, Attribute, Data, DataStruct, DeriveInput, + Fields, Type, +}; + +pub trait Derivable { + fn ident() -> TokenStream; + fn generic_params(_input: &DeriveInput) -> Result { + Ok(quote!()) + } + fn struct_asserts(input: &DeriveInput) -> Result; + fn check_attributes(_attributes: &[Attribute]) -> Result<(), &'static str> { + Ok(()) + } +} + +pub struct Pod; + +impl Derivable for Pod { + fn ident() -> TokenStream { + quote!(::bytemuck::Pod) + } + + fn struct_asserts(input: &DeriveInput) -> Result { + if !input.generics.params.is_empty() { + return Err("Pod requires cannot be derived for structs containing generic parameters because the padding requirements can't be verified for generic structs"); + } + + let assert_no_padding = generate_assert_no_padding(input)?; + let assert_fields_are_pod = + generate_fields_are_trait(input, Self::ident())?; + + Ok(quote!( + #assert_no_padding + #assert_fields_are_pod + )) + } + + fn check_attributes(attributes: &[Attribute]) -> Result<(), &'static str> { + let repr = get_repr(attributes); + match repr.as_ref().map(|repr| repr.as_str()) { + Some("C") => Ok(()), + Some("transparent") => Ok(()), + _ => { + Err("Pod requires the struct to be #[repr(C)] or #[repr(transparent)]") + } + } + } +} + +pub struct Zeroable; + +impl Derivable for Zeroable { + fn ident() -> TokenStream { + quote!(::bytemuck::Zeroable) + } + + fn struct_asserts(input: &DeriveInput) -> Result { + generate_fields_are_trait(input, Self::ident()) + } +} + +pub struct TransparentWrapper; + +impl TransparentWrapper { + fn get_wrapper_type( + attributes: &[Attribute], fields: &Fields, + ) -> Option { + let transparent_param = get_simple_attr(attributes, "transparent"); + transparent_param.map(|ident| ident.to_token_stream()).or_else(|| { + let mut types = get_field_types(fields); + let first_type = types.next(); + if let Some(_) = types.next() { + // can't guess param type if there is more than one field + return None; + } else { + first_type.map(|ty| ty.to_token_stream()) + } + }) + } +} + +impl Derivable for TransparentWrapper { + fn ident() -> TokenStream { + quote!(::bytemuck::TransparentWrapper) + } + + fn generic_params(input: &DeriveInput) -> Result { + let fields = get_struct_fields(input)?; + + Self::get_wrapper_type(&input.attrs, fields).map(|ty| quote!(<#ty>)) + .ok_or("when deriving TransparentWrapper for a struct with more than one field you need to specify the transparent field using #[transparent(T)]") + } + + fn struct_asserts(input: &DeriveInput) -> Result { + let fields = get_struct_fields(input)?; + let wrapped_type = match Self::get_wrapper_type(&input.attrs, fields) { + Some(wrapped_type) => wrapped_type.to_string(), + None => unreachable!(), /* other code will already reject this derive */ + }; + let mut wrapped_fields = fields + .iter() + .filter(|field| field.ty.to_token_stream().to_string() == wrapped_type); + if let None = wrapped_fields.next() { + return Err("TransparentWrapper must have one field of the wrapped type"); + }; + if let Some(_) = wrapped_fields.next() { + Err("TransparentWrapper can only have one field of the wrapped type") + } else { + Ok(quote!()) + } + } + + fn check_attributes(attributes: &[Attribute]) -> Result<(), &'static str> { + let repr = get_repr(attributes); + + match repr.as_ref().map(|repr| repr.as_str()) { + Some("transparent") => Ok(()), + _ => { + Err("TransparentWrapper requires the struct to be #[repr(transparent)]") + } + } + } +} + +fn get_struct_fields(input: &DeriveInput) -> Result<&Fields, &'static str> { + if let Data::Struct(DataStruct { fields, .. }) = &input.data { + Ok(fields) + } else { + Err("deriving this trait is only supported for structs") + } +} + +fn get_field_types<'a>( + fields: &'a Fields, +) -> impl Iterator + 'a { + fields.iter().map(|field| &field.ty) +} + +/// Check that a struct has no padding by asserting that the size of the struct +/// is equal to the sum of the size of it's fields +fn generate_assert_no_padding( + input: &DeriveInput, +) -> Result { + let struct_type = &input.ident; + let span = input.span(); + let fields = get_struct_fields(input)?; + + let field_types = get_field_types(&fields); + let struct_size = + quote_spanned!(span => core::mem::size_of::<#struct_type>()); + let size_sum = + quote_spanned!(span => 0 #( + core::mem::size_of::<#field_types>() )*); + + Ok(quote_spanned! {span => const _: fn() = || { + let _ = core::mem::transmute::<[u8; #struct_size], [u8; #size_sum]>; + };}) +} + +/// Check that all fields implement a given trait +fn generate_fields_are_trait( + input: &DeriveInput, trait_: TokenStream, +) -> Result { + let (impl_generics, _ty_generics, where_clause) = + input.generics.split_for_impl(); + let fields = get_struct_fields(input)?; + let span = input.span(); + let field_types = get_field_types(&fields); + Ok(quote_spanned! {span => #(const _: fn() = || { + fn check #impl_generics () #where_clause { + fn assert_impl() {} + assert_impl::<#field_types>(); + } + };)* + }) +} + +fn get_ident_from_stream(tokens: TokenStream) -> Option { + match tokens.into_iter().next() { + Some(TokenTree::Group(group)) => get_ident_from_stream(group.stream()), + Some(TokenTree::Ident(ident)) => Some(ident), + _ => None, + } +} + +/// get a simple #[foo(bar)] attribute, returning "bar" +fn get_simple_attr(attributes: &[Attribute], attr_name: &str) -> Option { + for attr in attributes { + if let (AttrStyle::Outer, Some(outer_ident), Some(inner_ident)) = ( + &attr.style, + attr.path.get_ident(), + get_ident_from_stream(attr.tokens.clone()), + ) { + if outer_ident.to_string() == attr_name { + return Some(inner_ident); + } + } + } + + None +} + +fn get_repr(attributes: &[Attribute]) -> Option { + get_simple_attr(attributes, "repr").map(|ident| ident.to_string()) +} diff --git a/derive/tests/basic.rs b/derive/tests/basic.rs new file mode 100644 index 0000000..8ab69b2 --- /dev/null +++ b/derive/tests/basic.rs @@ -0,0 +1,30 @@ +#![allow(dead_code)] + +use bytemuck_derive::{Pod, TransparentWrapper, Zeroable}; +use std::marker::PhantomData; + +#[derive(Copy, Clone, Pod, Zeroable)] +#[repr(C)] +struct Test { + a: u16, + b: u16, +} + +#[derive(Zeroable)] +struct ZeroGeneric { + a: T, +} + +#[derive(TransparentWrapper)] +#[repr(transparent)] +struct TransparentSingle { + a: u16, +} + +#[derive(TransparentWrapper)] +#[repr(transparent)] +#[transparent(u16)] +struct TransparentWithZeroSized { + a: u16, + b: PhantomData, +} diff --git a/src/lib.rs b/src/lib.rs index 90455e0..b9a3ec2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -81,6 +81,9 @@ pub use offset_of::*; mod transparent; pub use transparent::*; +#[cfg(feature = "derive")] +pub use bytemuck_derive::{Zeroable, Pod, TransparentWrapper}; + /* Note(Lokathor): We've switched all of the `unwrap` to `match` because there is diff --git a/tests/derive.rs b/tests/derive.rs new file mode 100644 index 0000000..1e1deeb --- /dev/null +++ b/tests/derive.rs @@ -0,0 +1,25 @@ +#![cfg(feature = "derive")] +#![allow(dead_code)] + +use bytemuck::{Zeroable, Pod, TransparentWrapper}; + +#[derive(Copy, Clone, Pod, Zeroable)] +#[repr(C)] +struct Test { + a: u16, + b: u16, +} + +#[derive(TransparentWrapper)] +#[repr(transparent)] +struct TransparentSingle { + a: u16, +} + +#[derive(TransparentWrapper)] +#[repr(transparent)] +#[transparent(u16)] +struct TransparentWithZeroSized { + a: u16, + b: () +} \ No newline at end of file