Add an IndexUnchecked trait that uses asm! (#805)

* Add an IndexUnchecked trait to spirv-std/arch

* Slap some #[spirv_std_macros::gpu_only] on there

* Spelling

* Add safety sections to the docs

* Improve documentation, implement for non-spirv targets
This commit is contained in:
Ashley 2021-11-28 22:56:13 -08:00 committed by GitHub
parent 6232d95256
commit 9fd3c9e172
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 103 additions and 0 deletions

View File

@ -245,3 +245,92 @@ pub fn signed_min<T: SignedInteger>(a: T, b: T) -> T {
pub fn signed_max<T: SignedInteger>(a: T, b: T) -> T {
unsafe { call_glsl_op_with_ints::<_, 42>(a, b) }
}
/// Index into an array without bounds checking.
///
/// The main purpose of this trait is to work around the fact that the regular `get_unchecked*`
/// methods do not work in in SPIR-V.
pub trait IndexUnchecked<T> {
/// Returns a reference to the element at `index`. The equivalent of `get_unchecked`.
///
/// # Safety
/// Behavior is undefined if the `index` value is greater than or equal to the length of the array.
unsafe fn index_unchecked(&self, index: usize) -> &T;
/// Returns a mutable reference to the element at `index`. The equivalent of `get_unchecked_mut`.
///
/// # Safety
/// Behavior is undefined if the `index` value is greater than or equal to the length of the array.
unsafe fn index_unchecked_mut(&mut self, index: usize) -> &mut T;
}
impl<T> IndexUnchecked<T> for [T] {
#[cfg(target_arch = "spirv")]
unsafe fn index_unchecked(&self, index: usize) -> &T {
asm!(
"%slice_ptr = OpLoad _ {slice_ptr_ptr}",
"%data_ptr = OpCompositeExtract _ %slice_ptr 0",
"%val_ptr = OpAccessChain _ %data_ptr {index}",
"OpReturnValue %val_ptr",
slice_ptr_ptr = in(reg) &self,
index = in(reg) index,
options(noreturn)
)
}
#[cfg(not(target_arch = "spirv"))]
unsafe fn index_unchecked(&self, index: usize) -> &T {
self.get_unchecked(index)
}
#[cfg(target_arch = "spirv")]
unsafe fn index_unchecked_mut(&mut self, index: usize) -> &mut T {
asm!(
"%slice_ptr = OpLoad _ {slice_ptr_ptr}",
"%data_ptr = OpCompositeExtract _ %slice_ptr 0",
"%val_ptr = OpAccessChain _ %data_ptr {index}",
"OpReturnValue %val_ptr",
slice_ptr_ptr = in(reg) &self,
index = in(reg) index,
options(noreturn)
)
}
#[cfg(not(target_arch = "spirv"))]
unsafe fn index_unchecked_mut(&mut self, index: usize) -> &mut T {
self.get_unchecked_mut(index)
}
}
impl<T, const N: usize> IndexUnchecked<T> for [T; N] {
#[cfg(target_arch = "spirv")]
unsafe fn index_unchecked(&self, index: usize) -> &T {
asm!(
"%val_ptr = OpAccessChain _ {array_ptr} {index}",
"OpReturnValue %val_ptr",
array_ptr = in(reg) self,
index = in(reg) index,
options(noreturn)
)
}
#[cfg(not(target_arch = "spirv"))]
unsafe fn index_unchecked(&self, index: usize) -> &T {
self.get_unchecked(index)
}
#[cfg(target_arch = "spirv")]
unsafe fn index_unchecked_mut(&mut self, index: usize) -> &mut T {
asm!(
"%val_ptr = OpAccessChain _ {array_ptr} {index}",
"OpReturnValue %val_ptr",
array_ptr = in(reg) self,
index = in(reg) index,
options(noreturn)
)
}
#[cfg(not(target_arch = "spirv"))]
unsafe fn index_unchecked_mut(&mut self, index: usize) -> &mut T {
self.get_unchecked_mut(index)
}
}

View File

@ -0,0 +1,14 @@
// build-pass
use spirv_std::arch::IndexUnchecked;
#[spirv(fragment)]
pub fn main(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] runtime_array: &mut [u32],
#[spirv(descriptor_set = 1, binding = 1, storage_buffer)] array: &mut [u32; 5],
) {
unsafe {
*runtime_array.index_unchecked_mut(0) = *array.index_unchecked(0);
*array.index_unchecked_mut(1) = *runtime_array.index_unchecked(1);
}
}