Fix soundness issue of TransparentWrapper derive macro. (#173)

Uses the compiler to check that all non-wrapped fields are actually 1-ZSTs,
and uses Zeroable to check that all non-wrapped fields are "conjurable".

Additionally, relaxes the bound of `PhantomData<T: Zeroable>: Zeroable` to all `T: ?Sized`.
This commit is contained in:
zachs18 2023-02-17 13:24:16 -06:00 committed by GitHub
parent d1655f541b
commit 1039388f0b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 136 additions and 19 deletions

1
.gitignore vendored
View File

@ -7,3 +7,4 @@ Cargo.lock
**/*.rs.bk **/*.rs.bk
/derive/target/ /derive/target/
/derive/.vscode/

View File

@ -218,8 +218,7 @@ impl Derivable for CheckedBitPattern {
Ok(assert_fields_are_maybe_pod) Ok(assert_fields_are_maybe_pod)
} }
Data::Enum(_) => Ok(quote!()), /* nothing needed, already guaranteed Data::Enum(_) => Ok(quote!()), /* nothing needed, already guaranteed OK by NoUninit */
* OK by NoUninit */
Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), /* shouldn't be possible since we already error in attribute check for this case */ Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), /* shouldn't be possible since we already error in attribute check for this case */
} }
} }
@ -273,21 +272,43 @@ impl Derivable for TransparentWrapper {
} }
fn asserts(input: &DeriveInput) -> Result<TokenStream> { fn asserts(input: &DeriveInput) -> Result<TokenStream> {
let (impl_generics, _ty_generics, where_clause) =
input.generics.split_for_impl();
let fields = get_struct_fields(input)?; let fields = get_struct_fields(input)?;
let wrapped_type = match Self::get_wrapper_type(&input.attrs, &fields) { let wrapped_type = match Self::get_wrapper_type(&input.attrs, &fields) {
Some(wrapped_type) => wrapped_type.to_string(), Some(wrapped_type) => wrapped_type.to_string(),
None => unreachable!(), /* other code will already reject this derive */ None => unreachable!(), /* other code will already reject this derive */
}; };
let mut wrapped_fields = fields let mut wrapped_field_ty = None;
.iter() let mut nonwrapped_field_tys = vec![];
.filter(|field| field.ty.to_token_stream().to_string() == wrapped_type); for field in fields.iter() {
if let None = wrapped_fields.next() { let field_ty = &field.ty;
bail!("TransparentWrapper must have one field of the wrapped type"); if field_ty.to_token_stream().to_string() == wrapped_type {
}; if wrapped_field_ty.is_some() {
if let Some(_) = wrapped_fields.next() { bail!(
bail!("TransparentWrapper can only have one field of the wrapped type") "TransparentWrapper can only have one field of the wrapped type"
);
}
wrapped_field_ty = Some(field_ty);
} else {
nonwrapped_field_tys.push(field_ty);
}
}
if let Some(wrapped_field_ty) = wrapped_field_ty {
Ok(quote!(
const _: () = {
#[repr(transparent)]
struct AssertWrappedIsWrapped #impl_generics((u8, ::core::marker::PhantomData<#wrapped_field_ty>), #(#nonwrapped_field_tys),*) #where_clause;
fn assert_zeroable<Z: ::bytemuck::Zeroable>() {}
fn check #impl_generics () #where_clause {
#(
assert_zeroable::<#nonwrapped_field_tys>();
)*
}
};
))
} else { } else {
Ok(quote!()) bail!("TransparentWrapper must have one field of the wrapped type")
} }
} }

View File

