mirror of
https://github.com/EmbarkStudios/rust-gpu.git
synced 2024-11-25 00:04:11 +00:00
Fix issues with some raytracing functions and add tests
This commit is contained in:
parent
2d5b8e6b0e
commit
0bd0bf0f11
@ -1,5 +1,5 @@
|
||||
use super::Builder;
|
||||
use crate::builder_spirv::{BuilderCursor, SpirvValue};
|
||||
use crate::builder_spirv::{BuilderCursor, SpirvValue, SpirvValueExt};
|
||||
use crate::codegen_cx::CodegenCx;
|
||||
use crate::spirv_type::SpirvType;
|
||||
use rspirv::dr;
|
||||
@ -304,8 +304,26 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
|
||||
}
|
||||
.def(self.span(), self),
|
||||
Op::TypeArray => {
|
||||
self.err("OpTypeArray in asm! is not supported yet");
|
||||
return;
|
||||
let count = inst.operands[1].unwrap_id_ref();
|
||||
let get_count_ty = || -> Option<Word> {
|
||||
let emit = self.emit();
|
||||
let func = &emit.module_ref().functions[emit.selected_function()?];
|
||||
let insts = &func.blocks[emit.selected_block()?].instructions;
|
||||
let inst = insts.iter().find(|i| i.result_id == Some(count))?;
|
||||
inst.result_type
|
||||
};
|
||||
let count_ty = match get_count_ty() {
|
||||
Some(ty) => ty,
|
||||
None => {
|
||||
self.err("Unable to find constant for OpTypeArray count");
|
||||
return;
|
||||
}
|
||||
};
|
||||
SpirvType::Array {
|
||||
element: inst.operands[0].unwrap_id_ref(),
|
||||
count: count.with_type(count_ty),
|
||||
}
|
||||
.def(self.span(), self)
|
||||
}
|
||||
Op::TypeRuntimeArray => SpirvType::RuntimeArray {
|
||||
element: inst.operands[0].unwrap_id_ref(),
|
||||
|
@ -579,18 +579,16 @@ impl RayQuery {
|
||||
#[spirv_std_macros::gpu_only]
|
||||
#[doc(alias = "OpRayQueryGetWorldRayDirectionKHR")]
|
||||
#[inline]
|
||||
pub unsafe fn get_world_ray_direction<V: Vector<f32, 3>, const INTERSECTION: u32>(&self) -> V {
|
||||
pub unsafe fn get_world_ray_direction<V: Vector<f32, 3>>(&self) -> V {
|
||||
let mut result = Default::default();
|
||||
|
||||
asm! {
|
||||
"%u32 = OpTypeInt 32 0",
|
||||
"%f32 = OpTypeFloat 32",
|
||||
"%f32x3 = OpTypeVector %f32 3",
|
||||
"%intersection = OpConstant %u32 {intersection}",
|
||||
"%result = OpRayQueryGetWorldRayDirectionKHR %f32x3 {ray_query} %intersection",
|
||||
"%result = OpRayQueryGetWorldRayDirectionKHR %f32x3 {ray_query}",
|
||||
"OpStore {result} %result",
|
||||
ray_query = in(reg) self,
|
||||
intersection = const INTERSECTION,
|
||||
result = in(reg) &mut result,
|
||||
}
|
||||
|
||||
@ -601,18 +599,16 @@ impl RayQuery {
|
||||
#[spirv_std_macros::gpu_only]
|
||||
#[doc(alias = "OpRayQueryGetWorldRayOriginKHR")]
|
||||
#[inline]
|
||||
pub unsafe fn get_world_ray_origin<V: Vector<f32, 3>, const INTERSECTION: u32>(&self) -> V {
|
||||
pub unsafe fn get_world_ray_origin<V: Vector<f32, 3>>(&self) -> V {
|
||||
let mut result = Default::default();
|
||||
|
||||
asm! {
|
||||
"%u32 = OpTypeInt 32 0",
|
||||
"%f32 = OpTypeFloat 32",
|
||||
"%f32x3 = OpTypeVector %f32 3",
|
||||
"%intersection = OpConstant %u32 {intersection}",
|
||||
"%result = OpRayQueryGetWorldRayOriginKHR %f32x3 {ray_query} %intersection",
|
||||
"%result = OpRayQueryGetWorldRayOriginKHR %f32x3 {ray_query}",
|
||||
"OpStore {result} %result",
|
||||
ray_query = in(reg) self,
|
||||
intersection = const INTERSECTION,
|
||||
result = in(reg) &mut result,
|
||||
}
|
||||
|
||||
@ -626,15 +622,22 @@ impl RayQuery {
|
||||
#[inline]
|
||||
pub unsafe fn get_intersection_object_to_world<V: Vector<f32, 3>, const INTERSECTION: u32>(
|
||||
&self,
|
||||
) -> V {
|
||||
) -> [V; 4] {
|
||||
let mut result = Default::default();
|
||||
|
||||
asm! {
|
||||
"%u32 = OpTypeInt 32 0",
|
||||
"%f32 = OpTypeFloat 32",
|
||||
"%four = OpConstant %u32 4",
|
||||
"%f32x3 = OpTypeVector %f32 3",
|
||||
"%f32x3x4 = OpTypeMatrix %f32x3 4",
|
||||
"%intersection = OpConstant %u32 {intersection}",
|
||||
"%result = OpRayQueryGetWorldRayOriginKHR %f32x3 {ray_query} %intersection",
|
||||
"%matrix = OpRayQueryGetIntersectionObjectToWorldKHR %f32x3x4 {ray_query} %intersection",
|
||||
"%col0 = OpCompositeExtract %f32x3 %matrix 0",
|
||||
"%col1 = OpCompositeExtract %f32x3 %matrix 1",
|
||||
"%col2 = OpCompositeExtract %f32x3 %matrix 2",
|
||||
"%col3 = OpCompositeExtract %f32x3 %matrix 3",
|
||||
"%result = OpCompositeConstruct typeof*{result} %col0 %col1 %col2 %col3",
|
||||
"OpStore {result} %result",
|
||||
ray_query = in(reg) self,
|
||||
intersection = const INTERSECTION,
|
||||
|
@ -0,0 +1,14 @@
|
||||
// build-pass
|
||||
// compile-flags: -Ctarget-feature=+RayQueryKHR,+ext:SPV_KHR_ray_query
|
||||
|
||||
use glam::Vec3;
|
||||
use spirv_std::ray_tracing::{AccelerationStructure, RayFlags, RayQuery};
|
||||
|
||||
#[spirv(fragment)]
|
||||
pub fn main(#[spirv(descriptor_set = 0, binding = 0)] accel: &AccelerationStructure) {
|
||||
unsafe {
|
||||
spirv_std::ray_query!(let mut handle);
|
||||
handle.initialize(accel, RayFlags::NONE, 0, Vec3::ZERO, 0.0, Vec3::ZERO, 0.0);
|
||||
handle.get_intersection_candidate_aabb_opaque();
|
||||
}
|
||||
}
|
@ -0,0 +1,14 @@
|
||||
// build-pass
|
||||
// compile-flags: -Ctarget-feature=+RayQueryKHR,+ext:SPV_KHR_ray_query
|
||||
|
||||
use glam::Vec3;
|
||||
use spirv_std::ray_tracing::{AccelerationStructure, RayFlags, RayQuery};
|
||||
|
||||
#[spirv(fragment)]
|
||||
pub fn main(#[spirv(descriptor_set = 0, binding = 0)] accel: &AccelerationStructure) {
|
||||
unsafe {
|
||||
spirv_std::ray_query!(let mut handle);
|
||||
handle.initialize(accel, RayFlags::NONE, 0, Vec3::ZERO, 0.0, Vec3::ZERO, 0.0);
|
||||
let direction: glam::Vec3 = handle.get_intersection_object_ray_direction::<_, 5>();
|
||||
}
|
||||
}
|
@ -0,0 +1,14 @@
|
||||
// build-pass
|
||||
// compile-flags: -Ctarget-feature=+RayQueryKHR,+ext:SPV_KHR_ray_query
|
||||
|
||||
use glam::Vec3;
|
||||
use spirv_std::ray_tracing::{AccelerationStructure, RayFlags, RayQuery};
|
||||
|
||||
#[spirv(fragment)]
|
||||
pub fn main(#[spirv(descriptor_set = 0, binding = 0)] accel: &AccelerationStructure) {
|
||||
unsafe {
|
||||
spirv_std::ray_query!(let mut handle);
|
||||
handle.initialize(accel, RayFlags::NONE, 0, Vec3::ZERO, 0.0, Vec3::ZERO, 0.0);
|
||||
let origin: glam::Vec3 = handle.get_intersection_object_ray_origin::<_, 5>();
|
||||
}
|
||||
}
|
@ -0,0 +1,14 @@
|
||||
// build-pass
|
||||
// compile-flags: -Ctarget-feature=+RayQueryKHR,+ext:SPV_KHR_ray_query
|
||||
|
||||
use glam::Vec3;
|
||||
use spirv_std::ray_tracing::{AccelerationStructure, RayFlags, RayQuery};
|
||||
|
||||
#[spirv(fragment)]
|
||||
pub fn main(#[spirv(descriptor_set = 0, binding = 0)] accel: &AccelerationStructure) {
|
||||
unsafe {
|
||||
spirv_std::ray_query!(let mut handle);
|
||||
handle.initialize(accel, RayFlags::NONE, 0, Vec3::ZERO, 0.0, Vec3::ZERO, 0.0);
|
||||
let matrix: [glam::Vec3; 4] = handle.get_intersection_object_to_world::<_, 5>();
|
||||
}
|
||||
}
|
@ -0,0 +1,14 @@
|
||||
// build-pass
|
||||
// compile-flags: -Ctarget-feature=+RayQueryKHR,+ext:SPV_KHR_ray_query
|
||||
|
||||
use glam::Vec3;
|
||||
use spirv_std::ray_tracing::{AccelerationStructure, RayFlags, RayQuery};
|
||||
|
||||
#[spirv(fragment)]
|
||||
pub fn main(#[spirv(descriptor_set = 0, binding = 0)] accel: &AccelerationStructure) {
|
||||
unsafe {
|
||||
spirv_std::ray_query!(let mut handle);
|
||||
handle.initialize(accel, RayFlags::NONE, 0, Vec3::ZERO, 0.0, Vec3::ZERO, 0.0);
|
||||
handle.get_intersection_primitive_index::<5>();
|
||||
}
|
||||
}
|
14
tests/ui/arch/ray_query_get_ray_flags_khr.rs
Normal file
14
tests/ui/arch/ray_query_get_ray_flags_khr.rs
Normal file
@ -0,0 +1,14 @@
|
||||
// build-pass
|
||||
// compile-flags: -Ctarget-feature=+RayQueryKHR,+ext:SPV_KHR_ray_query
|
||||
|
||||
use glam::Vec3;
|
||||
use spirv_std::ray_tracing::{AccelerationStructure, RayFlags, RayQuery};
|
||||
|
||||
#[spirv(fragment)]
|
||||
pub fn main(#[spirv(descriptor_set = 0, binding = 0)] accel: &AccelerationStructure) {
|
||||
unsafe {
|
||||
spirv_std::ray_query!(let mut handle);
|
||||
handle.initialize(accel, RayFlags::NONE, 0, Vec3::ZERO, 0.0, Vec3::ZERO, 0.0);
|
||||
let flags = handle.get_ray_flags();
|
||||
}
|
||||
}
|
14
tests/ui/arch/ray_query_get_world_ray_direction_khr.rs
Normal file
14
tests/ui/arch/ray_query_get_world_ray_direction_khr.rs
Normal file
@ -0,0 +1,14 @@
|
||||
// build-pass
|
||||
// compile-flags: -Ctarget-feature=+RayQueryKHR,+ext:SPV_KHR_ray_query
|
||||
|
||||
use glam::Vec3;
|
||||
use spirv_std::ray_tracing::{AccelerationStructure, RayFlags, RayQuery};
|
||||
|
||||
#[spirv(fragment)]
|
||||
pub fn main(#[spirv(descriptor_set = 0, binding = 0)] accel: &AccelerationStructure) {
|
||||
unsafe {
|
||||
spirv_std::ray_query!(let mut handle);
|
||||
handle.initialize(accel, RayFlags::NONE, 0, Vec3::ZERO, 0.0, Vec3::ZERO, 0.0);
|
||||
let direction: glam::Vec3 = handle.get_world_ray_direction();
|
||||
}
|
||||
}
|
14
tests/ui/arch/ray_query_get_world_ray_origin_khr.rs
Normal file
14
tests/ui/arch/ray_query_get_world_ray_origin_khr.rs
Normal file
@ -0,0 +1,14 @@
|
||||
// build-pass
|
||||
// compile-flags: -Ctarget-feature=+RayQueryKHR,+ext:SPV_KHR_ray_query
|
||||
|
||||
use glam::Vec3;
|
||||
use spirv_std::ray_tracing::{AccelerationStructure, RayFlags, RayQuery};
|
||||
|
||||
#[spirv(fragment)]
|
||||
pub fn main(#[spirv(descriptor_set = 0, binding = 0)] accel: &AccelerationStructure) {
|
||||
unsafe {
|
||||
spirv_std::ray_query!(let mut handle);
|
||||
handle.initialize(accel, RayFlags::NONE, 0, Vec3::ZERO, 0.0, Vec3::ZERO, 0.0);
|
||||
let origin: glam::Vec3 = handle.get_world_ray_origin();
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user