add basic derive macro for Pod, Zeroable and TransparentWrapper for structs (#30)

* add basic derive macro for Pod and Zeroable for structs

* add derive macro for TransparentWrapper

* use core::mem::size_of instead of std::mem::size_of in generated code

* cleanup error handling a bit

* remove unneeded iter logic

* remove unneeded clone and order impl

* fix generics

* fix doc typo

Co-authored-by: Lucien Greathouse <me@lpghatguy.com>

* remove unneeded lifetime anotation

* use unreachable for already rejected patch

Co-authored-by: Lucien Greathouse <me@lpghatguy.com>
This commit is contained in:
Robin Appelman 2020-08-21 01:04:36 +00:00 committed by GitHub
parent 94d71d9925
commit cf944452b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 434 additions and 0 deletions

View File

@ -16,6 +16,12 @@ exclude = ["/pedantic.bat"]
extern_crate_alloc = []
extern_crate_std = ["extern_crate_alloc"]
zeroable_maybe_uninit = []
derive = ["bytemuck_derive"]
[dependencies]
bytemuck_derive = { version = "1.0", path = "derive", optional = true }
[package.metadata.docs.rs]
all-features = true
[workspace]

23
derive/Cargo.toml Normal file
View File

@ -0,0 +1,23 @@
[package]
name = "bytemuck_derive"
description = "A crate for mucking around with piles of bytes."
version = "1.0.0"
authors = ["Lokathor <zefria@gmail.com>"]
repository = "https://github.com/Lokathor/bytemuck"
readme = "README.md"
keywords = ["transmute", "bytes", "casting"]
categories = ["encoding", "no-std"]
edition = "2018"
license = "Zlib OR Apache-2.0 OR MIT"
[lib]
name = "bytemuck_derive"
proc-macro = true
[dependencies]
syn = "1.0"
quote = "1.0"
proc-macro2 = "1.0"
[dev-dependencies]
bytemuck = { version = "1.3.2-alpha.0", path = ".." }

140
derive/src/lib.rs Normal file
View File

@ -0,0 +1,140 @@
//! Derive macros for [bytemuck](https://docs.rs/bytemuck) traits
mod traits;
use proc_macro2::TokenStream;
use quote::quote;
use syn::{parse_macro_input, DeriveInput};
use crate::traits::{Derivable, Pod, TransparentWrapper, Zeroable};
/// Derive the `Pod` trait for a struct
///
/// The macro ensures that the struct follows all the the safety requirements
/// for the `Pod` trait.
///
/// The following constraints need to be satisfied for the macro to succeed
///
/// - All fields in the struct must implement `Pod`
/// - The struct must be `#[repr(C)]` or `#[repr(transparent)]`
/// - The struct must not contain any padding bytes
/// - The struct contains no generic parameters
///
/// ## Example
///
/// ```rust
/// # use bytemuck_derive::{Pod, Zeroable};
///
/// #[derive(Copy, Clone, Pod, Zeroable)]
/// #[repr(C)]
/// struct Test {
/// a: u16,
/// b: u16,
/// }
/// ```
#[proc_macro_derive(Pod)]
pub fn derive_pod(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let expanded =
derive_marker_trait::<Pod>(parse_macro_input!(input as DeriveInput));
proc_macro::TokenStream::from(expanded)
}
/// Derive the `Zeroable` trait for a struct
///
/// The macro ensures that the struct follows all the the safety requirements
/// for the `Zeroable` trait.
///
/// The following constraints need to be satisfied for the macro to succeed
///
/// - All fields ind the struct must to implement `Zeroable`
///
/// ## Example
///
/// ```rust
/// # use bytemuck_derive::{Zeroable};
///
/// #[derive(Copy, Clone, Zeroable)]
/// #[repr(C)]
/// struct Test {
/// a: u16,
/// b: u16,
/// }
/// ```
#[proc_macro_derive(Zeroable)]
pub fn derive_zeroable(
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let expanded =
derive_marker_trait::<Zeroable>(parse_macro_input!(input as DeriveInput));
proc_macro::TokenStream::from(expanded)
}
/// Derive the `TransparentWrapper` trait for a struct
///
/// The macro ensures that the struct follows all the the safety requirements
/// for the `TransparentWrapper` trait.
///
/// The following constraints need to be satisfied for the macro to succeed
///
/// - The struct must be `#[repr(transparent)]
/// - The struct must contain the `Wrapped` type
///
/// If the struct only contains a single field, the `Wrapped` type will
/// automatically be determined if there is more then one field in the struct,
/// you need to specify the `Wrapped` type using `#[transparent(T)]`
///
/// ## Example
///
/// ```rust
/// # use bytemuck_derive::TransparentWrapper;
/// # use std::marker::PhantomData;
///
/// #[derive(Copy, Clone, TransparentWrapper)]
/// #[repr(transparent)]
/// #[transparent(u16)]
/// struct Test<T> {
/// inner: u16,
/// extra: PhantomData<T>,
/// }
/// ```
#[proc_macro_derive(TransparentWrapper, attributes(transparent))]
pub fn derive_transparent(
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let expanded = derive_marker_trait::<TransparentWrapper>(parse_macro_input!(
input as DeriveInput
));
proc_macro::TokenStream::from(expanded)
}
/// Basic wrapper for error handling
fn derive_marker_trait<Trait: Derivable>(input: DeriveInput) -> TokenStream {
derive_marker_trait_inner::<Trait>(input).unwrap_or_else(|err| {
quote! {
compile_error!(#err);
}
})
}
fn derive_marker_trait_inner<Trait: Derivable>(
input: DeriveInput,
) -> Result<TokenStream, &'static str> {
let name = &input.ident;
let (impl_generics, ty_generics, where_clause) =
input.generics.split_for_impl();
let trait_ = Trait::ident();
Trait::check_attributes(&input.attrs)?;
let asserts = Trait::struct_asserts(&input)?;
let trait_params = Trait::generic_params(&input)?;
Ok(quote! {
#asserts
unsafe impl #impl_generics #trait_ #trait_params for #name #ty_generics #where_clause {}
})
}

207
derive/src/traits.rs Normal file
View File

@ -0,0 +1,207 @@
use proc_macro2::{Ident, TokenStream, TokenTree};
use quote::{quote, quote_spanned, ToTokens};
use syn::{
spanned::Spanned, AttrStyle, Attribute, Data, DataStruct, DeriveInput,
Fields, Type,
};
pub trait Derivable {
fn ident() -> TokenStream;
fn generic_params(_input: &DeriveInput) -> Result<TokenStream, &'static str> {
Ok(quote!())
}
fn struct_asserts(input: &DeriveInput) -> Result<TokenStream, &'static str>;
fn check_attributes(_attributes: &[Attribute]) -> Result<(), &'static str> {
Ok(())
}
}
pub struct Pod;
impl Derivable for Pod {
fn ident() -> TokenStream {
quote!(::bytemuck::Pod)
}
fn struct_asserts(input: &DeriveInput) -> Result<TokenStream, &'static str> {
if !input.generics.params.is_empty() {
return Err("Pod requires cannot be derived for structs containing generic parameters because the padding requirements can't be verified for generic structs");
}
let assert_no_padding = generate_assert_no_padding(input)?;
let assert_fields_are_pod =
generate_fields_are_trait(input, Self::ident())?;
Ok(quote!(
#assert_no_padding
#assert_fields_are_pod
))
}
fn check_attributes(attributes: &[Attribute]) -> Result<(), &'static str> {
let repr = get_repr(attributes);
match repr.as_ref().map(|repr| repr.as_str()) {
Some("C") => Ok(()),
Some("transparent") => Ok(()),
_ => {
Err("Pod requires the struct to be #[repr(C)] or #[repr(transparent)]")
}
}
}
}
pub struct Zeroable;
impl Derivable for Zeroable {
fn ident() -> TokenStream {
quote!(::bytemuck::Zeroable)
}
fn struct_asserts(input: &DeriveInput) -> Result<TokenStream, &'static str> {
generate_fields_are_trait(input, Self::ident())
}
}
pub struct TransparentWrapper;
impl TransparentWrapper {
fn get_wrapper_type(
attributes: &[Attribute], fields: &Fields,
) -> Option<TokenStream> {
let transparent_param = get_simple_attr(attributes, "transparent");
transparent_param.map(|ident| ident.to_token_stream()).or_else(|| {
let mut types = get_field_types(fields);
let first_type = types.next();
if let Some(_) = types.next() {
// can't guess param type if there is more than one field
return None;
} else {
first_type.map(|ty| ty.to_token_stream())
}
})
}
}
impl Derivable for TransparentWrapper {
fn ident() -> TokenStream {
quote!(::bytemuck::TransparentWrapper)
}
fn generic_params(input: &DeriveInput) -> Result<TokenStream, &'static str> {
let fields = get_struct_fields(input)?;
Self::get_wrapper_type(&input.attrs, fields).map(|ty| quote!(<#ty>))
.ok_or("when deriving TransparentWrapper for a struct with more than one field you need to specify the transparent field using #[transparent(T)]")
}
fn struct_asserts(input: &DeriveInput) -> Result<TokenStream, &'static str> {
let fields = get_struct_fields(input)?;
let wrapped_type = match Self::get_wrapper_type(&input.attrs, fields) {
Some(wrapped_type) => wrapped_type.to_string(),
None => unreachable!(), /* other code will already reject this derive */
};
let mut wrapped_fields = fields
.iter()
.filter(|field| field.ty.to_token_stream().to_string() == wrapped_type);
if let None = wrapped_fields.next() {
return Err("TransparentWrapper must have one field of the wrapped type");
};
if let Some(_) = wrapped_fields.next() {
Err("TransparentWrapper can only have one field of the wrapped type")
} else {
Ok(quote!())
}
}
fn check_attributes(attributes: &[Attribute]) -> Result<(), &'static str> {
let repr = get_repr(attributes);
match repr.as_ref().map(|repr| repr.as_str()) {
Some("transparent") => Ok(()),
_ => {
Err("TransparentWrapper requires the struct to be #[repr(transparent)]")
}
}
}
}
fn get_struct_fields(input: &DeriveInput) -> Result<&Fields, &'static str> {
if let Data::Struct(DataStruct { fields, .. }) = &input.data {
Ok(fields)
} else {
Err("deriving this trait is only supported for structs")
}
}
fn get_field_types<'a>(
fields: &'a Fields,
) -> impl Iterator<Item = &'a Type> + 'a {
fields.iter().map(|field| &field.ty)
}
/// Check that a struct has no padding by asserting that the size of the struct
/// is equal to the sum of the size of it's fields
fn generate_assert_no_padding(
input: &DeriveInput,
) -> Result<TokenStream, &'static str> {
let struct_type = &input.ident;
let span = input.span();
let fields = get_struct_fields(input)?;
let field_types = get_field_types(&fields);
let struct_size =
quote_spanned!(span => core::mem::size_of::<#struct_type>());
let size_sum =
quote_spanned!(span => 0 #( + core::mem::size_of::<#field_types>() )*);
Ok(quote_spanned! {span => const _: fn() = || {
let _ = core::mem::transmute::<[u8; #struct_size], [u8; #size_sum]>;
};})
}
/// Check that all fields implement a given trait
fn generate_fields_are_trait(
input: &DeriveInput, trait_: TokenStream,
) -> Result<TokenStream, &'static str> {
let (impl_generics, _ty_generics, where_clause) =
input.generics.split_for_impl();
let fields = get_struct_fields(input)?;
let span = input.span();
let field_types = get_field_types(&fields);
Ok(quote_spanned! {span => #(const _: fn() = || {
fn check #impl_generics () #where_clause {
fn assert_impl<T: #trait_>() {}
assert_impl::<#field_types>();
}
};)*
})
}
fn get_ident_from_stream(tokens: TokenStream) -> Option<Ident> {
match tokens.into_iter().next() {
Some(TokenTree::Group(group)) => get_ident_from_stream(group.stream()),
Some(TokenTree::Ident(ident)) => Some(ident),
_ => None,
}
}
/// get a simple #[foo(bar)] attribute, returning "bar"
fn get_simple_attr(attributes: &[Attribute], attr_name: &str) -> Option<Ident> {
for attr in attributes {
if let (AttrStyle::Outer, Some(outer_ident), Some(inner_ident)) = (
&attr.style,
attr.path.get_ident(),
get_ident_from_stream(attr.tokens.clone()),
) {
if outer_ident.to_string() == attr_name {
return Some(inner_ident);
}
}
}
None
}
fn get_repr(attributes: &[Attribute]) -> Option<String> {
get_simple_attr(attributes, "repr").map(|ident| ident.to_string())
}

