mirror of
https://github.com/Lokathor/bytemuck.git
synced 2024-12-03 12:12:26 +00:00
Fix soundness issue of TransparentWrapper derive macro. (#173)
Uses the compiler to check that all non-wrapped fields are actually 1-ZSTs, and uses Zeroable to check that all non-wrapped fields are "conjurable". Additionally, relaxes the bound of `PhantomData<T: Zeroable>: Zeroable` to all `T: ?Sized`.
This commit is contained in:
parent
d1655f541b
commit
1039388f0b
1
.gitignore
vendored
1
.gitignore
vendored
@ -7,3 +7,4 @@ Cargo.lock
|
||||
**/*.rs.bk
|
||||
|
||||
/derive/target/
|
||||
/derive/.vscode/
|
||||
|
@ -218,8 +218,7 @@ impl Derivable for CheckedBitPattern {
|
||||
|
||||
Ok(assert_fields_are_maybe_pod)
|
||||
}
|
||||
Data::Enum(_) => Ok(quote!()), /* nothing needed, already guaranteed
|
||||
* OK by NoUninit */
|
||||
Data::Enum(_) => Ok(quote!()), /* nothing needed, already guaranteed OK by NoUninit */
|
||||
Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), /* shouldn't be possible since we already error in attribute check for this case */
|
||||
}
|
||||
}
|
||||
@ -273,21 +272,43 @@ impl Derivable for TransparentWrapper {
|
||||
}
|
||||
|
||||
fn asserts(input: &DeriveInput) -> Result<TokenStream> {
|
||||
let (impl_generics, _ty_generics, where_clause) =
|
||||
input.generics.split_for_impl();
|
||||
let fields = get_struct_fields(input)?;
|
||||
let wrapped_type = match Self::get_wrapper_type(&input.attrs, &fields) {
|
||||
Some(wrapped_type) => wrapped_type.to_string(),
|
||||
None => unreachable!(), /* other code will already reject this derive */
|
||||
};
|
||||
let mut wrapped_fields = fields
|
||||
.iter()
|
||||
.filter(|field| field.ty.to_token_stream().to_string() == wrapped_type);
|
||||
if let None = wrapped_fields.next() {
|
||||
bail!("TransparentWrapper must have one field of the wrapped type");
|
||||
};
|
||||
if let Some(_) = wrapped_fields.next() {
|
||||
bail!("TransparentWrapper can only have one field of the wrapped type")
|
||||
let mut wrapped_field_ty = None;
|
||||
let mut nonwrapped_field_tys = vec![];
|
||||
for field in fields.iter() {
|
||||
let field_ty = &field.ty;
|
||||
if field_ty.to_token_stream().to_string() == wrapped_type {
|
||||
if wrapped_field_ty.is_some() {
|
||||
bail!(
|
||||
"TransparentWrapper can only have one field of the wrapped type"
|
||||
);
|
||||
}
|
||||
wrapped_field_ty = Some(field_ty);
|
||||
} else {
|
||||
nonwrapped_field_tys.push(field_ty);
|
||||
}
|
||||
}
|
||||
if let Some(wrapped_field_ty) = wrapped_field_ty {
|
||||
Ok(quote!(
|
||||
const _: () = {
|
||||
#[repr(transparent)]
|
||||
struct AssertWrappedIsWrapped #impl_generics((u8, ::core::marker::PhantomData<#wrapped_field_ty>), #(#nonwrapped_field_tys),*) #where_clause;
|
||||
fn assert_zeroable<Z: ::bytemuck::Zeroable>() {}
|
||||
fn check #impl_generics () #where_clause {
|
||||
#(
|
||||
assert_zeroable::<#nonwrapped_field_tys>();
|
||||
)*
|
||||
}
|
||||
};
|
||||
))
|
||||
} else {
|
||||
Ok(quote!())
|
||||
bail!("TransparentWrapper must have one field of the wrapped type")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -4,7 +4,7 @@ use bytemuck::{
|
||||
AnyBitPattern, CheckedBitPattern, Contiguous, NoUninit, Pod,
|
||||
TransparentWrapper, Zeroable,
|
||||
};
|
||||
use std::marker::PhantomData;
|
||||
use std::marker::{PhantomData, PhantomPinned};
|
||||
|
||||
#[derive(Copy, Clone, Pod, Zeroable)]
|
||||
#[repr(C)]
|
||||
@ -64,6 +64,14 @@ struct TransparentWithZeroSized<T> {
|
||||
b: PhantomData<T>,
|
||||
}
|
||||
|
||||
struct MyZst<T>(PhantomData<T>, [u8; 0], PhantomPinned);
|
||||
unsafe impl<T> Zeroable for MyZst<T> {}
|
||||
|
||||
#[derive(TransparentWrapper)]
|
||||
#[repr(transparent)]
|
||||
#[transparent(u16)]
|
||||
struct TransparentTupleWithCustomZeroSized<T>(u16, MyZst<T>);
|
||||
|
||||
#[repr(u8)]
|
||||
#[derive(Clone, Copy, Contiguous)]
|
||||
enum ContiguousWithValues {
|
||||
@ -169,6 +177,21 @@ struct AnyBitPatternTest {
|
||||
#[repr(transparent)]
|
||||
struct NewtypeWrapperTest<T>(T);
|
||||
|
||||
/// ```compile_fail
|
||||
/// use bytemuck::TransparentWrapper;
|
||||
///
|
||||
/// struct NonTransparentSafeZST;
|
||||
///
|
||||
/// #[derive(TransparentWrapper)]
|
||||
/// #[repr(transparent)]
|
||||
/// struct Wrapper<T>(T, NonTransparentSafeZST);
|
||||
/// ```
|
||||
#[derive(
|
||||
Debug, Copy, Clone, PartialEq, Eq, Pod, Zeroable, TransparentWrapper,
|
||||
)]
|
||||
#[repr(transparent)]
|
||||
struct TransarentWrapperZstTest<T>(T);
|
||||
|
||||
#[test]
|
||||
fn fails_cast_contiguous() {
|
||||
let can_cast = CheckedBitPatternEnumWithValues::is_valid_bit_pattern(&5);
|
||||
@ -207,7 +230,14 @@ fn fails_cast_bytelit() {
|
||||
fn passes_cast_bytelit() {
|
||||
let res =
|
||||
bytemuck::checked::cast_slice::<u8, CheckedBitPatternEnumByteLit>(b"CAB");
|
||||
assert_eq!(res, [CheckedBitPatternEnumByteLit::C, CheckedBitPatternEnumByteLit::A, CheckedBitPatternEnumByteLit::B]);
|
||||
assert_eq!(
|
||||
res,
|
||||
[
|
||||
CheckedBitPatternEnumByteLit::C,
|
||||
CheckedBitPatternEnumByteLit::A,
|
||||
CheckedBitPatternEnumByteLit::B
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -23,7 +23,9 @@ use super::*;
|
||||
/// the only non-ZST field.
|
||||
///
|
||||
/// 2. Any fields *other* than the `Inner` field must be trivially constructable
|
||||
/// ZSTs, for example `PhantomData`, `PhantomPinned`, etc.
|
||||
/// ZSTs, for example `PhantomData`, `PhantomPinned`, etc. (When deriving
|
||||
/// `TransparentWrapper` on a type with ZST fields, the ZST fields must be
|
||||
/// [`Zeroable`]).
|
||||
///
|
||||
/// 3. The `Wrapper` may not impose additional alignment requirements over
|
||||
/// `Inner`.
|
||||
@ -84,6 +86,43 @@ use super::*;
|
||||
/// let mut buf = [1, 2, 3u8];
|
||||
/// let sm = Slice::wrap_mut(&mut buf);
|
||||
/// ```
|
||||
///
|
||||
/// ## Deriving
|
||||
///
|
||||
/// When deriving, the non-wrapped fields must uphold all the normal requirements,
|
||||
/// and must also be `Zeroable`.
|
||||
///
|
||||
#[cfg_attr(feature = "derive", doc = "```")]
|
||||
#[cfg_attr(
|
||||
not(feature = "derive"),
|
||||
doc = "```ignore
|
||||
// This example requires the `derive` feature."
|
||||
)]
|
||||
/// use bytemuck::TransparentWrapper;
|
||||
/// use std::marker::PhantomData;
|
||||
///
|
||||
/// #[derive(TransparentWrapper)]
|
||||
/// #[repr(transparent)]
|
||||
/// #[transparent(usize)]
|
||||
/// struct Wrapper<T: ?Sized>(usize, PhantomData<T>); // PhantomData<T> implements Zeroable for all T
|
||||
/// ```
|
||||
///
|
||||
/// Here, an error will occur, because `MyZst` does not implement `Zeroable`.
|
||||
///
|
||||
#[cfg_attr(feature = "derive", doc = "```compile_fail")]
|
||||
#[cfg_attr(
|
||||
not(feature = "derive"),
|
||||
doc = "```ignore
|
||||
// This example requires the `derive` feature."
|
||||
)]
|
||||
/// use bytemuck::TransparentWrapper;
|
||||
/// struct MyZst;
|
||||
///
|
||||
/// #[derive(TransparentWrapper)]
|
||||
/// #[repr(transparent)]
|
||||
/// #[transparent(usize)]
|
||||
/// struct Wrapper(usize, MyZst); // MyZst does not implement Zeroable
|
||||
/// ```
|
||||
pub unsafe trait TransparentWrapper<Inner: ?Sized> {
|
||||
/// Convert the inner type into the wrapper type.
|
||||
#[inline]
|
||||
|
@ -64,7 +64,7 @@ unsafe impl<T> Zeroable for *const [T] {}
|
||||
unsafe impl Zeroable for *mut str {}
|
||||
unsafe impl Zeroable for *const str {}
|
||||
|
||||
unsafe impl<T: Zeroable> Zeroable for PhantomData<T> {}
|
||||
unsafe impl<T: ?Sized> Zeroable for PhantomData<T> {}
|
||||
unsafe impl Zeroable for PhantomPinned {}
|
||||
unsafe impl<T: Zeroable> Zeroable for ManuallyDrop<T> {}
|
||||
unsafe impl<T: Zeroable> Zeroable for core::cell::UnsafeCell<T> {}
|
||||
|
@ -2,6 +2,7 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
use bytemuck::{ByteEq, ByteHash, Pod, TransparentWrapper, Zeroable};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[derive(Copy, Clone, Pod, Zeroable, ByteEq, ByteHash)]
|
||||
#[repr(C)]
|
||||
@ -26,7 +27,7 @@ struct TransparentWithZeroSized {
|
||||
|
||||
#[derive(TransparentWrapper)]
|
||||
#[repr(transparent)]
|
||||
struct TransparentWithGeneric<T> {
|
||||
struct TransparentWithGeneric<T: ?Sized> {
|
||||
a: T,
|
||||
}
|
||||
|
||||
@ -39,9 +40,9 @@ fn test_generic<T>(x: T) -> TransparentWithGeneric<T> {
|
||||
#[derive(TransparentWrapper)]
|
||||
#[repr(transparent)]
|
||||
#[transparent(T)]
|
||||
struct TransparentWithGenericAndZeroSized<T> {
|
||||
a: T,
|
||||
b: ()
|
||||
struct TransparentWithGenericAndZeroSized<T: ?Sized> {
|
||||
a: (),
|
||||
b: T,
|
||||
}
|
||||
|
||||
/// Ensuring that no additional bounds are emitted.
|
||||
@ -49,3 +50,28 @@ struct TransparentWithGenericAndZeroSized<T> {
|
||||
fn test_generic_with_zst<T>(x: T) -> TransparentWithGenericAndZeroSized<T> {
|
||||
TransparentWithGenericAndZeroSized::wrap(x)
|
||||
}
|
||||
|
||||
#[derive(TransparentWrapper)]
|
||||
#[repr(transparent)]
|
||||
struct TransparentUnsized {
|
||||
a: dyn std::fmt::Debug,
|
||||
}
|
||||
|
||||
type DynDebug = dyn std::fmt::Debug;
|
||||
|
||||
#[derive(TransparentWrapper)]
|
||||
#[repr(transparent)]
|
||||
#[transparent(DynDebug)]
|
||||
struct TransparentUnsizedWithZeroSized {
|
||||
a: (),
|
||||
b: DynDebug,
|
||||
}
|
||||
|
||||
#[derive(TransparentWrapper)]
|
||||
#[repr(transparent)]
|
||||
#[transparent(DynDebug)]
|
||||
struct TransparentUnsizedWithGenericZeroSizeds<T: ?Sized, U: ?Sized> {
|
||||
a: PhantomData<T>,
|
||||
b: PhantomData<U>,
|
||||
c: DynDebug,
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user