mirror of
https://github.com/Lokathor/bytemuck.git
synced 2024-11-21 22:32:23 +00:00
add derive macro for Contiguous (#31)
This commit is contained in:
parent
e202fa2756
commit
24b71e078f
@ -6,7 +6,7 @@ use proc_macro2::TokenStream;
|
||||
use quote::quote;
|
||||
use syn::{parse_macro_input, DeriveInput};
|
||||
|
||||
use crate::traits::{Derivable, Pod, TransparentWrapper, Zeroable};
|
||||
use crate::traits::{Contiguous, Derivable, Pod, TransparentWrapper, Zeroable};
|
||||
|
||||
/// Derive the `Pod` trait for a struct
|
||||
///
|
||||
@ -110,6 +110,40 @@ pub fn derive_transparent(
|
||||
proc_macro::TokenStream::from(expanded)
|
||||
}
|
||||
|
||||
/// Derive the `Contiguous` trait for an enum
|
||||
///
|
||||
/// The macro ensures that the enum follows all the the safety requirements
|
||||
/// for the `Contiguous` trait.
|
||||
///
|
||||
/// The following constraints need to be satisfied for the macro to succeed
|
||||
///
|
||||
/// - The enum must be `#[repr(Int)]`
|
||||
/// - The enum must be fieldless
|
||||
/// - The enum discriminants must form a contiguous range
|
||||
///
|
||||
/// ## Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # use bytemuck_derive::{Contiguous};
|
||||
///
|
||||
/// #[derive(Copy, Clone, Contiguous)]
|
||||
/// #[repr(u8)]
|
||||
/// enum Test {
|
||||
/// A = 0,
|
||||
/// B = 1,
|
||||
/// C = 2,
|
||||
/// }
|
||||
/// ```
|
||||
#[proc_macro_derive(Contiguous)]
|
||||
pub fn derive_contiguous(
|
||||
input: proc_macro::TokenStream,
|
||||
) -> proc_macro::TokenStream {
|
||||
let expanded =
|
||||
derive_marker_trait::<Contiguous>(parse_macro_input!(input as DeriveInput));
|
||||
|
||||
proc_macro::TokenStream::from(expanded)
|
||||
}
|
||||
|
||||
/// Basic wrapper for error handling
|
||||
fn derive_marker_trait<Trait: Derivable>(input: DeriveInput) -> TokenStream {
|
||||
derive_marker_trait_inner::<Trait>(input).unwrap_or_else(|err| {
|
||||
@ -131,10 +165,13 @@ fn derive_marker_trait_inner<Trait: Derivable>(
|
||||
Trait::check_attributes(&input.attrs)?;
|
||||
let asserts = Trait::struct_asserts(&input)?;
|
||||
let trait_params = Trait::generic_params(&input)?;
|
||||
let trait_impl = Trait::trait_impl(&input)?;
|
||||
|
||||
Ok(quote! {
|
||||
#asserts
|
||||
|
||||
unsafe impl #impl_generics #trait_ #trait_params for #name #ty_generics #where_clause {}
|
||||
unsafe impl #impl_generics #trait_ #trait_params for #name #ty_generics #where_clause {
|
||||
#trait_impl
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -1,8 +1,9 @@
|
||||
use proc_macro2::{Ident, TokenStream, TokenTree};
|
||||
use quote::{quote, quote_spanned, ToTokens};
|
||||
use syn::{
|
||||
spanned::Spanned, AttrStyle, Attribute, Data, DataStruct, DeriveInput,
|
||||
Fields, Type,
|
||||
spanned::Spanned, AttrStyle, Attribute, Data, DataEnum, DataStruct,
|
||||
DeriveInput, Expr, ExprLit, ExprUnary, Fields, Lit, LitInt, Type, UnOp,
|
||||
Variant,
|
||||
};
|
||||
|
||||
pub trait Derivable {
|
||||
@ -10,10 +11,15 @@ pub trait Derivable {
|
||||
fn generic_params(_input: &DeriveInput) -> Result<TokenStream, &'static str> {
|
||||
Ok(quote!())
|
||||
}
|
||||
fn struct_asserts(input: &DeriveInput) -> Result<TokenStream, &'static str>;
|
||||
fn struct_asserts(_input: &DeriveInput) -> Result<TokenStream, &'static str> {
|
||||
Ok(quote!())
|
||||
}
|
||||
fn check_attributes(_attributes: &[Attribute]) -> Result<(), &'static str> {
|
||||
Ok(())
|
||||
}
|
||||
fn trait_impl(_input: &DeriveInput) -> Result<TokenStream, &'static str> {
|
||||
Ok(quote!())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Pod;
|
||||
@ -125,6 +131,55 @@ impl Derivable for TransparentWrapper {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Contiguous;
|
||||
|
||||
impl Derivable for Contiguous {
|
||||
fn ident() -> TokenStream {
|
||||
quote!(::bytemuck::Contiguous)
|
||||
}
|
||||
|
||||
fn trait_impl(input: &DeriveInput) -> Result<TokenStream, &'static str> {
|
||||
let repr = get_repr(&input.attrs)
|
||||
.ok_or("Contiguous requires the enum to be #[repr(Int)]")?;
|
||||
|
||||
if !repr.starts_with('u') && !repr.starts_with('i') {
|
||||
return Err("Contiguous requires the enum to be #[repr(Int)]");
|
||||
}
|
||||
|
||||
let variants = get_enum_variants(input)?;
|
||||
let mut variants_with_discriminator =
|
||||
VariantDiscriminantIterator::new(variants);
|
||||
|
||||
let (min, max, count) = variants_with_discriminator.try_fold(
|
||||
(i64::max_value(), i64::min_value(), 0),
|
||||
|(min, max, count), res| {
|
||||
let discriminator = res?;
|
||||
Ok((
|
||||
i64::min(min, discriminator),
|
||||
i64::max(max, discriminator),
|
||||
count + 1,
|
||||
))
|
||||
},
|
||||
)?;
|
||||
|
||||
if max - min != count - 1 {
|
||||
return Err(
|
||||
"Contiguous requires the enum discriminants to be contiguous",
|
||||
);
|
||||
}
|
||||
|
||||
let repr_ident = Ident::new(&repr, input.span());
|
||||
let min_lit = LitInt::new(&format!("{}", min), input.span());
|
||||
let max_lit = LitInt::new(&format!("{}", max), input.span());
|
||||
|
||||
Ok(quote! {
|
||||
type Int = #repr_ident;
|
||||
const MIN_VALUE: #repr_ident = #min_lit;
|
||||
const MAX_VALUE: #repr_ident = #max_lit;
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn get_struct_fields(input: &DeriveInput) -> Result<&Fields, &'static str> {
|
||||
if let Data::Struct(DataStruct { fields, .. }) = &input.data {
|
||||
Ok(fields)
|
||||
@ -133,6 +188,16 @@ fn get_struct_fields(input: &DeriveInput) -> Result<&Fields, &'static str> {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_enum_variants<'a>(
|
||||
input: &'a DeriveInput,
|
||||
) -> Result<impl Iterator<Item = &'a Variant> + 'a, &'static str> {
|
||||
if let Data::Enum(DataEnum { variants, .. }) = &input.data {
|
||||
Ok(variants.iter())
|
||||
} else {
|
||||
Err("deriving this trait is only supported for enums")
|
||||
}
|
||||
}
|
||||
|
||||
fn get_field_types<'a>(
|
||||
fields: &'a Fields,
|
||||
) -> impl Iterator<Item = &'a Type> + 'a {
|
||||
@ -205,3 +270,53 @@ fn get_simple_attr(attributes: &[Attribute], attr_name: &str) -> Option<Ident> {
|
||||
fn get_repr(attributes: &[Attribute]) -> Option<String> {
|
||||
get_simple_attr(attributes, "repr").map(|ident| ident.to_string())
|
||||
}
|
||||
|
||||
struct VariantDiscriminantIterator<'a, I: Iterator<Item = &'a Variant> + 'a> {
|
||||
inner: I,
|
||||
last_value: i64,
|
||||
}
|
||||
|
||||
impl<'a, I: Iterator<Item = &'a Variant> + 'a>
|
||||
VariantDiscriminantIterator<'a, I>
|
||||
{
|
||||
fn new(inner: I) -> Self {
|
||||
VariantDiscriminantIterator { inner, last_value: -1 }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, I: Iterator<Item = &'a Variant> + 'a> Iterator
|
||||
for VariantDiscriminantIterator<'a, I>
|
||||
{
|
||||
type Item = Result<i64, &'static str>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
let variant = self.inner.next()?;
|
||||
if !variant.fields.is_empty() {
|
||||
return Some(Err("Only fieldless enums are supported"));
|
||||
}
|
||||
|
||||
if let Some((_, discriminant)) = &variant.discriminant {
|
||||
let discriminant_value = match parse_int_expr(discriminant) {
|
||||
Ok(value) => value,
|
||||
Err(e) => return Some(Err(e)),
|
||||
};
|
||||
self.last_value = discriminant_value;
|
||||
} else {
|
||||
self.last_value += 1;
|
||||
}
|
||||
|
||||
Some(Ok(self.last_value))
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_int_expr(expr: &Expr) -> Result<i64, &'static str> {
|
||||
match expr {
|
||||
Expr::Unary(ExprUnary { op: UnOp::Neg(_), expr, .. }) => {
|
||||
parse_int_expr(expr).map(|int| -int)
|
||||
}
|
||||
Expr::Lit(ExprLit { lit: Lit::Int(int), .. }) => {
|
||||
int.base10_parse().map_err(|_| "Invalid integer expression")
|
||||
}
|
||||
_ => Err("Not an integer expression"),
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
use bytemuck_derive::{Pod, TransparentWrapper, Zeroable};
|
||||
use bytemuck_derive::{Contiguous, Pod, TransparentWrapper, Zeroable};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[derive(Copy, Clone, Pod, Zeroable)]
|
||||
@ -28,3 +28,23 @@ struct TransparentWithZeroSized<T> {
|
||||
a: u16,
|
||||
b: PhantomData<T>,
|
||||
}
|
||||
|
||||
#[repr(u8)]
|
||||
#[derive(Clone, Copy, Contiguous)]
|
||||
enum ContiguousWithValues {
|
||||
A = 0,
|
||||
B = 1,
|
||||
C = 2,
|
||||
D = 3,
|
||||
E = 4,
|
||||
}
|
||||
|
||||
#[repr(i8)]
|
||||
#[derive(Clone, Copy, Contiguous)]
|
||||
enum ContiguousWithImplicitValues {
|
||||
A = -10,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
E,
|
||||
}
|
||||
|
@ -82,7 +82,7 @@ mod transparent;
|
||||
pub use transparent::*;
|
||||
|
||||
#[cfg(feature = "derive")]
|
||||
pub use bytemuck_derive::{Zeroable, Pod, TransparentWrapper};
|
||||
pub use bytemuck_derive::{Zeroable, Pod, TransparentWrapper, Contiguous};
|
||||
|
||||
/*
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user