30
derive/tests/basic.rs Normal file
View File

@ -0,0 +1,30 @@
#![allow(dead_code)]
use bytemuck_derive::{Pod, TransparentWrapper, Zeroable};
use std::marker::PhantomData;
#[derive(Copy, Clone, Pod, Zeroable)]
#[repr(C)]
struct Test {
a: u16,
b: u16,
}
#[derive(Zeroable)]
struct ZeroGeneric<T: bytemuck::Zeroable> {
a: T,
}
#[derive(TransparentWrapper)]
#[repr(transparent)]
struct TransparentSingle {
a: u16,
}
#[derive(TransparentWrapper)]
#[repr(transparent)]
#[transparent(u16)]
struct TransparentWithZeroSized<T> {
a: u16,
b: PhantomData<T>,
}

View File

@ -81,6 +81,9 @@ pub use offset_of::*;
mod transparent;
pub use transparent::*;
#[cfg(feature = "derive")]
pub use bytemuck_derive::{Zeroable, Pod, TransparentWrapper};
/*
Note(Lokathor): We've switched all of the `unwrap` to `match` because there is

25
tests/derive.rs Normal file
View File

@ -0,0 +1,25 @@
#![cfg(feature = "derive")]
#![allow(dead_code)]
use bytemuck::{Zeroable, Pod, TransparentWrapper};
#[derive(Copy, Clone, Pod, Zeroable)]
#[repr(C)]
struct Test {
a: u16,
b: u16,
}
#[derive(TransparentWrapper)]
#[repr(transparent)]
struct TransparentSingle {
a: u16,
}
#[derive(TransparentWrapper)]
#[repr(transparent)]
#[transparent(u16)]
struct TransparentWithZeroSized {
a: u16,
b: ()
}