diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 041f68b..7208ac3 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -6,7 +6,7 @@ use proc_macro2::TokenStream; use quote::quote; use syn::{parse_macro_input, DeriveInput}; -use crate::traits::{Derivable, Pod, TransparentWrapper, Zeroable}; +use crate::traits::{Contiguous, Derivable, Pod, TransparentWrapper, Zeroable}; /// Derive the `Pod` trait for a struct /// @@ -110,6 +110,40 @@ pub fn derive_transparent( proc_macro::TokenStream::from(expanded) } +/// Derive the `Contiguous` trait for an enum +/// +/// The macro ensures that the enum follows all the the safety requirements +/// for the `Contiguous` trait. +/// +/// The following constraints need to be satisfied for the macro to succeed +/// +/// - The enum must be `#[repr(Int)]` +/// - The enum must be fieldless +/// - The enum discriminants must form a contiguous range +/// +/// ## Example +/// +/// ```rust +/// # use bytemuck_derive::{Contiguous}; +/// +/// #[derive(Copy, Clone, Contiguous)] +/// #[repr(u8)] +/// enum Test { +/// A = 0, +/// B = 1, +/// C = 2, +/// } +/// ``` +#[proc_macro_derive(Contiguous)] +pub fn derive_contiguous( + input: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + let expanded = + derive_marker_trait::(parse_macro_input!(input as DeriveInput)); + + proc_macro::TokenStream::from(expanded) +} + /// Basic wrapper for error handling fn derive_marker_trait(input: DeriveInput) -> TokenStream { derive_marker_trait_inner::(input).unwrap_or_else(|err| { @@ -131,10 +165,13 @@ fn derive_marker_trait_inner( Trait::check_attributes(&input.attrs)?; let asserts = Trait::struct_asserts(&input)?; let trait_params = Trait::generic_params(&input)?; + let trait_impl = Trait::trait_impl(&input)?; Ok(quote! { #asserts - unsafe impl #impl_generics #trait_ #trait_params for #name #ty_generics #where_clause {} + unsafe impl #impl_generics #trait_ #trait_params for #name #ty_generics #where_clause { + #trait_impl + } }) } diff --git a/derive/src/traits.rs b/derive/src/traits.rs index 3167ed2..91296b8 100644 --- a/derive/src/traits.rs +++ b/derive/src/traits.rs @@ -1,8 +1,9 @@ use proc_macro2::{Ident, TokenStream, TokenTree}; use quote::{quote, quote_spanned, ToTokens}; use syn::{ - spanned::Spanned, AttrStyle, Attribute, Data, DataStruct, DeriveInput, - Fields, Type, + spanned::Spanned, AttrStyle, Attribute, Data, DataEnum, DataStruct, + DeriveInput, Expr, ExprLit, ExprUnary, Fields, Lit, LitInt, Type, UnOp, + Variant, }; pub trait Derivable { @@ -10,10 +11,15 @@ pub trait Derivable { fn generic_params(_input: &DeriveInput) -> Result { Ok(quote!()) } - fn struct_asserts(input: &DeriveInput) -> Result; + fn struct_asserts(_input: &DeriveInput) -> Result { + Ok(quote!()) + } fn check_attributes(_attributes: &[Attribute]) -> Result<(), &'static str> { Ok(()) } + fn trait_impl(_input: &DeriveInput) -> Result { + Ok(quote!()) + } } pub struct Pod; @@ -91,7 +97,7 @@ impl Derivable for TransparentWrapper { let fields = get_struct_fields(input)?; Self::get_wrapper_type(&input.attrs, fields).map(|ty| quote!(<#ty>)) - .ok_or("when deriving TransparentWrapper for a struct with more than one field you need to specify the transparent field using #[transparent(T)]") + .ok_or("when deriving TransparentWrapper for a struct with more than one field you need to specify the transparent field using #[transparent(T)]") } fn struct_asserts(input: &DeriveInput) -> Result { @@ -125,6 +131,55 @@ impl Derivable for TransparentWrapper { } } +pub struct Contiguous; + +impl Derivable for Contiguous { + fn ident() -> TokenStream { + quote!(::bytemuck::Contiguous) + } + + fn trait_impl(input: &DeriveInput) -> Result { + let repr = get_repr(&input.attrs) + .ok_or("Contiguous requires the enum to be #[repr(Int)]")?; + + if !repr.starts_with('u') && !repr.starts_with('i') { + return Err("Contiguous requires the enum to be #[repr(Int)]"); + } + + let variants = get_enum_variants(input)?; + let mut variants_with_discriminator = + VariantDiscriminantIterator::new(variants); + + let (min, max, count) = variants_with_discriminator.try_fold( + (i64::max_value(), i64::min_value(), 0), + |(min, max, count), res| { + let discriminator = res?; + Ok(( + i64::min(min, discriminator), + i64::max(max, discriminator), + count + 1, + )) + }, + )?; + + if max - min != count - 1 { + return Err( + "Contiguous requires the enum discriminants to be contiguous", + ); + } + + let repr_ident = Ident::new(&repr, input.span()); + let min_lit = LitInt::new(&format!("{}", min), input.span()); + let max_lit = LitInt::new(&format!("{}", max), input.span()); + + Ok(quote! { + type Int = #repr_ident; + const MIN_VALUE: #repr_ident = #min_lit; + const MAX_VALUE: #repr_ident = #max_lit; + }) + } +} + fn get_struct_fields(input: &DeriveInput) -> Result<&Fields, &'static str> { if let Data::Struct(DataStruct { fields, .. }) = &input.data { Ok(fields) @@ -133,6 +188,16 @@ fn get_struct_fields(input: &DeriveInput) -> Result<&Fields, &'static str> { } } +fn get_enum_variants<'a>( + input: &'a DeriveInput, +) -> Result + 'a, &'static str> { + if let Data::Enum(DataEnum { variants, .. }) = &input.data { + Ok(variants.iter()) + } else { + Err("deriving this trait is only supported for enums") + } +} + fn get_field_types<'a>( fields: &'a Fields, ) -> impl Iterator + 'a { @@ -205,3 +270,53 @@ fn get_simple_attr(attributes: &[Attribute], attr_name: &str) -> Option { fn get_repr(attributes: &[Attribute]) -> Option { get_simple_attr(attributes, "repr").map(|ident| ident.to_string()) } + +struct VariantDiscriminantIterator<'a, I: Iterator + 'a> { + inner: I, + last_value: i64, +} + +impl<'a, I: Iterator + 'a> + VariantDiscriminantIterator<'a, I> +{ + fn new(inner: I) -> Self { + VariantDiscriminantIterator { inner, last_value: -1 } + } +} + +impl<'a, I: Iterator + 'a> Iterator + for VariantDiscriminantIterator<'a, I> +{ + type Item = Result; + + fn next(&mut self) -> Option { + let variant = self.inner.next()?; + if !variant.fields.is_empty() { + return Some(Err("Only fieldless enums are supported")); + } + + if let Some((_, discriminant)) = &variant.discriminant { + let discriminant_value = match parse_int_expr(discriminant) { + Ok(value) => value, + Err(e) => return Some(Err(e)), + }; + self.last_value = discriminant_value; + } else { + self.last_value += 1; + } + + Some(Ok(self.last_value)) + } +} + +fn parse_int_expr(expr: &Expr) -> Result { + match expr { + Expr::Unary(ExprUnary { op: UnOp::Neg(_), expr, .. }) => { + parse_int_expr(expr).map(|int| -int) + } + Expr::Lit(ExprLit { lit: Lit::Int(int), .. }) => { + int.base10_parse().map_err(|_| "Invalid integer expression") + } + _ => Err("Not an integer expression"), + } +} diff --git a/derive/tests/basic.rs b/derive/tests/basic.rs index 8ab69b2..867399a 100644 --- a/derive/tests/basic.rs +++ b/derive/tests/basic.rs @@ -1,6 +1,6 @@ #![allow(dead_code)] -use bytemuck_derive::{Pod, TransparentWrapper, Zeroable}; +use bytemuck_derive::{Contiguous, Pod, TransparentWrapper, Zeroable}; use std::marker::PhantomData; #[derive(Copy, Clone, Pod, Zeroable)] @@ -28,3 +28,23 @@ struct TransparentWithZeroSized { a: u16, b: PhantomData, } + +#[repr(u8)] +#[derive(Clone, Copy, Contiguous)] +enum ContiguousWithValues { + A = 0, + B = 1, + C = 2, + D = 3, + E = 4, +} + +#[repr(i8)] +#[derive(Clone, Copy, Contiguous)] +enum ContiguousWithImplicitValues { + A = -10, + B, + C, + D, + E, +} diff --git a/src/lib.rs b/src/lib.rs index b9a3ec2..c0fa320 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -82,7 +82,7 @@ mod transparent; pub use transparent::*; #[cfg(feature = "derive")] -pub use bytemuck_derive::{Zeroable, Pod, TransparentWrapper}; +pub use bytemuck_derive::{Zeroable, Pod, TransparentWrapper, Contiguous}; /*