From 09dd2ffd682182294a09d5874b0e255d699d833f Mon Sep 17 00:00:00 2001 From: zachs18 <8355914+zachs18@users.noreply.github.com> Date: Thu, 1 Sep 2022 18:23:28 -0500 Subject: [PATCH] Add `cast_{arc,rc}` (and slice and try), and `{wrap,peel}_{arc,rc}`. (#132) * Add `allocation::{try_,}cast_{arc,rc}`, and add `{wrap,peel}_{arc,rc}` to `TransparentWrapperAlloc`. * Avoid intermediate slice reference in `try_cast_slice_{arc,rc}`. * remove `unsafe` block; run `cargo +nightly fmt` (ignoring files I didn't modify) * Make `cast_rc` (etc) have the same bounds as `cast_mut`, due to the existence of `Rc::get_mut_unchecked`. --- src/allocation.rs | 289 +++++++++++++++++++++++++++++++++++++++++++ tests/transparent.rs | 15 +++ 2 files changed, 304 insertions(+) diff --git a/src/allocation.rs b/src/allocation.rs index 93664bf..8c29864 100644 --- a/src/allocation.rs +++ b/src/allocation.rs @@ -13,6 +13,8 @@ use super::*; use alloc::{ alloc::{alloc_zeroed, Layout}, boxed::Box, + rc::Rc, + sync::Arc, vec, vec::Vec, }; @@ -315,6 +317,205 @@ pub fn pod_collect_to_vec< dst } +/// As [`try_cast_rc`](try_cast_rc), but unwraps for you. +#[inline] +pub fn cast_rc( + input: Rc, +) -> Rc { + try_cast_rc(input).map_err(|(e, _v)| e).unwrap() +} + +/// Attempts to cast the content type of a [`Rc`](alloc::rc::Rc). +/// +/// On failure you get back an error along with the starting `Rc`. +/// +/// The bounds on this function are the same as [`cast_mut`], because a user +/// could call `Rc::get_unchecked_mut` on the output, which could be observable +/// in the input. +/// +/// ## Failure +/// +/// * The start and end content type of the `Rc` must have the exact same +/// alignment. +/// * The start and end size of the `Rc` must have the exact same size. +#[inline] +pub fn try_cast_rc( + input: Rc, +) -> Result, (PodCastError, Rc)> { + if align_of::() != align_of::() { + Err((PodCastError::AlignmentMismatch, input)) + } else if size_of::() != size_of::() { + Err((PodCastError::SizeMismatch, input)) + } else { + // Safety: Rc::from_raw requires size and alignment match, which is met. + let ptr: *const B = Rc::into_raw(input) as *const B; + Ok(unsafe { Rc::from_raw(ptr) }) + } +} + +/// As [`try_cast_arc`](try_cast_arc), but unwraps for you. +#[inline] +pub fn cast_arc( + input: Arc, +) -> Arc { + try_cast_arc(input).map_err(|(e, _v)| e).unwrap() +} + +/// Attempts to cast the content type of a [`Arc`](alloc::sync::Arc). +/// +/// On failure you get back an error along with the starting `Arc`. +/// +/// The bounds on this function are the same as [`cast_mut`], because a user +/// could call `Rc::get_unchecked_mut` on the output, which could be observable +/// in the input. +/// +/// ## Failure +/// +/// * The start and end content type of the `Arc` must have the exact same +/// alignment. +/// * The start and end size of the `Arc` must have the exact same size. +#[inline] +pub fn try_cast_arc< + A: NoUninit + AnyBitPattern, + B: NoUninit + AnyBitPattern, +>( + input: Arc, +) -> Result, (PodCastError, Arc)> { + if align_of::() != align_of::() { + Err((PodCastError::AlignmentMismatch, input)) + } else if size_of::() != size_of::() { + Err((PodCastError::SizeMismatch, input)) + } else { + // Safety: Arc::from_raw requires size and alignment match, which is met. + let ptr: *const B = Arc::into_raw(input) as *const B; + Ok(unsafe { Arc::from_raw(ptr) }) + } +} + +/// As [`try_cast_slice_rc`](try_cast_slice_rc), but unwraps for you. +#[inline] +pub fn cast_slice_rc< + A: NoUninit + AnyBitPattern, + B: NoUninit + AnyBitPattern, +>( + input: Rc<[A]>, +) -> Rc<[B]> { + try_cast_slice_rc(input).map_err(|(e, _v)| e).unwrap() +} + +/// Attempts to cast the content type of a `Rc<[T]>`. +/// +/// On failure you get back an error along with the starting `Rc<[T]>`. +/// +/// The bounds on this function are the same as [`cast_mut`], because a user +/// could call `Rc::get_unchecked_mut` on the output, which could be observable +/// in the input. +/// +/// ## Failure +/// +/// * The start and end content type of the `Rc<[T]>` must have the exact same +/// alignment. +/// * The start and end content size in bytes of the `Rc<[T]>` must be the exact +/// same. +#[inline] +pub fn try_cast_slice_rc< + A: NoUninit + AnyBitPattern, + B: NoUninit + AnyBitPattern, +>( + input: Rc<[A]>, +) -> Result, (PodCastError, Rc<[A]>)> { + if align_of::() != align_of::() { + Err((PodCastError::AlignmentMismatch, input)) + } else if size_of::() != size_of::() { + if size_of::() * input.len() % size_of::() != 0 { + // If the size in bytes of the underlying buffer does not match an exact + // multiple of the size of B, we cannot cast between them. + Err((PodCastError::SizeMismatch, input)) + } else { + // Because the size is an exact multiple, we can now change the length + // of the slice and recreate the Rc + // NOTE: This is a valid operation because according to the docs of + // std::rc::Rc::from_raw(), the type U that was in the original Rc + // acquired from Rc::into_raw() must have the same size alignment and + // size of the type T in the new Rc. So as long as both the size + // and alignment stay the same, the Rc will remain a valid Rc. + let length = size_of::() * input.len() / size_of::(); + let rc_ptr: *const A = Rc::into_raw(input) as *const A; + // Must use ptr::slice_from_raw_parts, because we cannot make an + // intermediate const reference, because it has mutable provenance, + // nor an intermediate mutable reference, because it could be aliased. + let ptr = core::ptr::slice_from_raw_parts(rc_ptr as *const B, length); + Ok(unsafe { Rc::<[B]>::from_raw(ptr) }) + } + } else { + let rc_ptr: *const [A] = Rc::into_raw(input); + let ptr: *const [B] = rc_ptr as *const [B]; + Ok(unsafe { Rc::<[B]>::from_raw(ptr) }) + } +} + +/// As [`try_cast_slice_arc`](try_cast_slice_arc), but unwraps for you. +#[inline] +pub fn cast_slice_arc< + A: NoUninit + AnyBitPattern, + B: NoUninit + AnyBitPattern, +>( + input: Arc<[A]>, +) -> Arc<[B]> { + try_cast_slice_arc(input).map_err(|(e, _v)| e).unwrap() +} + +/// Attempts to cast the content type of a `Arc<[T]>`. +/// +/// On failure you get back an error along with the starting `Arc<[T]>`. +/// +/// The bounds on this function are the same as [`cast_mut`], because a user +/// could call `Rc::get_unchecked_mut` on the output, which could be observable +/// in the input. +/// +/// ## Failure +/// +/// * The start and end content type of the `Arc<[T]>` must have the exact same +/// alignment. +/// * The start and end content size in bytes of the `Arc<[T]>` must be the +/// exact same. +#[inline] +pub fn try_cast_slice_arc< + A: NoUninit + AnyBitPattern, + B: NoUninit + AnyBitPattern, +>( + input: Arc<[A]>, +) -> Result, (PodCastError, Arc<[A]>)> { + if align_of::() != align_of::() { + Err((PodCastError::AlignmentMismatch, input)) + } else if size_of::() != size_of::() { + if size_of::() * input.len() % size_of::() != 0 { + // If the size in bytes of the underlying buffer does not match an exact + // multiple of the size of B, we cannot cast between them. + Err((PodCastError::SizeMismatch, input)) + } else { + // Because the size is an exact multiple, we can now change the length + // of the slice and recreate the Arc + // NOTE: This is a valid operation because according to the docs of + // std::sync::Arc::from_raw(), the type U that was in the original Arc + // acquired from Arc::into_raw() must have the same size alignment and + // size of the type T in the new Arc. So as long as both the size + // and alignment stay the same, the Arc will remain a valid Arc. + let length = size_of::() * input.len() / size_of::(); + let arc_ptr: *const A = Arc::into_raw(input) as *const A; + // Must use ptr::slice_from_raw_parts, because we cannot make an + // intermediate const reference, because it has mutable provenance, + // nor an intermediate mutable reference, because it could be aliased. + let ptr = core::ptr::slice_from_raw_parts(arc_ptr as *const B, length); + Ok(unsafe { Arc::<[B]>::from_raw(ptr) }) + } + } else { + let arc_ptr: *const [A] = Arc::into_raw(input); + let ptr: *const [B] = arc_ptr as *const [B]; + Ok(unsafe { Arc::<[B]>::from_raw(ptr) }) + } +} + /// An extension trait for `TransparentWrapper` and alloc types. pub trait TransparentWrapperAlloc: TransparentWrapper @@ -364,6 +565,50 @@ pub trait TransparentWrapperAlloc: } } + /// Convert an [`Rc`](alloc::rc::Rc) to the inner type into an `Rc` to the + /// wrapper type. + #[inline] + fn wrap_rc(s: Rc) -> Rc { + assert!(size_of::<*mut Inner>() == size_of::<*mut Self>()); + + unsafe { + // A pointer cast doesn't work here because rustc can't tell that + // the vtables match (because of the `?Sized` restriction relaxation). + // A `transmute` doesn't work because the layout of Rc is unspecified. + // + // SAFETY: + // * The unsafe contract requires that pointers to Inner and Self have + // identical representations, and that the size and alignment of Inner + // and Self are the same, which meets the safety requirements of + // Rc::from_raw + let inner_ptr: *const Inner = Rc::into_raw(s); + let wrapper_ptr: *const Self = transmute!(inner_ptr); + Rc::from_raw(wrapper_ptr) + } + } + + /// Convert an [`Arc`](alloc::sync::Arc) to the inner type into an `Arc` to + /// the wrapper type. + #[inline] + fn wrap_arc(s: Arc) -> Arc { + assert!(size_of::<*mut Inner>() == size_of::<*mut Self>()); + + unsafe { + // A pointer cast doesn't work here because rustc can't tell that + // the vtables match (because of the `?Sized` restriction relaxation). + // A `transmute` doesn't work because the layout of Arc is unspecified. + // + // SAFETY: + // * The unsafe contract requires that pointers to Inner and Self have + // identical representations, and that the size and alignment of Inner + // and Self are the same, which meets the safety requirements of + // Arc::from_raw + let inner_ptr: *const Inner = Arc::into_raw(s); + let wrapper_ptr: *const Self = transmute!(inner_ptr); + Arc::from_raw(wrapper_ptr) + } + } + /// Convert a vec of the wrapper type into a vec of the inner type. fn peel_vec(s: Vec) -> Vec where @@ -408,5 +653,49 @@ pub trait TransparentWrapperAlloc: Box::from_raw(inner_ptr) } } + + /// Convert an [`Rc`](alloc::rc::Rc) to the wrapper type into an `Rc` to the + /// inner type. + #[inline] + fn peel_rc(s: Rc) -> Rc { + assert!(size_of::<*mut Inner>() == size_of::<*mut Self>()); + + unsafe { + // A pointer cast doesn't work here because rustc can't tell that + // the vtables match (because of the `?Sized` restriction relaxation). + // A `transmute` doesn't work because the layout of Rc is unspecified. + // + // SAFETY: + // * The unsafe contract requires that pointers to Inner and Self have + // identical representations, and that the size and alignment of Inner + // and Self are the same, which meets the safety requirements of + // Rc::from_raw + let wrapper_ptr: *const Self = Rc::into_raw(s); + let inner_ptr: *const Inner = transmute!(wrapper_ptr); + Rc::from_raw(inner_ptr) + } + } + + /// Convert an [`Arc`](alloc::sync::Arc) to the wrapper type into an `Arc` to + /// the inner type. + #[inline] + fn peel_arc(s: Arc) -> Arc { + assert!(size_of::<*mut Inner>() == size_of::<*mut Self>()); + + unsafe { + // A pointer cast doesn't work here because rustc can't tell that + // the vtables match (because of the `?Sized` restriction relaxation). + // A `transmute` doesn't work because the layout of Arc is unspecified. + // + // SAFETY: + // * The unsafe contract requires that pointers to Inner and Self have + // identical representations, and that the size and alignment of Inner + // and Self are the same, which meets the safety requirements of + // Arc::from_raw + let wrapper_ptr: *const Self = Arc::into_raw(s); + let inner_ptr: *const Inner = transmute!(wrapper_ptr); + Arc::from_raw(inner_ptr) + } + } } impl> TransparentWrapperAlloc for T {} diff --git a/tests/transparent.rs b/tests/transparent.rs index 9bcbac8..78d4fe1 100644 --- a/tests/transparent.rs +++ b/tests/transparent.rs @@ -77,6 +77,7 @@ fn test_transparent_wrapper() { #[cfg(feature = "extern_crate_alloc")] { use bytemuck::allocation::TransparentWrapperAlloc; + use std::{rc::Rc, sync::Arc}; let a: Vec = vec![Foreign::default(); 2]; @@ -92,5 +93,19 @@ fn test_transparent_wrapper() { assert_eq!(&*e, &0); let f: Box = Wrapper::peel_box(e); assert_eq!(&*f, &0); + + let g: Rc = Rc::new(Foreign::default()); + + let h: Rc = Wrapper::wrap_rc(g); + assert_eq!(&*h, &0); + let i: Rc = Wrapper::peel_rc(h); + assert_eq!(&*i, &0); + + let j: Arc = Arc::new(Foreign::default()); + + let k: Arc = Wrapper::wrap_arc(j); + assert_eq!(&*k, &0); + let l: Arc = Wrapper::peel_arc(k); + assert_eq!(&*l, &0); } }