mirror of
https://github.com/Lokathor/bytemuck.git
synced 2024-11-25 08:12:26 +00:00
add basic derive macro for Pod, Zeroable and TransparentWrapper for structs (#30)
* add basic derive macro for Pod and Zeroable for structs * add derive macro for TransparentWrapper * use core::mem::size_of instead of std::mem::size_of in generated code * cleanup error handling a bit * remove unneeded iter logic * remove unneeded clone and order impl * fix generics * fix doc typo Co-authored-by: Lucien Greathouse <me@lpghatguy.com> * remove unneeded lifetime anotation * use unreachable for already rejected patch Co-authored-by: Lucien Greathouse <me@lpghatguy.com>
This commit is contained in:
parent
94d71d9925
commit
cf944452b7
@ -16,6 +16,12 @@ exclude = ["/pedantic.bat"]
|
|||||||
extern_crate_alloc = []
|
extern_crate_alloc = []
|
||||||
extern_crate_std = ["extern_crate_alloc"]
|
extern_crate_std = ["extern_crate_alloc"]
|
||||||
zeroable_maybe_uninit = []
|
zeroable_maybe_uninit = []
|
||||||
|
derive = ["bytemuck_derive"]
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
bytemuck_derive = { version = "1.0", path = "derive", optional = true }
|
||||||
|
|
||||||
[package.metadata.docs.rs]
|
[package.metadata.docs.rs]
|
||||||
all-features = true
|
all-features = true
|
||||||
|
|
||||||
|
[workspace]
|
23
derive/Cargo.toml
Normal file
23
derive/Cargo.toml
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
[package]
|
||||||
|
name = "bytemuck_derive"
|
||||||
|
description = "A crate for mucking around with piles of bytes."
|
||||||
|
version = "1.0.0"
|
||||||
|
authors = ["Lokathor <zefria@gmail.com>"]
|
||||||
|
repository = "https://github.com/Lokathor/bytemuck"
|
||||||
|
readme = "README.md"
|
||||||
|
keywords = ["transmute", "bytes", "casting"]
|
||||||
|
categories = ["encoding", "no-std"]
|
||||||
|
edition = "2018"
|
||||||
|
license = "Zlib OR Apache-2.0 OR MIT"
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
name = "bytemuck_derive"
|
||||||
|
proc-macro = true
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
syn = "1.0"
|
||||||
|
quote = "1.0"
|
||||||
|
proc-macro2 = "1.0"
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
bytemuck = { version = "1.3.2-alpha.0", path = ".." }
|
140
derive/src/lib.rs
Normal file
140
derive/src/lib.rs
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
//! Derive macros for [bytemuck](https://docs.rs/bytemuck) traits
|
||||||
|
|
||||||
|
mod traits;
|
||||||
|
|
||||||
|
use proc_macro2::TokenStream;
|
||||||
|
use quote::quote;
|
||||||
|
use syn::{parse_macro_input, DeriveInput};
|
||||||
|
|
||||||
|
use crate::traits::{Derivable, Pod, TransparentWrapper, Zeroable};
|
||||||
|
|
||||||
|
/// Derive the `Pod` trait for a struct
|
||||||
|
///
|
||||||
|
/// The macro ensures that the struct follows all the the safety requirements
|
||||||
|
/// for the `Pod` trait.
|
||||||
|
///
|
||||||
|
/// The following constraints need to be satisfied for the macro to succeed
|
||||||
|
///
|
||||||
|
/// - All fields in the struct must implement `Pod`
|
||||||
|
/// - The struct must be `#[repr(C)]` or `#[repr(transparent)]`
|
||||||
|
/// - The struct must not contain any padding bytes
|
||||||
|
/// - The struct contains no generic parameters
|
||||||
|
///
|
||||||
|
/// ## Example
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// # use bytemuck_derive::{Pod, Zeroable};
|
||||||
|
///
|
||||||
|
/// #[derive(Copy, Clone, Pod, Zeroable)]
|
||||||
|
/// #[repr(C)]
|
||||||
|
/// struct Test {
|
||||||
|
/// a: u16,
|
||||||
|
/// b: u16,
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
#[proc_macro_derive(Pod)]
|
||||||
|
pub fn derive_pod(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
|
||||||
|
let expanded =
|
||||||
|
derive_marker_trait::<Pod>(parse_macro_input!(input as DeriveInput));
|
||||||
|
|
||||||
|
proc_macro::TokenStream::from(expanded)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Derive the `Zeroable` trait for a struct
|
||||||
|
///
|
||||||
|
/// The macro ensures that the struct follows all the the safety requirements
|
||||||
|
/// for the `Zeroable` trait.
|
||||||
|
///
|
||||||
|
/// The following constraints need to be satisfied for the macro to succeed
|
||||||
|
///
|
||||||
|
/// - All fields ind the struct must to implement `Zeroable`
|
||||||
|
///
|
||||||
|
/// ## Example
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// # use bytemuck_derive::{Zeroable};
|
||||||
|
///
|
||||||
|
/// #[derive(Copy, Clone, Zeroable)]
|
||||||
|
/// #[repr(C)]
|
||||||
|
/// struct Test {
|
||||||
|
/// a: u16,
|
||||||
|
/// b: u16,
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
#[proc_macro_derive(Zeroable)]
|
||||||
|
pub fn derive_zeroable(
|
||||||
|
input: proc_macro::TokenStream,
|
||||||
|
) -> proc_macro::TokenStream {
|
||||||
|
let expanded =
|
||||||
|
derive_marker_trait::<Zeroable>(parse_macro_input!(input as DeriveInput));
|
||||||
|
|
||||||
|
proc_macro::TokenStream::from(expanded)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Derive the `TransparentWrapper` trait for a struct
|
||||||
|
///
|
||||||
|
/// The macro ensures that the struct follows all the the safety requirements
|
||||||
|
/// for the `TransparentWrapper` trait.
|
||||||
|
///
|
||||||
|
/// The following constraints need to be satisfied for the macro to succeed
|
||||||
|
///
|
||||||
|
/// - The struct must be `#[repr(transparent)]
|
||||||
|
/// - The struct must contain the `Wrapped` type
|
||||||
|
///
|
||||||
|
/// If the struct only contains a single field, the `Wrapped` type will
|
||||||
|
/// automatically be determined if there is more then one field in the struct,
|
||||||
|
/// you need to specify the `Wrapped` type using `#[transparent(T)]`
|
||||||
|
///
|
||||||
|
/// ## Example
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// # use bytemuck_derive::TransparentWrapper;
|
||||||
|
/// # use std::marker::PhantomData;
|
||||||
|
///
|
||||||
|
/// #[derive(Copy, Clone, TransparentWrapper)]
|
||||||
|
/// #[repr(transparent)]
|
||||||
|
/// #[transparent(u16)]
|
||||||
|
/// struct Test<T> {
|
||||||
|
/// inner: u16,
|
||||||
|
/// extra: PhantomData<T>,
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
#[proc_macro_derive(TransparentWrapper, attributes(transparent))]
|
||||||
|
pub fn derive_transparent(
|
||||||
|
input: proc_macro::TokenStream,
|
||||||
|
) -> proc_macro::TokenStream {
|
||||||
|
let expanded = derive_marker_trait::<TransparentWrapper>(parse_macro_input!(
|
||||||
|
input as DeriveInput
|
||||||
|
));
|
||||||
|
|
||||||
|
proc_macro::TokenStream::from(expanded)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Basic wrapper for error handling
|
||||||
|
fn derive_marker_trait<Trait: Derivable>(input: DeriveInput) -> TokenStream {
|
||||||
|
derive_marker_trait_inner::<Trait>(input).unwrap_or_else(|err| {
|
||||||
|
quote! {
|
||||||
|
compile_error!(#err);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn derive_marker_trait_inner<Trait: Derivable>(
|
||||||
|
input: DeriveInput,
|
||||||
|
) -> Result<TokenStream, &'static str> {
|
||||||
|
let name = &input.ident;
|
||||||
|
|
||||||
|
let (impl_generics, ty_generics, where_clause) =
|
||||||
|
input.generics.split_for_impl();
|
||||||
|
|
||||||
|
let trait_ = Trait::ident();
|
||||||
|
Trait::check_attributes(&input.attrs)?;
|
||||||
|
let asserts = Trait::struct_asserts(&input)?;
|
||||||
|
let trait_params = Trait::generic_params(&input)?;
|
||||||
|
|
||||||
|
Ok(quote! {
|
||||||
|
#asserts
|
||||||
|
|
||||||
|
unsafe impl #impl_generics #trait_ #trait_params for #name #ty_generics #where_clause {}
|
||||||
|
})
|
||||||
|
}
|
207
derive/src/traits.rs
Normal file
207
derive/src/traits.rs
Normal file
@ -0,0 +1,207 @@
|
|||||||
|
use proc_macro2::{Ident, TokenStream, TokenTree};
|
||||||
|
use quote::{quote, quote_spanned, ToTokens};
|
||||||
|
use syn::{
|
||||||
|
spanned::Spanned, AttrStyle, Attribute, Data, DataStruct, DeriveInput,
|
||||||
|
Fields, Type,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub trait Derivable {
|
||||||
|
fn ident() -> TokenStream;
|
||||||
|
fn generic_params(_input: &DeriveInput) -> Result<TokenStream, &'static str> {
|
||||||
|
Ok(quote!())
|
||||||
|
}
|
||||||
|
fn struct_asserts(input: &DeriveInput) -> Result<TokenStream, &'static str>;
|
||||||
|
fn check_attributes(_attributes: &[Attribute]) -> Result<(), &'static str> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Pod;
|
||||||
|
|
||||||
|
impl Derivable for Pod {
|
||||||
|
fn ident() -> TokenStream {
|
||||||
|
quote!(::bytemuck::Pod)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn struct_asserts(input: &DeriveInput) -> Result<TokenStream, &'static str> {
|
||||||
|
if !input.generics.params.is_empty() {
|
||||||
|
return Err("Pod requires cannot be derived for structs containing generic parameters because the padding requirements can't be verified for generic structs");
|
||||||
|
}
|
||||||
|
|
||||||
|
let assert_no_padding = generate_assert_no_padding(input)?;
|
||||||
|
let assert_fields_are_pod =
|
||||||
|
generate_fields_are_trait(input, Self::ident())?;
|
||||||
|
|
||||||
|
Ok(quote!(
|
||||||
|
#assert_no_padding
|
||||||
|
#assert_fields_are_pod
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn check_attributes(attributes: &[Attribute]) -> Result<(), &'static str> {
|
||||||
|
let repr = get_repr(attributes);
|
||||||
|
match repr.as_ref().map(|repr| repr.as_str()) {
|
||||||
|
Some("C") => Ok(()),
|
||||||
|
Some("transparent") => Ok(()),
|
||||||
|
_ => {
|
||||||
|
Err("Pod requires the struct to be #[repr(C)] or #[repr(transparent)]")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Zeroable;
|
||||||
|
|
||||||
|
impl Derivable for Zeroable {
|
||||||
|
fn ident() -> TokenStream {
|
||||||
|
quote!(::bytemuck::Zeroable)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn struct_asserts(input: &DeriveInput) -> Result<TokenStream, &'static str> {
|
||||||
|
generate_fields_are_trait(input, Self::ident())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct TransparentWrapper;
|
||||||
|
|
||||||
|
impl TransparentWrapper {
|
||||||
|
fn get_wrapper_type(
|
||||||
|
attributes: &[Attribute], fields: &Fields,
|
||||||
|
) -> Option<TokenStream> {
|
||||||
|
let transparent_param = get_simple_attr(attributes, "transparent");
|
||||||
|
transparent_param.map(|ident| ident.to_token_stream()).or_else(|| {
|
||||||
|
let mut types = get_field_types(fields);
|
||||||
|
let first_type = types.next();
|
||||||
|
if let Some(_) = types.next() {
|
||||||
|
// can't guess param type if there is more than one field
|
||||||
|
return None;
|
||||||
|
} else {
|
||||||
|
first_type.map(|ty| ty.to_token_stream())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Derivable for TransparentWrapper {
|
||||||
|
fn ident() -> TokenStream {
|
||||||
|
quote!(::bytemuck::TransparentWrapper)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn generic_params(input: &DeriveInput) -> Result<TokenStream, &'static str> {
|
||||||
|
let fields = get_struct_fields(input)?;
|
||||||
|
|
||||||
|
Self::get_wrapper_type(&input.attrs, fields).map(|ty| quote!(<#ty>))
|
||||||
|
.ok_or("when deriving TransparentWrapper for a struct with more than one field you need to specify the transparent field using #[transparent(T)]")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn struct_asserts(input: &DeriveInput) -> Result<TokenStream, &'static str> {
|
||||||
|
let fields = get_struct_fields(input)?;
|
||||||
|
let wrapped_type = match Self::get_wrapper_type(&input.attrs, fields) {
|
||||||
|
Some(wrapped_type) => wrapped_type.to_string(),
|
||||||
|
None => unreachable!(), /* other code will already reject this derive */
|
||||||
|
};
|
||||||
|
let mut wrapped_fields = fields
|
||||||
|
.iter()
|
||||||
|
.filter(|field| field.ty.to_token_stream().to_string() == wrapped_type);
|
||||||
|
if let None = wrapped_fields.next() {
|
||||||
|
return Err("TransparentWrapper must have one field of the wrapped type");
|
||||||
|
};
|
||||||
|
if let Some(_) = wrapped_fields.next() {
|
||||||
|
Err("TransparentWrapper can only have one field of the wrapped type")
|
||||||
|
} else {
|
||||||
|
Ok(quote!())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn check_attributes(attributes: &[Attribute]) -> Result<(), &'static str> {
|
||||||
|
let repr = get_repr(attributes);
|
||||||
|
|
||||||
|
match repr.as_ref().map(|repr| repr.as_str()) {
|
||||||
|
Some("transparent") => Ok(()),
|
||||||
|
_ => {
|
||||||
|
Err("TransparentWrapper requires the struct to be #[repr(transparent)]")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_struct_fields(input: &DeriveInput) -> Result<&Fields, &'static str> {
|
||||||
|
if let Data::Struct(DataStruct { fields, .. }) = &input.data {
|
||||||
|
Ok(fields)
|
||||||
|
} else {
|
||||||
|
Err("deriving this trait is only supported for structs")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_field_types<'a>(
|
||||||
|
fields: &'a Fields,
|
||||||
|
) -> impl Iterator<Item = &'a Type> + 'a {
|
||||||
|
fields.iter().map(|field| &field.ty)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check that a struct has no padding by asserting that the size of the struct
|
||||||
|
/// is equal to the sum of the size of it's fields
|
||||||
|
fn generate_assert_no_padding(
|
||||||
|
input: &DeriveInput,
|
||||||
|
) -> Result<TokenStream, &'static str> {
|
||||||
|
let struct_type = &input.ident;
|
||||||
|
let span = input.span();
|
||||||
|
let fields = get_struct_fields(input)?;
|
||||||
|
|
||||||
|
let field_types = get_field_types(&fields);
|
||||||
|
let struct_size =
|
||||||
|
quote_spanned!(span => core::mem::size_of::<#struct_type>());
|
||||||
|
let size_sum =
|
||||||
|
quote_spanned!(span => 0 #( + core::mem::size_of::<#field_types>() )*);
|
||||||
|
|
||||||
|
Ok(quote_spanned! {span => const _: fn() = || {
|
||||||
|
let _ = core::mem::transmute::<[u8; #struct_size], [u8; #size_sum]>;
|
||||||
|
};})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check that all fields implement a given trait
|
||||||
|
fn generate_fields_are_trait(
|
||||||
|
input: &DeriveInput, trait_: TokenStream,
|
||||||
|
) -> Result<TokenStream, &'static str> {
|
||||||
|
let (impl_generics, _ty_generics, where_clause) =
|
||||||
|
input.generics.split_for_impl();
|
||||||
|
let fields = get_struct_fields(input)?;
|
||||||
|
let span = input.span();
|
||||||
|
let field_types = get_field_types(&fields);
|
||||||
|
Ok(quote_spanned! {span => #(const _: fn() = || {
|
||||||
|
fn check #impl_generics () #where_clause {
|
||||||
|
fn assert_impl<T: #trait_>() {}
|
||||||
|
assert_impl::<#field_types>();
|
||||||
|
}
|
||||||
|
};)*
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_ident_from_stream(tokens: TokenStream) -> Option<Ident> {
|
||||||
|
match tokens.into_iter().next() {
|
||||||
|
Some(TokenTree::Group(group)) => get_ident_from_stream(group.stream()),
|
||||||
|
Some(TokenTree::Ident(ident)) => Some(ident),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// get a simple #[foo(bar)] attribute, returning "bar"
|
||||||
|
fn get_simple_attr(attributes: &[Attribute], attr_name: &str) -> Option<Ident> {
|
||||||
|
for attr in attributes {
|
||||||
|
if let (AttrStyle::Outer, Some(outer_ident), Some(inner_ident)) = (
|
||||||
|
&attr.style,
|
||||||
|
attr.path.get_ident(),
|
||||||
|
get_ident_from_stream(attr.tokens.clone()),
|
||||||
|
) {
|
||||||
|
if outer_ident.to_string() == attr_name {
|
||||||
|
return Some(inner_ident);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_repr(attributes: &[Attribute]) -> Option<String> {
|
||||||
|
get_simple_attr(attributes, "repr").map(|ident| ident.to_string())
|
||||||
|
}
|
30
derive/tests/basic.rs
Normal file
30
derive/tests/basic.rs
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
#![allow(dead_code)]
|
||||||
|
|
||||||
|
use bytemuck_derive::{Pod, TransparentWrapper, Zeroable};
|
||||||
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
#[derive(Copy, Clone, Pod, Zeroable)]
|
||||||
|
#[repr(C)]
|
||||||
|
struct Test {
|
||||||
|
a: u16,
|
||||||
|
b: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Zeroable)]
|
||||||
|
struct ZeroGeneric<T: bytemuck::Zeroable> {
|
||||||
|
a: T,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(TransparentWrapper)]
|
||||||
|
#[repr(transparent)]
|
||||||
|
struct TransparentSingle {
|
||||||
|
a: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(TransparentWrapper)]
|
||||||
|
#[repr(transparent)]
|
||||||
|
#[transparent(u16)]
|
||||||
|
struct TransparentWithZeroSized<T> {
|
||||||
|
a: u16,
|
||||||
|
b: PhantomData<T>,
|
||||||
|
}
|
@ -81,6 +81,9 @@ pub use offset_of::*;
|
|||||||
mod transparent;
|
mod transparent;
|
||||||
pub use transparent::*;
|
pub use transparent::*;
|
||||||
|
|
||||||
|
#[cfg(feature = "derive")]
|
||||||
|
pub use bytemuck_derive::{Zeroable, Pod, TransparentWrapper};
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
|
||||||
Note(Lokathor): We've switched all of the `unwrap` to `match` because there is
|
Note(Lokathor): We've switched all of the `unwrap` to `match` because there is
|
||||||
|
25
tests/derive.rs
Normal file
25
tests/derive.rs
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
#![cfg(feature = "derive")]
|
||||||
|
#![allow(dead_code)]
|
||||||
|
|
||||||
|
use bytemuck::{Zeroable, Pod, TransparentWrapper};
|
||||||
|
|
||||||
|
#[derive(Copy, Clone, Pod, Zeroable)]
|
||||||
|
#[repr(C)]
|
||||||
|
struct Test {
|
||||||
|
a: u16,
|
||||||
|
b: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(TransparentWrapper)]
|
||||||
|
#[repr(transparent)]
|
||||||
|
struct TransparentSingle {
|
||||||
|
a: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(TransparentWrapper)]
|
||||||
|
#[repr(transparent)]
|
||||||
|
#[transparent(u16)]
|
||||||
|
struct TransparentWithZeroSized {
|
||||||
|
a: u16,
|
||||||
|
b: ()
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user