mirror of
https://github.com/Lokathor/bytemuck.git
synced 2024-12-11 16:12:25 +00:00
Fix soundness issue of TransparentWrapper derive macro. (#173)
Uses the compiler to check that all non-wrapped fields are actually 1-ZSTs, and uses Zeroable to check that all non-wrapped fields are "conjurable". Additionally, relaxes the bound of `PhantomData<T: Zeroable>: Zeroable` to all `T: ?Sized`.
This commit is contained in:
parent
d1655f541b
commit
1039388f0b
1
.gitignore
vendored
1
.gitignore
vendored
@ -7,3 +7,4 @@ Cargo.lock
|
|||||||
**/*.rs.bk
|
**/*.rs.bk
|
||||||
|
|
||||||
/derive/target/
|
/derive/target/
|
||||||
|
/derive/.vscode/
|
||||||
|
@ -218,8 +218,7 @@ impl Derivable for CheckedBitPattern {
|
|||||||
|
|
||||||
Ok(assert_fields_are_maybe_pod)
|
Ok(assert_fields_are_maybe_pod)
|
||||||
}
|
}
|
||||||
Data::Enum(_) => Ok(quote!()), /* nothing needed, already guaranteed
|
Data::Enum(_) => Ok(quote!()), /* nothing needed, already guaranteed OK by NoUninit */
|
||||||
* OK by NoUninit */
|
|
||||||
Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), /* shouldn't be possible since we already error in attribute check for this case */
|
Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), /* shouldn't be possible since we already error in attribute check for this case */
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -273,21 +272,43 @@ impl Derivable for TransparentWrapper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn asserts(input: &DeriveInput) -> Result<TokenStream> {
|
fn asserts(input: &DeriveInput) -> Result<TokenStream> {
|
||||||
|
let (impl_generics, _ty_generics, where_clause) =
|
||||||
|
input.generics.split_for_impl();
|
||||||
let fields = get_struct_fields(input)?;
|
let fields = get_struct_fields(input)?;
|
||||||
let wrapped_type = match Self::get_wrapper_type(&input.attrs, &fields) {
|
let wrapped_type = match Self::get_wrapper_type(&input.attrs, &fields) {
|
||||||
Some(wrapped_type) => wrapped_type.to_string(),
|
Some(wrapped_type) => wrapped_type.to_string(),
|
||||||
None => unreachable!(), /* other code will already reject this derive */
|
None => unreachable!(), /* other code will already reject this derive */
|
||||||
};
|
};
|
||||||
let mut wrapped_fields = fields
|
let mut wrapped_field_ty = None;
|
||||||
.iter()
|
let mut nonwrapped_field_tys = vec![];
|
||||||
.filter(|field| field.ty.to_token_stream().to_string() == wrapped_type);
|
for field in fields.iter() {
|
||||||
if let None = wrapped_fields.next() {
|
let field_ty = &field.ty;
|
||||||
bail!("TransparentWrapper must have one field of the wrapped type");
|
if field_ty.to_token_stream().to_string() == wrapped_type {
|
||||||
};
|
if wrapped_field_ty.is_some() {
|
||||||
if let Some(_) = wrapped_fields.next() {
|
bail!(
|
||||||
bail!("TransparentWrapper can only have one field of the wrapped type")
|
"TransparentWrapper can only have one field of the wrapped type"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
wrapped_field_ty = Some(field_ty);
|
||||||
|
} else {
|
||||||
|
nonwrapped_field_tys.push(field_ty);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(wrapped_field_ty) = wrapped_field_ty {
|
||||||
|
Ok(quote!(
|
||||||
|
const _: () = {
|
||||||
|
#[repr(transparent)]
|
||||||
|
struct AssertWrappedIsWrapped #impl_generics((u8, ::core::marker::PhantomData<#wrapped_field_ty>), #(#nonwrapped_field_tys),*) #where_clause;
|
||||||
|
fn assert_zeroable<Z: ::bytemuck::Zeroable>() {}
|
||||||
|
fn check #impl_generics () #where_clause {
|
||||||
|
#(
|
||||||
|
assert_zeroable::<#nonwrapped_field_tys>();
|
||||||
|
)*
|
||||||
|
}
|
||||||
|
};
|
||||||
|
))
|
||||||
} else {
|
} else {
|
||||||
Ok(quote!())
|
bail!("TransparentWrapper must have one field of the wrapped type")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ use bytemuck::{
|
|||||||
AnyBitPattern, CheckedBitPattern, Contiguous, NoUninit, Pod,
|
AnyBitPattern, CheckedBitPattern, Contiguous, NoUninit, Pod,
|
||||||
TransparentWrapper, Zeroable,
|
TransparentWrapper, Zeroable,
|
||||||
};
|
};
|
||||||
use std::marker::PhantomData;
|
use std::marker::{PhantomData, PhantomPinned};
|
||||||
|
|
||||||
#[derive(Copy, Clone, Pod, Zeroable)]
|
#[derive(Copy, Clone, Pod, Zeroable)]
|
||||||
#[repr(C)]
|
#[repr(C)]
|
||||||
@ -64,6 +64,14 @@ struct TransparentWithZeroSized<T> {
|
|||||||
b: PhantomData<T>,
|
b: PhantomData<T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct MyZst<T>(PhantomData<T>, [u8; 0], PhantomPinned);
|
||||||
|
unsafe impl<T> Zeroable for MyZst<T> {}
|
||||||
|
|
||||||
|
#[derive(TransparentWrapper)]
|
||||||
|
#[repr(transparent)]
|
||||||
|
#[transparent(u16)]
|
||||||
|
struct TransparentTupleWithCustomZeroSized<T>(u16, MyZst<T>);
|
||||||
|
|
||||||
#[repr(u8)]
|
#[repr(u8)]
|
||||||
#[derive(Clone, Copy, Contiguous)]
|
#[derive(Clone, Copy, Contiguous)]
|
||||||
enum ContiguousWithValues {
|
enum ContiguousWithValues {
|
||||||
@ -169,6 +177,21 @@ struct AnyBitPatternTest {
|
|||||||
#[repr(transparent)]
|
#[repr(transparent)]
|
||||||
struct NewtypeWrapperTest<T>(T);
|
struct NewtypeWrapperTest<T>(T);
|
||||||
|
|
||||||
|
/// ```compile_fail
|
||||||
|
/// use bytemuck::TransparentWrapper;
|
||||||
|
///
|
||||||
|
/// struct NonTransparentSafeZST;
|
||||||
|
///
|
||||||
|
/// #[derive(TransparentWrapper)]
|
||||||
|
/// #[repr(transparent)]
|
||||||
|
/// struct Wrapper<T>(T, NonTransparentSafeZST);
|
||||||
|
/// ```
|
||||||
|
#[derive(
|
||||||
|
Debug, Copy, Clone, PartialEq, Eq, Pod, Zeroable, TransparentWrapper,
|
||||||
|
)]
|
||||||
|
#[repr(transparent)]
|
||||||
|
struct TransarentWrapperZstTest<T>(T);
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn fails_cast_contiguous() {
|
fn fails_cast_contiguous() {
|
||||||
let can_cast = CheckedBitPatternEnumWithValues::is_valid_bit_pattern(&5);
|
let can_cast = CheckedBitPatternEnumWithValues::is_valid_bit_pattern(&5);
|
||||||
@ -207,7 +230,14 @@ fn fails_cast_bytelit() {
|
|||||||
fn passes_cast_bytelit() {
|
fn passes_cast_bytelit() {
|
||||||
let res =
|
let res =
|
||||||
bytemuck::checked::cast_slice::<u8, CheckedBitPatternEnumByteLit>(b"CAB");
|
bytemuck::checked::cast_slice::<u8, CheckedBitPatternEnumByteLit>(b"CAB");
|
||||||
assert_eq!(res, [CheckedBitPatternEnumByteLit::C, CheckedBitPatternEnumByteLit::A, CheckedBitPatternEnumByteLit::B]);
|
assert_eq!(
|
||||||
|
res,
|
||||||
|
[
|
||||||
|
CheckedBitPatternEnumByteLit::C,
|
||||||
|
CheckedBitPatternEnumByteLit::A,
|
||||||
|
CheckedBitPatternEnumByteLit::B
|
||||||
|
]
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -23,7 +23,9 @@ use super::*;
|
|||||||
/// the only non-ZST field.
|
/// the only non-ZST field.
|
||||||
///
|
///
|
||||||
/// 2. Any fields *other* than the `Inner` field must be trivially constructable
|
/// 2. Any fields *other* than the `Inner` field must be trivially constructable
|
||||||
/// ZSTs, for example `PhantomData`, `PhantomPinned`, etc.
|
/// ZSTs, for example `PhantomData`, `PhantomPinned`, etc. (When deriving
|
||||||
|
/// `TransparentWrapper` on a type with ZST fields, the ZST fields must be
|
||||||
|
/// [`Zeroable`]).
|
||||||
///
|
///
|
||||||
/// 3. The `Wrapper` may not impose additional alignment requirements over
|
/// 3. The `Wrapper` may not impose additional alignment requirements over
|
||||||
/// `Inner`.
|
/// `Inner`.
|
||||||
@ -84,6 +86,43 @@ use super::*;
|
|||||||
/// let mut buf = [1, 2, 3u8];
|
/// let mut buf = [1, 2, 3u8];
|
||||||
/// let sm = Slice::wrap_mut(&mut buf);
|
/// let sm = Slice::wrap_mut(&mut buf);
|
||||||
/// ```
|
/// ```
|
||||||
|
///
|
||||||
|
/// ## Deriving
|
||||||
|
///
|
||||||
|
/// When deriving, the non-wrapped fields must uphold all the normal requirements,
|
||||||
|
/// and must also be `Zeroable`.
|
||||||
|
///
|
||||||
|
#[cfg_attr(feature = "derive", doc = "```")]
|
||||||
|
#[cfg_attr(
|
||||||
|
not(feature = "derive"),
|
||||||
|
doc = "```ignore
|
||||||
|
// This example requires the `derive` feature."
|
||||||
|
)]
|
||||||
|
/// use bytemuck::TransparentWrapper;
|
||||||
|
/// use std::marker::PhantomData;
|
||||||
|
///
|
||||||
|
/// #[derive(TransparentWrapper)]
|
||||||
|
/// #[repr(transparent)]
|
||||||
|
/// #[transparent(usize)]
|
||||||
|
/// struct Wrapper<T: ?Sized>(usize, PhantomData<T>); // PhantomData<T> implements Zeroable for all T
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// Here, an error will occur, because `MyZst` does not implement `Zeroable`.
|
||||||
|
///
|
||||||
|
#[cfg_attr(feature = "derive", doc = "```compile_fail")]
|
||||||
|
#[cfg_attr(
|
||||||
|
not(feature = "derive"),
|
||||||
|
doc = "```ignore
|
||||||
|
// This example requires the `derive` feature."
|
||||||
|
)]
|
||||||
|
/// use bytemuck::TransparentWrapper;
|
||||||
|
/// struct MyZst;
|
||||||
|
///
|
||||||
|
/// #[derive(TransparentWrapper)]
|
||||||
|
/// #[repr(transparent)]
|
||||||
|
/// #[transparent(usize)]
|
||||||
|
/// struct Wrapper(usize, MyZst); // MyZst does not implement Zeroable
|
||||||
|
/// ```
|
||||||
pub unsafe trait TransparentWrapper<Inner: ?Sized> {
|
pub unsafe trait TransparentWrapper<Inner: ?Sized> {
|
||||||
/// Convert the inner type into the wrapper type.
|
/// Convert the inner type into the wrapper type.
|
||||||
#[inline]
|
#[inline]
|
||||||
|
@ -64,7 +64,7 @@ unsafe impl<T> Zeroable for *const [T] {}
|
|||||||
unsafe impl Zeroable for *mut str {}
|
unsafe impl Zeroable for *mut str {}
|
||||||
unsafe impl Zeroable for *const str {}
|
unsafe impl Zeroable for *const str {}
|
||||||
|
|
||||||
unsafe impl<T: Zeroable> Zeroable for PhantomData<T> {}
|
unsafe impl<T: ?Sized> Zeroable for PhantomData<T> {}
|
||||||
unsafe impl Zeroable for PhantomPinned {}
|
unsafe impl Zeroable for PhantomPinned {}
|
||||||
unsafe impl<T: Zeroable> Zeroable for ManuallyDrop<T> {}
|
unsafe impl<T: Zeroable> Zeroable for ManuallyDrop<T> {}
|
||||||
unsafe impl<T: Zeroable> Zeroable for core::cell::UnsafeCell<T> {}
|
unsafe impl<T: Zeroable> Zeroable for core::cell::UnsafeCell<T> {}
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
#![allow(dead_code)]
|
#![allow(dead_code)]
|
||||||
|
|
||||||
use bytemuck::{ByteEq, ByteHash, Pod, TransparentWrapper, Zeroable};
|
use bytemuck::{ByteEq, ByteHash, Pod, TransparentWrapper, Zeroable};
|
||||||
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
#[derive(Copy, Clone, Pod, Zeroable, ByteEq, ByteHash)]
|
#[derive(Copy, Clone, Pod, Zeroable, ByteEq, ByteHash)]
|
||||||
#[repr(C)]
|
#[repr(C)]
|
||||||
@ -26,7 +27,7 @@ struct TransparentWithZeroSized {
|
|||||||
|
|
||||||
#[derive(TransparentWrapper)]
|
#[derive(TransparentWrapper)]
|
||||||
#[repr(transparent)]
|
#[repr(transparent)]
|
||||||
struct TransparentWithGeneric<T> {
|
struct TransparentWithGeneric<T: ?Sized> {
|
||||||
a: T,
|
a: T,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -39,9 +40,9 @@ fn test_generic<T>(x: T) -> TransparentWithGeneric<T> {
|
|||||||
#[derive(TransparentWrapper)]
|
#[derive(TransparentWrapper)]
|
||||||
#[repr(transparent)]
|
#[repr(transparent)]
|
||||||
#[transparent(T)]
|
#[transparent(T)]
|
||||||
struct TransparentWithGenericAndZeroSized<T> {
|
struct TransparentWithGenericAndZeroSized<T: ?Sized> {
|
||||||
a: T,
|
a: (),
|
||||||
b: ()
|
b: T,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Ensuring that no additional bounds are emitted.
|
/// Ensuring that no additional bounds are emitted.
|
||||||
@ -49,3 +50,28 @@ struct TransparentWithGenericAndZeroSized<T> {
|
|||||||
fn test_generic_with_zst<T>(x: T) -> TransparentWithGenericAndZeroSized<T> {
|
fn test_generic_with_zst<T>(x: T) -> TransparentWithGenericAndZeroSized<T> {
|
||||||
TransparentWithGenericAndZeroSized::wrap(x)
|
TransparentWithGenericAndZeroSized::wrap(x)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(TransparentWrapper)]
|
||||||
|
#[repr(transparent)]
|
||||||
|
struct TransparentUnsized {
|
||||||
|
a: dyn std::fmt::Debug,
|
||||||
|
}
|
||||||
|
|
||||||
|
type DynDebug = dyn std::fmt::Debug;
|
||||||
|
|
||||||
|
#[derive(TransparentWrapper)]
|
||||||
|
#[repr(transparent)]
|
||||||
|
#[transparent(DynDebug)]
|
||||||
|
struct TransparentUnsizedWithZeroSized {
|
||||||
|
a: (),
|
||||||
|
b: DynDebug,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(TransparentWrapper)]
|
||||||
|
#[repr(transparent)]
|
||||||
|
#[transparent(DynDebug)]
|
||||||
|
struct TransparentUnsizedWithGenericZeroSizeds<T: ?Sized, U: ?Sized> {
|
||||||
|
a: PhantomData<T>,
|
||||||
|
b: PhantomData<U>,
|
||||||
|
c: DynDebug,
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user