spv-out: fix acceleration structure in a function argument

This commit is contained in:
Dzmitry Malyshau 2024-07-14 22:16:50 -07:00 committed by Teodor Tanasoaia
parent d02e2949b2
commit 1b4e8ada63
5 changed files with 193 additions and 162 deletions

View File

@ -160,6 +160,7 @@ By @teoxoy in [#5901](https://github.com/gfx-rs/wgpu/pull/5901)
- Implement `WGSL`'s `unpack4xI8`,`unpack4xU8`,`pack4xI8` and `pack4xU8`. By @VlaDexa in [#5424](https://github.com/gfx-rs/wgpu/pull/5424)
- Began work adding support for atomics to the SPIR-V frontend. Tracking issue is [here](https://github.com/gfx-rs/wgpu/issues/4489). By @schell in [#5702](https://github.com/gfx-rs/wgpu/pull/5702).
- In hlsl-out, allow passing information about the fragment entry point to omit vertex outputs that are not in the fragment inputs. By @Imberflur in [#5531](https://github.com/gfx-rs/wgpu/pull/5531)
- In spv-out, allow passing `acceleration_structure` as a function argument. By @kvark in [#5961](https://github.com/gfx-rs/wgpu/pull/5961)
```diff
let writer: naga::back::hlsl::Writer = /* ... */;

View File

@ -254,7 +254,9 @@ impl crate::TypeInner {
/// Returns true if this is a handle to a type rather than the type directly.
pub const fn is_handle(&self) -> bool {
match *self {
crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } => true,
crate::TypeInner::Image { .. }
| crate::TypeInner::Sampler { .. }
| crate::TypeInner::AccelerationStructure { .. } => true,
_ => false,
}
}

View File

@ -1,6 +1,3 @@
@group(0) @binding(0)
var acc_struct: acceleration_structure;
/*
let RAY_FLAG_NONE = 0x00u;
let RAY_FLAG_OPAQUE = 0x01u;
@ -43,6 +40,18 @@ struct RayIntersection {
}
*/
fn query_loop(pos: vec3<f32>, dir: vec3<f32>, acs: acceleration_structure) -> RayIntersection {
var rq: ray_query;
rayQueryInitialize(&rq, acs, RayDesc(RAY_FLAG_TERMINATE_ON_FIRST_HIT, 0xFFu, 0.1, 100.0, pos, dir));
while (rayQueryProceed(&rq)) {}
return rayQueryGetCommittedIntersection(&rq);
}
@group(0) @binding(0)
var acc_struct: acceleration_structure;
struct Output {
visible: u32,
normal: vec3<f32>,
@ -58,16 +67,14 @@ fn get_torus_normal(world_point: vec3<f32>, intersection: RayIntersection) -> ve
return normalize(world_point - world_point_on_guiding_line);
}
@compute @workgroup_size(1)
fn main() {
var rq: ray_query;
let pos = vec3<f32>(0.0);
let dir = vec3<f32>(0.0, 1.0, 0.0);
rayQueryInitialize(&rq, acc_struct, RayDesc(RAY_FLAG_TERMINATE_ON_FIRST_HIT, 0xFFu, 0.1, 100.0, vec3<f32>(0.0), dir));
let intersection = query_loop(pos, dir, acc_struct);
while (rayQueryProceed(&rq)) {}
let intersection = rayQueryGetCommittedIntersection(&rq);
output.visible = u32(intersection.kind == RAY_QUERY_INTERSECTION_NONE);
output.normal = get_torus_normal(dir * intersection.t, intersection);
}

View File

@ -13,11 +13,6 @@ constexpr metal::uint _map_intersection_type(const metal::raytracing::intersecti
ty==metal::raytracing::intersection_type::bounding_box ? 4 : 0;
}
struct Output {
uint visible;
char _pad1[12];
metal::float3 normal;
};
struct RayIntersection {
uint kind;
float t;
@ -40,6 +35,34 @@ struct RayDesc {
metal::float3 origin;
metal::float3 dir;
};
struct Output {
uint visible;
char _pad1[12];
metal::float3 normal;
};
RayIntersection query_loop(
metal::float3 pos,
metal::float3 dir,
metal::raytracing::instance_acceleration_structure acs
) {
_RayQuery rq = {};
RayDesc _e8 = RayDesc {4u, 255u, 0.1, 100.0, pos, dir};
rq.intersector.assume_geometry_type(metal::raytracing::geometry_type::triangle);
rq.intersector.set_opacity_cull_mode((_e8.flags & 64) != 0 ? metal::raytracing::opacity_cull_mode::opaque : (_e8.flags & 128) != 0 ? metal::raytracing::opacity_cull_mode::non_opaque : metal::raytracing::opacity_cull_mode::none);
rq.intersector.force_opacity((_e8.flags & 1) != 0 ? metal::raytracing::forced_opacity::opaque : (_e8.flags & 2) != 0 ? metal::raytracing::forced_opacity::non_opaque : metal::raytracing::forced_opacity::none);
rq.intersector.accept_any_intersection((_e8.flags & 4) != 0);
rq.intersection = rq.intersector.intersect(metal::raytracing::ray(_e8.origin, _e8.dir, _e8.tmin, _e8.tmax), acs, _e8.cull_mask); rq.ready = true;
while(true) {
bool _e9 = rq.ready;
rq.ready = false;
if (_e9) {
} else {
break;
}
}
return RayIntersection {_map_intersection_type(rq.intersection.type), rq.intersection.distance, rq.intersection.user_instance_id, rq.intersection.instance_id, {}, rq.intersection.geometry_id, rq.intersection.primitive_id, rq.intersection.triangle_barycentric_coord, rq.intersection.triangle_front_facing, {}, rq.intersection.object_to_world_transform, rq.intersection.world_to_object_transform};
}
metal::float3 get_torus_normal(
metal::float3 world_point,
@ -55,25 +78,11 @@ kernel void main_(
metal::raytracing::instance_acceleration_structure acc_struct [[user(fake0)]]
, device Output& output [[user(fake0)]]
) {
_RayQuery rq = {};
metal::float3 dir = metal::float3(0.0, 1.0, 0.0);
RayDesc _e12 = RayDesc {4u, 255u, 0.1, 100.0, metal::float3(0.0), dir};
rq.intersector.assume_geometry_type(metal::raytracing::geometry_type::triangle);
rq.intersector.set_opacity_cull_mode((_e12.flags & 64) != 0 ? metal::raytracing::opacity_cull_mode::opaque : (_e12.flags & 128) != 0 ? metal::raytracing::opacity_cull_mode::non_opaque : metal::raytracing::opacity_cull_mode::none);
rq.intersector.force_opacity((_e12.flags & 1) != 0 ? metal::raytracing::forced_opacity::opaque : (_e12.flags & 2) != 0 ? metal::raytracing::forced_opacity::non_opaque : metal::raytracing::forced_opacity::none);
rq.intersector.accept_any_intersection((_e12.flags & 4) != 0);
rq.intersection = rq.intersector.intersect(metal::raytracing::ray(_e12.origin, _e12.dir, _e12.tmin, _e12.tmax), acc_struct, _e12.cull_mask); rq.ready = true;
while(true) {
bool _e13 = rq.ready;
rq.ready = false;
if (_e13) {
} else {
break;
}
}
RayIntersection intersection_1 = RayIntersection {_map_intersection_type(rq.intersection.type), rq.intersection.distance, rq.intersection.user_instance_id, rq.intersection.instance_id, {}, rq.intersection.geometry_id, rq.intersection.primitive_id, rq.intersection.triangle_barycentric_coord, rq.intersection.triangle_front_facing, {}, rq.intersection.object_to_world_transform, rq.intersection.world_to_object_transform};
output.visible = static_cast<uint>(intersection_1.kind == 0u);
metal::float3 _e25 = get_torus_normal(dir * intersection_1.t, intersection_1);
output.normal = _e25;
metal::float3 pos_1 = metal::float3(0.0);
metal::float3 dir_1 = metal::float3(0.0, 1.0, 0.0);
RayIntersection _e7 = query_loop(pos_1, dir_1, acc_struct);
output.visible = static_cast<uint>(_e7.kind == 0u);
metal::float3 _e18 = get_torus_normal(dir_1 * _e7.t, _e7);
output.normal = _e18;
return;
}

View File

@ -1,37 +1,37 @@
; SPIR-V
; Version: 1.4
; Generator: rspirv
; Bound: 95
; Bound: 104
OpCapability Shader
OpCapability RayQueryKHR
OpExtension "SPV_KHR_ray_query"
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %41 "main" %15 %17
OpExecutionMode %41 LocalSize 1 1 1
OpMemberDecorate %7 0 Offset 0
OpMemberDecorate %7 1 Offset 16
OpMemberDecorate %11 0 Offset 0
OpMemberDecorate %11 1 Offset 4
OpMemberDecorate %11 2 Offset 8
OpMemberDecorate %11 3 Offset 12
OpMemberDecorate %11 4 Offset 16
OpMemberDecorate %11 5 Offset 20
OpMemberDecorate %11 6 Offset 24
OpMemberDecorate %11 7 Offset 28
OpMemberDecorate %11 8 Offset 36
OpMemberDecorate %11 9 Offset 48
OpMemberDecorate %11 9 ColMajor
OpMemberDecorate %11 9 MatrixStride 16
OpMemberDecorate %11 10 Offset 112
OpMemberDecorate %11 10 ColMajor
OpMemberDecorate %11 10 MatrixStride 16
OpMemberDecorate %14 0 Offset 0
OpMemberDecorate %14 1 Offset 4
OpMemberDecorate %14 2 Offset 8
OpMemberDecorate %14 3 Offset 12
OpMemberDecorate %14 4 Offset 16
OpMemberDecorate %14 5 Offset 32
OpEntryPoint GLCompute %84 "main" %15 %17
OpExecutionMode %84 LocalSize 1 1 1
OpMemberDecorate %10 0 Offset 0
OpMemberDecorate %10 1 Offset 4
OpMemberDecorate %10 2 Offset 8
OpMemberDecorate %10 3 Offset 12
OpMemberDecorate %10 4 Offset 16
OpMemberDecorate %10 5 Offset 20
OpMemberDecorate %10 6 Offset 24
OpMemberDecorate %10 7 Offset 28
OpMemberDecorate %10 8 Offset 36
OpMemberDecorate %10 9 Offset 48
OpMemberDecorate %10 9 ColMajor
OpMemberDecorate %10 9 MatrixStride 16
OpMemberDecorate %10 10 Offset 112
OpMemberDecorate %10 10 ColMajor
OpMemberDecorate %10 10 MatrixStride 16
OpMemberDecorate %12 0 Offset 0
OpMemberDecorate %12 1 Offset 4
OpMemberDecorate %12 2 Offset 8
OpMemberDecorate %12 3 Offset 12
OpMemberDecorate %12 4 Offset 16
OpMemberDecorate %12 5 Offset 32
OpMemberDecorate %13 0 Offset 0
OpMemberDecorate %13 1 Offset 16
OpDecorate %15 DescriptorSet 0
OpDecorate %15 Binding 0
OpDecorate %17 DescriptorSet 0
@ -39,114 +39,126 @@ OpDecorate %17 Binding 1
OpDecorate %18 Block
OpMemberDecorate %18 0 Offset 0
%2 = OpTypeVoid
%3 = OpTypeAccelerationStructureNV
%4 = OpTypeInt 32 0
%6 = OpTypeFloat 32
%5 = OpTypeVector %6 3
%7 = OpTypeStruct %4 %5
%8 = OpTypeVector %6 2
%9 = OpTypeBool
%10 = OpTypeMatrix %5 4
%11 = OpTypeStruct %4 %6 %4 %4 %4 %4 %4 %8 %9 %10 %10
%12 = OpTypeVector %6 4
%13 = OpTypeRayQueryKHR
%14 = OpTypeStruct %4 %4 %6 %6 %5 %5
%16 = OpTypePointer UniformConstant %3
%4 = OpTypeFloat 32
%3 = OpTypeVector %4 3
%5 = OpTypeAccelerationStructureNV
%6 = OpTypeInt 32 0
%7 = OpTypeVector %4 2
%8 = OpTypeBool
%9 = OpTypeMatrix %3 4
%10 = OpTypeStruct %6 %4 %6 %6 %6 %6 %6 %7 %8 %9 %9
%11 = OpTypeRayQueryKHR
%12 = OpTypeStruct %6 %6 %4 %4 %3 %3
%13 = OpTypeStruct %6 %3
%14 = OpTypeVector %4 4
%16 = OpTypePointer UniformConstant %5
%15 = OpVariable %16 UniformConstant
%18 = OpTypeStruct %7
%18 = OpTypeStruct %13
%19 = OpTypePointer StorageBuffer %18
%17 = OpVariable %19 StorageBuffer
%24 = OpTypeFunction %5 %5 %11
%25 = OpConstant %6 1.0
%26 = OpConstant %6 2.4
%27 = OpConstant %6 0.0
%42 = OpTypeFunction %2
%44 = OpTypePointer StorageBuffer %7
%45 = OpConstant %4 0
%47 = OpConstantComposite %5 %27 %25 %27
%48 = OpConstant %4 4
%49 = OpConstant %4 255
%50 = OpConstantComposite %5 %27 %27 %27
%51 = OpConstant %6 0.1
%52 = OpConstant %6 100.0
%53 = OpConstantComposite %14 %48 %49 %51 %52 %50 %47
%55 = OpTypePointer Function %13
%72 = OpConstant %4 1
%85 = OpTypePointer StorageBuffer %4
%90 = OpTypePointer StorageBuffer %5
%23 = OpFunction %5 None %24
%21 = OpFunctionParameter %5
%22 = OpFunctionParameter %11
%26 = OpTypeFunction %10 %3 %3 %16
%27 = OpConstant %6 4
%28 = OpConstant %6 255
%29 = OpConstant %4 0.1
%30 = OpConstant %4 100.0
%32 = OpTypePointer Function %11
%50 = OpConstant %6 1
%67 = OpTypeFunction %3 %3 %10
%68 = OpConstant %4 1.0
%69 = OpConstant %4 2.4
%70 = OpConstant %4 0.0
%85 = OpTypeFunction %2
%87 = OpTypePointer StorageBuffer %13
%88 = OpConstant %6 0
%90 = OpConstantComposite %3 %70 %70 %70
%91 = OpConstantComposite %3 %70 %68 %70
%94 = OpTypePointer StorageBuffer %6
%99 = OpTypePointer StorageBuffer %3
%25 = OpFunction %10 None %26
%21 = OpFunctionParameter %3
%22 = OpFunctionParameter %3
%23 = OpFunctionParameter %16
%20 = OpLabel
OpBranch %28
%28 = OpLabel
%29 = OpCompositeExtract %10 %22 10
%30 = OpCompositeConstruct %12 %21 %25
%31 = OpMatrixTimesVector %5 %29 %30
%32 = OpVectorShuffle %8 %31 %31 0 1
%33 = OpExtInst %8 %1 Normalize %32
%34 = OpVectorTimesScalar %8 %33 %26
%35 = OpCompositeExtract %10 %22 9
%36 = OpCompositeConstruct %12 %34 %27 %25
%37 = OpMatrixTimesVector %5 %35 %36
%38 = OpFSub %5 %21 %37
%39 = OpExtInst %5 %1 Normalize %38
OpReturnValue %39
%31 = OpVariable %32 Function
%24 = OpLoad %5 %23
OpBranch %33
%33 = OpLabel
%34 = OpCompositeConstruct %12 %27 %28 %29 %30 %21 %22
%35 = OpCompositeExtract %6 %34 0
%36 = OpCompositeExtract %6 %34 1
%37 = OpCompositeExtract %4 %34 2
%38 = OpCompositeExtract %4 %34 3
%39 = OpCompositeExtract %3 %34 4
%40 = OpCompositeExtract %3 %34 5
OpRayQueryInitializeKHR %31 %24 %35 %36 %39 %37 %40 %38
OpBranch %41
%41 = OpLabel
OpLoopMerge %42 %44 None
OpBranch %43
%43 = OpLabel
%45 = OpRayQueryProceedKHR %8 %31
OpSelectionMerge %46 None
OpBranchConditional %45 %46 %47
%47 = OpLabel
OpBranch %42
%46 = OpLabel
OpBranch %48
%48 = OpLabel
OpBranch %49
%49 = OpLabel
OpBranch %44
%44 = OpLabel
OpBranch %41
%42 = OpLabel
%51 = OpRayQueryGetIntersectionTypeKHR %6 %31 %50
%52 = OpRayQueryGetIntersectionInstanceCustomIndexKHR %6 %31 %50
%53 = OpRayQueryGetIntersectionInstanceIdKHR %6 %31 %50
%54 = OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR %6 %31 %50
%55 = OpRayQueryGetIntersectionGeometryIndexKHR %6 %31 %50
%56 = OpRayQueryGetIntersectionPrimitiveIndexKHR %6 %31 %50
%57 = OpRayQueryGetIntersectionTKHR %4 %31 %50
%58 = OpRayQueryGetIntersectionBarycentricsKHR %7 %31 %50
%59 = OpRayQueryGetIntersectionFrontFaceKHR %8 %31 %50
%60 = OpRayQueryGetIntersectionObjectToWorldKHR %9 %31 %50
%61 = OpRayQueryGetIntersectionWorldToObjectKHR %9 %31 %50
%62 = OpCompositeConstruct %10 %51 %57 %52 %53 %54 %55 %56 %58 %59 %60 %61
OpReturnValue %62
OpFunctionEnd
%41 = OpFunction %2 None %42
%40 = OpLabel
%54 = OpVariable %55 Function
%43 = OpLoad %3 %15
%46 = OpAccessChain %44 %17 %45
OpBranch %56
%56 = OpLabel
%57 = OpCompositeExtract %4 %53 0
%58 = OpCompositeExtract %4 %53 1
%59 = OpCompositeExtract %6 %53 2
%60 = OpCompositeExtract %6 %53 3
%61 = OpCompositeExtract %5 %53 4
%62 = OpCompositeExtract %5 %53 5
OpRayQueryInitializeKHR %54 %43 %57 %58 %61 %59 %62 %60
OpBranch %63
%66 = OpFunction %3 None %67
%64 = OpFunctionParameter %3
%65 = OpFunctionParameter %10
%63 = OpLabel
OpLoopMerge %64 %66 None
OpBranch %65
%65 = OpLabel
%67 = OpRayQueryProceedKHR %9 %54
OpSelectionMerge %68 None
OpBranchConditional %67 %68 %69
%69 = OpLabel
OpBranch %64
%68 = OpLabel
OpBranch %70
%70 = OpLabel
OpBranch %71
%71 = OpLabel
OpBranch %66
%66 = OpLabel
OpBranch %63
%64 = OpLabel
%73 = OpRayQueryGetIntersectionTypeKHR %4 %54 %72
%74 = OpRayQueryGetIntersectionInstanceCustomIndexKHR %4 %54 %72
%75 = OpRayQueryGetIntersectionInstanceIdKHR %4 %54 %72
%76 = OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR %4 %54 %72
%77 = OpRayQueryGetIntersectionGeometryIndexKHR %4 %54 %72
%78 = OpRayQueryGetIntersectionPrimitiveIndexKHR %4 %54 %72
%79 = OpRayQueryGetIntersectionTKHR %6 %54 %72
%80 = OpRayQueryGetIntersectionBarycentricsKHR %8 %54 %72
%81 = OpRayQueryGetIntersectionFrontFaceKHR %9 %54 %72
%82 = OpRayQueryGetIntersectionObjectToWorldKHR %10 %54 %72
%83 = OpRayQueryGetIntersectionWorldToObjectKHR %10 %54 %72
%84 = OpCompositeConstruct %11 %73 %79 %74 %75 %76 %77 %78 %80 %81 %82 %83
%86 = OpCompositeExtract %4 %84 0
%87 = OpIEqual %9 %86 %45
%88 = OpSelect %4 %87 %72 %45
%89 = OpAccessChain %85 %46 %45
OpStore %89 %88
%91 = OpCompositeExtract %6 %84 1
%92 = OpVectorTimesScalar %5 %47 %91
%93 = OpFunctionCall %5 %23 %92 %84
%94 = OpAccessChain %90 %46 %72
OpStore %94 %93
%72 = OpCompositeExtract %9 %65 10
%73 = OpCompositeConstruct %14 %64 %68
%74 = OpMatrixTimesVector %3 %72 %73
%75 = OpVectorShuffle %7 %74 %74 0 1
%76 = OpExtInst %7 %1 Normalize %75
%77 = OpVectorTimesScalar %7 %76 %69
%78 = OpCompositeExtract %9 %65 9
%79 = OpCompositeConstruct %14 %77 %70 %68
%80 = OpMatrixTimesVector %3 %78 %79
%81 = OpFSub %3 %64 %80
%82 = OpExtInst %3 %1 Normalize %81
OpReturnValue %82
OpFunctionEnd
%84 = OpFunction %2 None %85
%83 = OpLabel
%86 = OpLoad %5 %15
%89 = OpAccessChain %87 %17 %88
OpBranch %92
%92 = OpLabel
%93 = OpFunctionCall %10 %25 %90 %91 %15
%95 = OpCompositeExtract %6 %93 0
%96 = OpIEqual %8 %95 %88
%97 = OpSelect %6 %96 %50 %88
%98 = OpAccessChain %94 %89 %88
OpStore %98 %97
%100 = OpCompositeExtract %4 %93 1
%101 = OpVectorTimesScalar %3 %91 %100
%102 = OpFunctionCall %3 %66 %101 %93
%103 = OpAccessChain %99 %89 %50
OpStore %103 %102
OpReturn
OpFunctionEnd