@ -4,7 +4,7 @@ use bytemuck::{
AnyBitPattern, CheckedBitPattern, Contiguous, NoUninit, Pod, AnyBitPattern, CheckedBitPattern, Contiguous, NoUninit, Pod,
TransparentWrapper, Zeroable, TransparentWrapper, Zeroable,
}; };
use std::marker::PhantomData; use std::marker::{PhantomData, PhantomPinned};
#[derive(Copy, Clone, Pod, Zeroable)] #[derive(Copy, Clone, Pod, Zeroable)]
#[repr(C)] #[repr(C)]
@ -64,6 +64,14 @@ struct TransparentWithZeroSized<T> {
b: PhantomData<T>, b: PhantomData<T>,
} }
struct MyZst<T>(PhantomData<T>, [u8; 0], PhantomPinned);
unsafe impl<T> Zeroable for MyZst<T> {}
#[derive(TransparentWrapper)]
#[repr(transparent)]
#[transparent(u16)]
struct TransparentTupleWithCustomZeroSized<T>(u16, MyZst<T>);
#[repr(u8)] #[repr(u8)]
#[derive(Clone, Copy, Contiguous)] #[derive(Clone, Copy, Contiguous)]
enum ContiguousWithValues { enum ContiguousWithValues {
@ -169,6 +177,21 @@ struct AnyBitPatternTest {
#[repr(transparent)] #[repr(transparent)]
struct NewtypeWrapperTest<T>(T); struct NewtypeWrapperTest<T>(T);
/// ```compile_fail
/// use bytemuck::TransparentWrapper;
///
/// struct NonTransparentSafeZST;
///
/// #[derive(TransparentWrapper)]
/// #[repr(transparent)]
/// struct Wrapper<T>(T, NonTransparentSafeZST);
/// ```
#[derive(
Debug, Copy, Clone, PartialEq, Eq, Pod, Zeroable, TransparentWrapper,
)]
#[repr(transparent)]
struct TransarentWrapperZstTest<T>(T);
#[test] #[test]
fn fails_cast_contiguous() { fn fails_cast_contiguous() {
let can_cast = CheckedBitPatternEnumWithValues::is_valid_bit_pattern(&5); let can_cast = CheckedBitPatternEnumWithValues::is_valid_bit_pattern(&5);
@ -207,7 +230,14 @@ fn fails_cast_bytelit() {
fn passes_cast_bytelit() { fn passes_cast_bytelit() {
let res = let res =
bytemuck::checked::cast_slice::<u8, CheckedBitPatternEnumByteLit>(b"CAB"); bytemuck::checked::cast_slice::<u8, CheckedBitPatternEnumByteLit>(b"CAB");
assert_eq!(res, [CheckedBitPatternEnumByteLit::C, CheckedBitPatternEnumByteLit::A, CheckedBitPatternEnumByteLit::B]); assert_eq!(
res,
[
CheckedBitPatternEnumByteLit::C,
CheckedBitPatternEnumByteLit::A,
CheckedBitPatternEnumByteLit::B
]
);
} }
#[test] #[test]

View File

@ -23,7 +23,9 @@ use super::*;
/// the only non-ZST field. /// the only non-ZST field.
/// ///
/// 2. Any fields *other* than the `Inner` field must be trivially constructable /// 2. Any fields *other* than the `Inner` field must be trivially constructable
/// ZSTs, for example `PhantomData`, `PhantomPinned`, etc. /// ZSTs, for example `PhantomData`, `PhantomPinned`, etc. (When deriving
/// `TransparentWrapper` on a type with ZST fields, the ZST fields must be
/// [`Zeroable`]).
/// ///
/// 3. The `Wrapper` may not impose additional alignment requirements over /// 3. The `Wrapper` may not impose additional alignment requirements over
/// `Inner`. /// `Inner`.
@ -84,6 +86,43 @@ use super::*;
/// let mut buf = [1, 2, 3u8]; /// let mut buf = [1, 2, 3u8];
/// let sm = Slice::wrap_mut(&mut buf); /// let sm = Slice::wrap_mut(&mut buf);
/// ``` /// ```
///
/// ## Deriving
///
/// When deriving, the non-wrapped fields must uphold all the normal requirements,
/// and must also be `Zeroable`.
///
#[cfg_attr(feature = "derive", doc = "```")]
#[cfg_attr(
not(feature = "derive"),
doc = "```ignore
// This example requires the `derive` feature."
)]
/// use bytemuck::TransparentWrapper;
/// use std::marker::PhantomData;
///
/// #[derive(TransparentWrapper)]
/// #[repr(transparent)]
/// #[transparent(usize)]
/// struct Wrapper<T: ?Sized>(usize, PhantomData<T>); // PhantomData<T> implements Zeroable for all T
/// ```
///
/// Here, an error will occur, because `MyZst` does not implement `Zeroable`.
///
#[cfg_attr(feature = "derive", doc = "```compile_fail")]
#[cfg_attr(
not(feature = "derive"),
doc = "```ignore
// This example requires the `derive` feature."
)]
/// use bytemuck::TransparentWrapper;
/// struct MyZst;
///
/// #[derive(TransparentWrapper)]
/// #[repr(transparent)]
/// #[transparent(usize)]
/// struct Wrapper(usize, MyZst); // MyZst does not implement Zeroable
/// ```
pub unsafe trait TransparentWrapper<Inner: ?Sized> { pub unsafe trait TransparentWrapper<Inner: ?Sized> {
/// Convert the inner type into the wrapper type. /// Convert the inner type into the wrapper type.
#[inline] #[inline]

View File

@ -64,7 +64,7 @@ unsafe impl<T> Zeroable for *const [T] {}
unsafe impl Zeroable for *mut str {} unsafe impl Zeroable for *mut str {}
unsafe impl Zeroable for *const str {} unsafe impl Zeroable for *const str {}
unsafe impl<T: Zeroable> Zeroable for PhantomData<T> {} unsafe impl<T: ?Sized> Zeroable for PhantomData<T> {}
unsafe impl Zeroable for PhantomPinned {} unsafe impl Zeroable for PhantomPinned {}
unsafe impl<T: Zeroable> Zeroable for ManuallyDrop<T> {} unsafe impl<T: Zeroable> Zeroable for ManuallyDrop<T> {}
unsafe impl<T: Zeroable> Zeroable for core::cell::UnsafeCell<T> {} unsafe impl<T: Zeroable> Zeroable for core::cell::UnsafeCell<T> {}

View File

@ -2,6 +2,7 @@
#![allow(dead_code)] #![allow(dead_code)]
use bytemuck::{ByteEq, ByteHash, Pod, TransparentWrapper, Zeroable}; use bytemuck::{ByteEq, ByteHash, Pod, TransparentWrapper, Zeroable};
use std::marker::PhantomData;
#[derive(Copy, Clone, Pod, Zeroable, ByteEq, ByteHash)] #[derive(Copy, Clone, Pod, Zeroable, ByteEq, ByteHash)]
#[repr(C)] #[repr(C)]
@ -26,7 +27,7 @@ struct TransparentWithZeroSized {
#[derive(TransparentWrapper)] #[derive(TransparentWrapper)]
#[repr(transparent)] #[repr(transparent)]
struct TransparentWithGeneric<T> { struct TransparentWithGeneric<T: ?Sized> {
a: T, a: T,
} }
@ -39,9 +40,9 @@ fn test_generic<T>(x: T) -> TransparentWithGeneric<T> {
#[derive(TransparentWrapper)] #[derive(TransparentWrapper)]
#[repr(transparent)] #[repr(transparent)]
#[transparent(T)] #[transparent(T)]
struct TransparentWithGenericAndZeroSized<T> { struct TransparentWithGenericAndZeroSized<T: ?Sized> {
a: T, a: (),
b: () b: T,
} }
/// Ensuring that no additional bounds are emitted. /// Ensuring that no additional bounds are emitted.
@ -49,3 +50,28 @@ struct TransparentWithGenericAndZeroSized<T> {
fn test_generic_with_zst<T>(x: T) -> TransparentWithGenericAndZeroSized<T> { fn test_generic_with_zst<T>(x: T) -> TransparentWithGenericAndZeroSized<T> {
TransparentWithGenericAndZeroSized::wrap(x) TransparentWithGenericAndZeroSized::wrap(x)
} }
#[derive(TransparentWrapper)]
#[repr(transparent)]
struct TransparentUnsized {
a: dyn std::fmt::Debug,
}
type DynDebug = dyn std::fmt::Debug;
#[derive(TransparentWrapper)]
#[repr(transparent)]
#[transparent(DynDebug)]
struct TransparentUnsizedWithZeroSized {
a: (),
b: DynDebug,
}
#[derive(TransparentWrapper)]
#[repr(transparent)]
#[transparent(DynDebug)]
struct TransparentUnsizedWithGenericZeroSizeds<T: ?Sized, U: ?Sized> {
a: PhantomData<T>,
b: PhantomData<U>,
c: DynDebug,
}