mirror of
https://github.com/Lokathor/bytemuck.git
synced 2024-11-21 14:22:26 +00:00
add basic derive macro for Pod, Zeroable and TransparentWrapper for structs (#30)
* add basic derive macro for Pod and Zeroable for structs * add derive macro for TransparentWrapper * use core::mem::size_of instead of std::mem::size_of in generated code * cleanup error handling a bit * remove unneeded iter logic * remove unneeded clone and order impl * fix generics * fix doc typo Co-authored-by: Lucien Greathouse <me@lpghatguy.com> * remove unneeded lifetime anotation * use unreachable for already rejected patch Co-authored-by: Lucien Greathouse <me@lpghatguy.com>
This commit is contained in:
parent
94d71d9925
commit
cf944452b7
@ -16,6 +16,12 @@ exclude = ["/pedantic.bat"]
|
||||
extern_crate_alloc = []
|
||||
extern_crate_std = ["extern_crate_alloc"]
|
||||
zeroable_maybe_uninit = []
|
||||
derive = ["bytemuck_derive"]
|
||||
|
||||
[dependencies]
|
||||
bytemuck_derive = { version = "1.0", path = "derive", optional = true }
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
all-features = true
|
||||
|
||||
[workspace]
|
23
derive/Cargo.toml
Normal file
23
derive/Cargo.toml
Normal file
@ -0,0 +1,23 @@
|
||||
[package]
|
||||
name = "bytemuck_derive"
|
||||
description = "A crate for mucking around with piles of bytes."
|
||||
version = "1.0.0"
|
||||
authors = ["Lokathor <zefria@gmail.com>"]
|
||||
repository = "https://github.com/Lokathor/bytemuck"
|
||||
readme = "README.md"
|
||||
keywords = ["transmute", "bytes", "casting"]
|
||||
categories = ["encoding", "no-std"]
|
||||
edition = "2018"
|
||||
license = "Zlib OR Apache-2.0 OR MIT"
|
||||
|
||||
[lib]
|
||||
name = "bytemuck_derive"
|
||||
proc-macro = true
|
||||
|
||||
[dependencies]
|
||||
syn = "1.0"
|
||||
quote = "1.0"
|
||||
proc-macro2 = "1.0"
|
||||
|
||||
[dev-dependencies]
|
||||
bytemuck = { version = "1.3.2-alpha.0", path = ".." }
|
140
derive/src/lib.rs
Normal file
140
derive/src/lib.rs
Normal file
@ -0,0 +1,140 @@
|
||||
//! Derive macros for [bytemuck](https://docs.rs/bytemuck) traits
|
||||
|
||||
mod traits;
|
||||
|
||||
use proc_macro2::TokenStream;
|
||||
use quote::quote;
|
||||
use syn::{parse_macro_input, DeriveInput};
|
||||
|
||||
use crate::traits::{Derivable, Pod, TransparentWrapper, Zeroable};
|
||||
|
||||
/// Derive the `Pod` trait for a struct
|
||||
///
|
||||
/// The macro ensures that the struct follows all the the safety requirements
|
||||
/// for the `Pod` trait.
|
||||
///
|
||||
/// The following constraints need to be satisfied for the macro to succeed
|
||||
///
|
||||
/// - All fields in the struct must implement `Pod`
|
||||
/// - The struct must be `#[repr(C)]` or `#[repr(transparent)]`
|
||||
/// - The struct must not contain any padding bytes
|
||||
/// - The struct contains no generic parameters
|
||||
///
|
||||
/// ## Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # use bytemuck_derive::{Pod, Zeroable};
|
||||
///
|
||||
/// #[derive(Copy, Clone, Pod, Zeroable)]
|
||||
/// #[repr(C)]
|
||||
/// struct Test {
|
||||
/// a: u16,
|
||||
/// b: u16,
|
||||
/// }
|
||||
/// ```
|
||||
#[proc_macro_derive(Pod)]
|
||||
pub fn derive_pod(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
|
||||
let expanded =
|
||||
derive_marker_trait::<Pod>(parse_macro_input!(input as DeriveInput));
|
||||
|
||||
proc_macro::TokenStream::from(expanded)
|
||||
}
|
||||
|
||||
/// Derive the `Zeroable` trait for a struct
|
||||
///
|
||||
/// The macro ensures that the struct follows all the the safety requirements
|
||||
/// for the `Zeroable` trait.
|
||||
///
|
||||
/// The following constraints need to be satisfied for the macro to succeed
|
||||
///
|
||||
/// - All fields ind the struct must to implement `Zeroable`
|
||||
///
|
||||
/// ## Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # use bytemuck_derive::{Zeroable};
|
||||
///
|
||||
/// #[derive(Copy, Clone, Zeroable)]
|
||||
/// #[repr(C)]
|
||||
/// struct Test {
|
||||
/// a: u16,
|
||||
/// b: u16,
|
||||
/// }
|
||||
/// ```
|
||||
#[proc_macro_derive(Zeroable)]
|
||||
pub fn derive_zeroable(
|
||||
input: proc_macro::TokenStream,
|
||||
) -> proc_macro::TokenStream {
|
||||
let expanded =
|
||||
derive_marker_trait::<Zeroable>(parse_macro_input!(input as DeriveInput));
|
||||
|
||||
proc_macro::TokenStream::from(expanded)
|
||||
}
|
||||
|
||||
/// Derive the `TransparentWrapper` trait for a struct
|
||||
///
|
||||
/// The macro ensures that the struct follows all the the safety requirements
|
||||
/// for the `TransparentWrapper` trait.
|
||||
///
|
||||
/// The following constraints need to be satisfied for the macro to succeed
|
||||
///
|
||||
/// - The struct must be `#[repr(transparent)]
|
||||
/// - The struct must contain the `Wrapped` type
|
||||
///
|
||||
/// If the struct only contains a single field, the `Wrapped` type will
|
||||
/// automatically be determined if there is more then one field in the struct,
|
||||
/// you need to specify the `Wrapped` type using `#[transparent(T)]`
|
||||
///
|
||||
/// ## Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # use bytemuck_derive::TransparentWrapper;
|
||||
/// # use std::marker::PhantomData;
|
||||
///
|
||||
/// #[derive(Copy, Clone, TransparentWrapper)]
|
||||
/// #[repr(transparent)]
|
||||
/// #[transparent(u16)]
|
||||
/// struct Test<T> {
|
||||
/// inner: u16,
|
||||
/// extra: PhantomData<T>,
|
||||
/// }
|
||||
/// ```
|
||||
#[proc_macro_derive(TransparentWrapper, attributes(transparent))]
|
||||
pub fn derive_transparent(
|
||||
input: proc_macro::TokenStream,
|
||||
) -> proc_macro::TokenStream {
|
||||
let expanded = derive_marker_trait::<TransparentWrapper>(parse_macro_input!(
|
||||
input as DeriveInput
|
||||
));
|
||||
|
||||
proc_macro::TokenStream::from(expanded)
|
||||
}
|
||||
|
||||
/// Basic wrapper for error handling
|
||||
fn derive_marker_trait<Trait: Derivable>(input: DeriveInput) -> TokenStream {
|
||||
derive_marker_trait_inner::<Trait>(input).unwrap_or_else(|err| {
|
||||
quote! {
|
||||
compile_error!(#err);
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn derive_marker_trait_inner<Trait: Derivable>(
|
||||
input: DeriveInput,
|
||||
) -> Result<TokenStream, &'static str> {
|
||||
let name = &input.ident;
|
||||
|
||||
let (impl_generics, ty_generics, where_clause) =
|
||||
input.generics.split_for_impl();
|
||||
|
||||
let trait_ = Trait::ident();
|
||||
Trait::check_attributes(&input.attrs)?;
|
||||
let asserts = Trait::struct_asserts(&input)?;
|
||||
let trait_params = Trait::generic_params(&input)?;
|
||||
|
||||
Ok(quote! {
|
||||
#asserts
|
||||
|
||||
unsafe impl #impl_generics #trait_ #trait_params for #name #ty_generics #where_clause {}
|
||||
})
|
||||
}
|
207
derive/src/traits.rs
Normal file
207
derive/src/traits.rs
Normal file
@ -0,0 +1,207 @@
|
||||
use proc_macro2::{Ident, TokenStream, TokenTree};
|
||||
use quote::{quote, quote_spanned, ToTokens};
|
||||
use syn::{
|
||||
spanned::Spanned, AttrStyle, Attribute, Data, DataStruct, DeriveInput,
|
||||
Fields, Type,
|
||||
};
|
||||
|
||||
pub trait Derivable {
|
||||
fn ident() -> TokenStream;
|
||||
fn generic_params(_input: &DeriveInput) -> Result<TokenStream, &'static str> {
|
||||
Ok(quote!())
|
||||
}
|
||||
fn struct_asserts(input: &DeriveInput) -> Result<TokenStream, &'static str>;
|
||||
fn check_attributes(_attributes: &[Attribute]) -> Result<(), &'static str> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Pod;
|
||||
|
||||
impl Derivable for Pod {
|
||||
fn ident() -> TokenStream {
|
||||
quote!(::bytemuck::Pod)
|
||||
}
|
||||
|
||||
fn struct_asserts(input: &DeriveInput) -> Result<TokenStream, &'static str> {
|
||||
if !input.generics.params.is_empty() {
|
||||
return Err("Pod requires cannot be derived for structs containing generic parameters because the padding requirements can't be verified for generic structs");
|
||||
}
|
||||
|
||||
let assert_no_padding = generate_assert_no_padding(input)?;
|
||||
let assert_fields_are_pod =
|
||||
generate_fields_are_trait(input, Self::ident())?;
|
||||
|
||||
Ok(quote!(
|
||||
#assert_no_padding
|
||||
#assert_fields_are_pod
|
||||
))
|
||||
}
|
||||
|
||||
fn check_attributes(attributes: &[Attribute]) -> Result<(), &'static str> {
|
||||
let repr = get_repr(attributes);
|
||||
match repr.as_ref().map(|repr| repr.as_str()) {
|
||||
Some("C") => Ok(()),
|
||||
Some("transparent") => Ok(()),
|
||||
_ => {
|
||||
Err("Pod requires the struct to be #[repr(C)] or #[repr(transparent)]")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Zeroable;
|
||||
|
||||
impl Derivable for Zeroable {
|
||||
fn ident() -> TokenStream {
|
||||
quote!(::bytemuck::Zeroable)
|
||||
}
|
||||
|
||||
fn struct_asserts(input: &DeriveInput) -> Result<TokenStream, &'static str> {
|
||||
generate_fields_are_trait(input, Self::ident())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TransparentWrapper;
|
||||
|
||||
impl TransparentWrapper {
|
||||
fn get_wrapper_type(
|
||||
attributes: &[Attribute], fields: &Fields,
|
||||
) -> Option<TokenStream> {
|
||||
let transparent_param = get_simple_attr(attributes, "transparent");
|
||||
transparent_param.map(|ident| ident.to_token_stream()).or_else(|| {
|
||||
let mut types = get_field_types(fields);
|
||||
let first_type = types.next();
|
||||
if let Some(_) = types.next() {
|
||||
// can't guess param type if there is more than one field
|
||||
return None;
|
||||
} else {
|
||||
first_type.map(|ty| ty.to_token_stream())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Derivable for TransparentWrapper {
|
||||
fn ident() -> TokenStream {
|
||||
quote!(::bytemuck::TransparentWrapper)
|
||||
}
|
||||
|
||||
fn generic_params(input: &DeriveInput) -> Result<TokenStream, &'static str> {
|
||||
let fields = get_struct_fields(input)?;
|
||||
|
||||
Self::get_wrapper_type(&input.attrs, fields).map(|ty| quote!(<#ty>))
|
||||
.ok_or("when deriving TransparentWrapper for a struct with more than one field you need to specify the transparent field using #[transparent(T)]")
|
||||
}
|
||||
|
||||
fn struct_asserts(input: &DeriveInput) -> Result<TokenStream, &'static str> {
|
||||
let fields = get_struct_fields(input)?;
|
||||
let wrapped_type = match Self::get_wrapper_type(&input.attrs, fields) {
|
||||
Some(wrapped_type) => wrapped_type.to_string(),
|
||||
None => unreachable!(), /* other code will already reject this derive */
|
||||
};
|
||||
let mut wrapped_fields = fields
|
||||
.iter()
|
||||
.filter(|field| field.ty.to_token_stream().to_string() == wrapped_type);
|
||||
if let None = wrapped_fields.next() {
|
||||
return Err("TransparentWrapper must have one field of the wrapped type");
|
||||
};
|
||||
if let Some(_) = wrapped_fields.next() {
|
||||
Err("TransparentWrapper can only have one field of the wrapped type")
|
||||
} else {
|
||||
Ok(quote!())
|
||||
}
|
||||
}
|
||||
|
||||
fn check_attributes(attributes: &[Attribute]) -> Result<(), &'static str> {
|
||||
let repr = get_repr(attributes);
|
||||
|
||||
match repr.as_ref().map(|repr| repr.as_str()) {
|
||||
Some("transparent") => Ok(()),
|
||||
_ => {
|
||||
Err("TransparentWrapper requires the struct to be #[repr(transparent)]")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_struct_fields(input: &DeriveInput) -> Result<&Fields, &'static str> {
|
||||
if let Data::Struct(DataStruct { fields, .. }) = &input.data {
|
||||
Ok(fields)
|
||||
} else {
|
||||
Err("deriving this trait is only supported for structs")
|
||||
}
|
||||
}
|
||||
|
||||
fn get_field_types<'a>(
|
||||
fields: &'a Fields,
|
||||
) -> impl Iterator<Item = &'a Type> + 'a {
|
||||
fields.iter().map(|field| &field.ty)
|
||||
}
|
||||
|
||||
/// Check that a struct has no padding by asserting that the size of the struct
|
||||
/// is equal to the sum of the size of it's fields
|
||||
fn generate_assert_no_padding(
|
||||
input: &DeriveInput,
|
||||
) -> Result<TokenStream, &'static str> {
|
||||
let struct_type = &input.ident;
|
||||
let span = input.span();
|
||||
let fields = get_struct_fields(input)?;
|
||||
|
||||
let field_types = get_field_types(&fields);
|
||||
let struct_size =
|
||||
quote_spanned!(span => core::mem::size_of::<#struct_type>());
|
||||
let size_sum =
|
||||
quote_spanned!(span => 0 #( + core::mem::size_of::<#field_types>() )*);
|
||||
|
||||
Ok(quote_spanned! {span => const _: fn() = || {
|
||||
let _ = core::mem::transmute::<[u8; #struct_size], [u8; #size_sum]>;
|
||||
};})
|
||||
}
|
||||
|
||||
/// Check that all fields implement a given trait
|
||||
fn generate_fields_are_trait(
|
||||
input: &DeriveInput, trait_: TokenStream,
|
||||
) -> Result<TokenStream, &'static str> {
|
||||
let (impl_generics, _ty_generics, where_clause) =
|
||||
input.generics.split_for_impl();
|
||||
let fields = get_struct_fields(input)?;
|
||||
let span = input.span();
|
||||
let field_types = get_field_types(&fields);
|
||||
Ok(quote_spanned! {span => #(const _: fn() = || {
|
||||
fn check #impl_generics () #where_clause {
|
||||
fn assert_impl<T: #trait_>() {}
|
||||
assert_impl::<#field_types>();
|
||||
}
|
||||
};)*
|
||||
})
|
||||
}
|
||||
|
||||
fn get_ident_from_stream(tokens: TokenStream) -> Option<Ident> {
|
||||
match tokens.into_iter().next() {
|
||||
Some(TokenTree::Group(group)) => get_ident_from_stream(group.stream()),
|
||||
Some(TokenTree::Ident(ident)) => Some(ident),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// get a simple #[foo(bar)] attribute, returning "bar"
|
||||
fn get_simple_attr(attributes: &[Attribute], attr_name: &str) -> Option<Ident> {
|
||||
for attr in attributes {
|
||||
if let (AttrStyle::Outer, Some(outer_ident), Some(inner_ident)) = (
|
||||
&attr.style,
|
||||
attr.path.get_ident(),
|
||||
get_ident_from_stream(attr.tokens.clone()),
|
||||
) {
|
||||
if outer_ident.to_string() == attr_name {
|
||||
return Some(inner_ident);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn get_repr(attributes: &[Attribute]) -> Option<String> {
|
||||
get_simple_attr(attributes, "repr").map(|ident| ident.to_string())
|
||||
}
|
30
derive/tests/basic.rs
Normal file
30
derive/tests/basic.rs
Normal file
@ -0,0 +1,30 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
use bytemuck_derive::{Pod, TransparentWrapper, Zeroable};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[derive(Copy, Clone, Pod, Zeroable)]
|
||||
#[repr(C)]
|
||||
struct Test {
|
||||
a: u16,
|
||||
b: u16,
|
||||
}
|
||||
|
||||
#[derive(Zeroable)]
|
||||
struct ZeroGeneric<T: bytemuck::Zeroable> {
|
||||
a: T,
|
||||
}
|
||||
|
||||
#[derive(TransparentWrapper)]
|
||||
#[repr(transparent)]
|
||||
struct TransparentSingle {
|
||||
a: u16,
|
||||
}
|
||||
|
||||
#[derive(TransparentWrapper)]
|
||||
#[repr(transparent)]
|
||||
#[transparent(u16)]
|
||||
struct TransparentWithZeroSized<T> {
|
||||
a: u16,
|
||||
b: PhantomData<T>,
|
||||
}
|
@ -81,6 +81,9 @@ pub use offset_of::*;
|
||||
mod transparent;
|
||||
pub use transparent::*;
|
||||
|
||||
#[cfg(feature = "derive")]
|
||||
pub use bytemuck_derive::{Zeroable, Pod, TransparentWrapper};
|
||||
|
||||
/*
|
||||
|
||||
Note(Lokathor): We've switched all of the `unwrap` to `match` because there is
|
||||
|
25
tests/derive.rs
Normal file
25
tests/derive.rs
Normal file
@ -0,0 +1,25 @@
|
||||
#![cfg(feature = "derive")]
|
||||
#![allow(dead_code)]
|
||||
|
||||
use bytemuck::{Zeroable, Pod, TransparentWrapper};
|
||||
|
||||
#[derive(Copy, Clone, Pod, Zeroable)]
|
||||
#[repr(C)]
|
||||
struct Test {
|
||||
a: u16,
|
||||
b: u16,
|
||||
}
|
||||
|
||||
#[derive(TransparentWrapper)]
|
||||
#[repr(transparent)]
|
||||
struct TransparentSingle {
|
||||
a: u16,
|
||||
}
|
||||
|
||||
#[derive(TransparentWrapper)]
|
||||
#[repr(transparent)]
|
||||
#[transparent(u16)]
|
||||
struct TransparentWithZeroSized {
|
||||
a: u16,
|
||||
b: ()
|
||||
}
|
Loading…
Reference in New Issue
Block a user