diff --git a/crates/spirv-std/src/arch.rs b/crates/spirv-std/src/arch.rs index 2319751dac..4fc6a28f67 100644 --- a/crates/spirv-std/src/arch.rs +++ b/crates/spirv-std/src/arch.rs @@ -245,3 +245,92 @@ pub fn signed_min(a: T, b: T) -> T { pub fn signed_max(a: T, b: T) -> T { unsafe { call_glsl_op_with_ints::<_, 42>(a, b) } } + +/// Index into an array without bounds checking. +/// +/// The main purpose of this trait is to work around the fact that the regular `get_unchecked*` +/// methods do not work in in SPIR-V. +pub trait IndexUnchecked { + /// Returns a reference to the element at `index`. The equivalent of `get_unchecked`. + /// + /// # Safety + /// Behavior is undefined if the `index` value is greater than or equal to the length of the array. + unsafe fn index_unchecked(&self, index: usize) -> &T; + /// Returns a mutable reference to the element at `index`. The equivalent of `get_unchecked_mut`. + /// + /// # Safety + /// Behavior is undefined if the `index` value is greater than or equal to the length of the array. + unsafe fn index_unchecked_mut(&mut self, index: usize) -> &mut T; +} + +impl IndexUnchecked for [T] { + #[cfg(target_arch = "spirv")] + unsafe fn index_unchecked(&self, index: usize) -> &T { + asm!( + "%slice_ptr = OpLoad _ {slice_ptr_ptr}", + "%data_ptr = OpCompositeExtract _ %slice_ptr 0", + "%val_ptr = OpAccessChain _ %data_ptr {index}", + "OpReturnValue %val_ptr", + slice_ptr_ptr = in(reg) &self, + index = in(reg) index, + options(noreturn) + ) + } + + #[cfg(not(target_arch = "spirv"))] + unsafe fn index_unchecked(&self, index: usize) -> &T { + self.get_unchecked(index) + } + + #[cfg(target_arch = "spirv")] + unsafe fn index_unchecked_mut(&mut self, index: usize) -> &mut T { + asm!( + "%slice_ptr = OpLoad _ {slice_ptr_ptr}", + "%data_ptr = OpCompositeExtract _ %slice_ptr 0", + "%val_ptr = OpAccessChain _ %data_ptr {index}", + "OpReturnValue %val_ptr", + slice_ptr_ptr = in(reg) &self, + index = in(reg) index, + options(noreturn) + ) + } + + #[cfg(not(target_arch = "spirv"))] + unsafe fn index_unchecked_mut(&mut self, index: usize) -> &mut T { + self.get_unchecked_mut(index) + } +} + +impl IndexUnchecked for [T; N] { + #[cfg(target_arch = "spirv")] + unsafe fn index_unchecked(&self, index: usize) -> &T { + asm!( + "%val_ptr = OpAccessChain _ {array_ptr} {index}", + "OpReturnValue %val_ptr", + array_ptr = in(reg) self, + index = in(reg) index, + options(noreturn) + ) + } + + #[cfg(not(target_arch = "spirv"))] + unsafe fn index_unchecked(&self, index: usize) -> &T { + self.get_unchecked(index) + } + + #[cfg(target_arch = "spirv")] + unsafe fn index_unchecked_mut(&mut self, index: usize) -> &mut T { + asm!( + "%val_ptr = OpAccessChain _ {array_ptr} {index}", + "OpReturnValue %val_ptr", + array_ptr = in(reg) self, + index = in(reg) index, + options(noreturn) + ) + } + + #[cfg(not(target_arch = "spirv"))] + unsafe fn index_unchecked_mut(&mut self, index: usize) -> &mut T { + self.get_unchecked_mut(index) + } +} diff --git a/tests/ui/arch/index_unchecked.rs b/tests/ui/arch/index_unchecked.rs new file mode 100644 index 0000000000..31e94e948a --- /dev/null +++ b/tests/ui/arch/index_unchecked.rs @@ -0,0 +1,14 @@ +// build-pass + +use spirv_std::arch::IndexUnchecked; + +#[spirv(fragment)] +pub fn main( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] runtime_array: &mut [u32], + #[spirv(descriptor_set = 1, binding = 1, storage_buffer)] array: &mut [u32; 5], +) { + unsafe { + *runtime_array.index_unchecked_mut(0) = *array.index_unchecked(0); + *array.index_unchecked_mut(1) = *runtime_array.index_unchecked(1); + } +}