add derive macro for Contiguous (#31)

This commit is contained in:
Robin Appelman 2020-08-21 22:08:34 +00:00 committed by GitHub
parent e202fa2756
commit 24b71e078f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 180 additions and 8 deletions

View File

@ -6,7 +6,7 @@ use proc_macro2::TokenStream;
use quote::quote;
use syn::{parse_macro_input, DeriveInput};
use crate::traits::{Derivable, Pod, TransparentWrapper, Zeroable};
use crate::traits::{Contiguous, Derivable, Pod, TransparentWrapper, Zeroable};
/// Derive the `Pod` trait for a struct
///
@ -110,6 +110,40 @@ pub fn derive_transparent(
proc_macro::TokenStream::from(expanded)
}
/// Derive the `Contiguous` trait for an enum
///
/// The macro ensures that the enum follows all the the safety requirements
/// for the `Contiguous` trait.
///
/// The following constraints need to be satisfied for the macro to succeed
///
/// - The enum must be `#[repr(Int)]`
/// - The enum must be fieldless
/// - The enum discriminants must form a contiguous range
///
/// ## Example
///
/// ```rust
/// # use bytemuck_derive::{Contiguous};
///
/// #[derive(Copy, Clone, Contiguous)]
/// #[repr(u8)]
/// enum Test {
/// A = 0,
/// B = 1,
/// C = 2,
/// }
/// ```
#[proc_macro_derive(Contiguous)]
pub fn derive_contiguous(
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let expanded =
derive_marker_trait::<Contiguous>(parse_macro_input!(input as DeriveInput));
proc_macro::TokenStream::from(expanded)
}
/// Basic wrapper for error handling
fn derive_marker_trait<Trait: Derivable>(input: DeriveInput) -> TokenStream {
derive_marker_trait_inner::<Trait>(input).unwrap_or_else(|err| {
@ -131,10 +165,13 @@ fn derive_marker_trait_inner<Trait: Derivable>(
Trait::check_attributes(&input.attrs)?;
let asserts = Trait::struct_asserts(&input)?;
let trait_params = Trait::generic_params(&input)?;
let trait_impl = Trait::trait_impl(&input)?;
Ok(quote! {
#asserts
unsafe impl #impl_generics #trait_ #trait_params for #name #ty_generics #where_clause {}
unsafe impl #impl_generics #trait_ #trait_params for #name #ty_generics #where_clause {
#trait_impl
}
})
}

View File

@ -1,8 +1,9 @@
use proc_macro2::{Ident, TokenStream, TokenTree};
use quote::{quote, quote_spanned, ToTokens};
use syn::{
spanned::Spanned, AttrStyle, Attribute, Data, DataStruct, DeriveInput,
Fields, Type,
spanned::Spanned, AttrStyle, Attribute, Data, DataEnum, DataStruct,
DeriveInput, Expr, ExprLit, ExprUnary, Fields, Lit, LitInt, Type, UnOp,
Variant,
};
pub trait Derivable {
@ -10,10 +11,15 @@ pub trait Derivable {
fn generic_params(_input: &DeriveInput) -> Result<TokenStream, &'static str> {
Ok(quote!())
}
fn struct_asserts(input: &DeriveInput) -> Result<TokenStream, &'static str>;
fn struct_asserts(_input: &DeriveInput) -> Result<TokenStream, &'static str> {
Ok(quote!())
}
fn check_attributes(_attributes: &[Attribute]) -> Result<(), &'static str> {
Ok(())
}
fn trait_impl(_input: &DeriveInput) -> Result<TokenStream, &'static str> {
Ok(quote!())
}
}
pub struct Pod;
@ -125,6 +131,55 @@ impl Derivable for TransparentWrapper {
}
}
pub struct Contiguous;
impl Derivable for Contiguous {
fn ident() -> TokenStream {
quote!(::bytemuck::Contiguous)
}
fn trait_impl(input: &DeriveInput) -> Result<TokenStream, &'static str> {
let repr = get_repr(&input.attrs)
.ok_or("Contiguous requires the enum to be #[repr(Int)]")?;
if !repr.starts_with('u') && !repr.starts_with('i') {
return Err("Contiguous requires the enum to be #[repr(Int)]");
}
let variants = get_enum_variants(input)?;
let mut variants_with_discriminator =
VariantDiscriminantIterator::new(variants);
let (min, max, count) = variants_with_discriminator.try_fold(
(i64::max_value(), i64::min_value(), 0),
|(min, max, count), res| {
let discriminator = res?;
Ok((
i64::min(min, discriminator),
i64::max(max, discriminator),
count + 1,
))
},
)?;
if max - min != count - 1 {
return Err(
"Contiguous requires the enum discriminants to be contiguous",
);
}
let repr_ident = Ident::new(&repr, input.span());
let min_lit = LitInt::new(&format!("{}", min), input.span());
let max_lit = LitInt::new(&format!("{}", max), input.span());
Ok(quote! {
type Int = #repr_ident;
const MIN_VALUE: #repr_ident = #min_lit;
const MAX_VALUE: #repr_ident = #max_lit;
})
}
}
fn get_struct_fields(input: &DeriveInput) -> Result<&Fields, &'static str> {
if let Data::Struct(DataStruct { fields, .. }) = &input.data {
Ok(fields)
@ -133,6 +188,16 @@ fn get_struct_fields(input: &DeriveInput) -> Result<&Fields, &'static str> {
}
}
fn get_enum_variants<'a>(
input: &'a DeriveInput,
) -> Result<impl Iterator<Item = &'a Variant> + 'a, &'static str> {
if let Data::Enum(DataEnum { variants, .. }) = &input.data {
Ok(variants.iter())
} else {
Err("deriving this trait is only supported for enums")
}
}
fn get_field_types<'a>(
fields: &'a Fields,
) -> impl Iterator<Item = &'a Type> + 'a {
@ -205,3 +270,53 @@ fn get_simple_attr(attributes: &[Attribute], attr_name: &str) -> Option<Ident> {
fn get_repr(attributes: &[Attribute]) -> Option<String> {
get_simple_attr(attributes, "repr").map(|ident| ident.to_string())
}
struct VariantDiscriminantIterator<'a, I: Iterator<Item = &'a Variant> + 'a> {
inner: I,
last_value: i64,
}
impl<'a, I: Iterator<Item = &'a Variant> + 'a>
VariantDiscriminantIterator<'a, I>
{
fn new(inner: I) -> Self {
VariantDiscriminantIterator { inner, last_value: -1 }
}
}
impl<'a, I: Iterator<Item = &'a Variant> + 'a> Iterator
for VariantDiscriminantIterator<'a, I>
{
type Item = Result<i64, &'static str>;
fn next(&mut self) -> Option<Self::Item> {
let variant = self.inner.next()?;
if !variant.fields.is_empty() {
return Some(Err("Only fieldless enums are supported"));
}
if let Some((_, discriminant)) = &variant.discriminant {
let discriminant_value = match parse_int_expr(discriminant) {
Ok(value) => value,
Err(e) => return Some(Err(e)),
};
self.last_value = discriminant_value;
} else {
self.last_value += 1;
}
Some(Ok(self.last_value))
}
}
fn parse_int_expr(expr: &Expr) -> Result<i64, &'static str> {
match expr {
Expr::Unary(ExprUnary { op: UnOp::Neg(_), expr, .. }) => {
parse_int_expr(expr).map(|int| -int)
}
Expr::Lit(ExprLit { lit: Lit::Int(int), .. }) => {
int.base10_parse().map_err(|_| "Invalid integer expression")
}
_ => Err("Not an integer expression"),
}
}

View File

@ -1,6 +1,6 @@
#![allow(dead_code)]
use bytemuck_derive::{Pod, TransparentWrapper, Zeroable};
use bytemuck_derive::{Contiguous, Pod, TransparentWrapper, Zeroable};
use std::marker::PhantomData;
#[derive(Copy, Clone, Pod, Zeroable)]
@ -28,3 +28,23 @@ struct TransparentWithZeroSized<T> {
a: u16,
b: PhantomData<T>,
}
#[repr(u8)]
#[derive(Clone, Copy, Contiguous)]
enum ContiguousWithValues {
A = 0,
B = 1,
C = 2,
D = 3,
E = 4,
}
#[repr(i8)]
#[derive(Clone, Copy, Contiguous)]
enum ContiguousWithImplicitValues {
A = -10,
B,
C,
D,
E,
}

View File

@ -82,7 +82,7 @@ mod transparent;
pub use transparent::*;
#[cfg(feature = "derive")]
pub use bytemuck_derive::{Zeroable, Pod, TransparentWrapper};
pub use bytemuck_derive::{Zeroable, Pod, TransparentWrapper, Contiguous};
/*