[naga] Adjust RayQuery statements in override processing.

This commit is contained in:
Jim Blandy 2024-03-25 19:11:23 -07:00 committed by Teodor Tanasoaia
parent 906ed128de
commit ba19d8df34
8 changed files with 702 additions and 2 deletions

View File

@ -633,7 +633,7 @@ fn adjust_stmt(new_pos: &[Handle<Expression>], stmt: &mut Statement) {
Statement::Call {
ref mut arguments,
ref mut result,
..
function: _,
} => {
for argument in arguments.iter_mut() {
adjust(argument);
@ -642,8 +642,24 @@ fn adjust_stmt(new_pos: &[Handle<Expression>], stmt: &mut Statement) {
adjust(e);
}
}
Statement::RayQuery { ref mut query, .. } => {
Statement::RayQuery {
ref mut query,
ref mut fun,
} => {
adjust(query);
match *fun {
crate::RayQueryFunction::Initialize {
ref mut acceleration_structure,
ref mut descriptor,
} => {
adjust(acceleration_structure);
adjust(descriptor);
}
crate::RayQueryFunction::Proceed { ref mut result } => {
adjust(result);
}
crate::RayQueryFunction::Terminate => {}
}
}
Statement::Break | Statement::Continue | Statement::Kill | Statement::Barrier(_) => {}
}

View File

@ -0,0 +1,18 @@
(
god_mode: true,
spv: (
version: (1, 4),
separate_entry_points: true,
),
msl: (
lang_version: (2, 4),
spirv_cross_compatibility: false,
fake_missing_bindings: true,
zero_initialize_workgroup_memory: false,
per_entry_point_map: {},
inline_samplers: [],
),
pipeline_constants: {
"o": 2.0
}
)

View File

@ -0,0 +1,21 @@
override o: f32;
@group(0) @binding(0)
var acc_struct: acceleration_structure;
@compute @workgroup_size(1)
fn main() {
var rq: ray_query;
let desc = RayDesc(
RAY_FLAG_TERMINATE_ON_FIRST_HIT,
0xFFu,
o * 17.0,
o * 19.0,
vec3<f32>(o * 23.0),
vec3<f32>(o * 29.0, o * 31.0, o * 37.0),
);
rayQueryInitialize(&rq, acc_struct, desc);
while (rayQueryProceed(&rq)) {}
}

View File

@ -0,0 +1,259 @@
(
types: [
(
name: None,
inner: Scalar((
kind: Float,
width: 4,
)),
),
(
name: None,
inner: AccelerationStructure,
),
(
name: None,
inner: RayQuery,
),
(
name: None,
inner: Scalar((
kind: Uint,
width: 4,
)),
),
(
name: None,
inner: Vector(
size: Tri,
scalar: (
kind: Float,
width: 4,
),
),
),
(
name: Some("RayDesc"),
inner: Struct(
members: [
(
name: Some("flags"),
ty: 4,
binding: None,
offset: 0,
),
(
name: Some("cull_mask"),
ty: 4,
binding: None,
offset: 4,
),
(
name: Some("tmin"),
ty: 1,
binding: None,
offset: 8,
),
(
name: Some("tmax"),
ty: 1,
binding: None,
offset: 12,
),
(
name: Some("origin"),
ty: 5,
binding: None,
offset: 16,
),
(
name: Some("dir"),
ty: 5,
binding: None,
offset: 32,
),
],
span: 48,
),
),
],
special_types: (
ray_desc: Some(6),
ray_intersection: None,
predeclared_types: {},
),
constants: [],
overrides: [
(
name: Some("o"),
id: None,
ty: 1,
init: None,
),
],
global_variables: [
(
name: Some("acc_struct"),
space: Handle,
binding: Some((
group: 0,
binding: 0,
)),
ty: 2,
init: None,
),
],
global_expressions: [],
functions: [],
entry_points: [
(
name: "main",
stage: Compute,
early_depth_test: None,
workgroup_size: (1, 1, 1),
function: (
name: Some("main"),
arguments: [],
result: None,
local_variables: [
(
name: Some("rq"),
ty: 3,
init: None,
),
],
expressions: [
LocalVariable(1),
Literal(U32(4)),
Literal(U32(255)),
Override(1),
Literal(F32(17.0)),
Binary(
op: Multiply,
left: 4,
right: 5,
),
Override(1),
Literal(F32(19.0)),
Binary(
op: Multiply,
left: 7,
right: 8,
),
Override(1),
Literal(F32(23.0)),
Binary(
op: Multiply,
left: 10,
right: 11,
),
Splat(
size: Tri,
value: 12,
),
Override(1),
Literal(F32(29.0)),
Binary(
op: Multiply,
left: 14,
right: 15,
),
Override(1),
Literal(F32(31.0)),
Binary(
op: Multiply,
left: 17,
right: 18,
),
Override(1),
Literal(F32(37.0)),
Binary(
op: Multiply,
left: 20,
right: 21,
),
Compose(
ty: 5,
components: [
16,
19,
22,
],
),
Compose(
ty: 6,
components: [
2,
3,
6,
9,
13,
23,
],
),
GlobalVariable(1),
RayQueryProceedResult,
],
named_expressions: {
24: "desc",
},
body: [
Emit((
start: 5,
end: 6,
)),
Emit((
start: 8,
end: 9,
)),
Emit((
start: 11,
end: 13,
)),
Emit((
start: 15,
end: 16,
)),
Emit((
start: 18,
end: 19,
)),
Emit((
start: 21,
end: 24,
)),
RayQuery(
query: 1,
fun: Initialize(
acceleration_structure: 25,
descriptor: 24,
),
),
Loop(
body: [
RayQuery(
query: 1,
fun: Proceed(
result: 26,
),
),
If(
condition: 26,
accept: [],
reject: [
Break,
],
),
Block([]),
],
continuing: [],
break_if: None,
),
Return(
value: None,
),
],
),
),
],
)

View File

@ -0,0 +1,259 @@
(
types: [
(
name: None,
inner: Scalar((
kind: Float,
width: 4,
)),
),
(
name: None,
inner: AccelerationStructure,
),
(
name: None,
inner: RayQuery,
),
(
name: None,
inner: Scalar((
kind: Uint,
width: 4,
)),
),
(
name: None,
inner: Vector(
size: Tri,
scalar: (
kind: Float,
width: 4,
),
),
),
(
name: Some("RayDesc"),
inner: Struct(
members: [
(
name: Some("flags"),
ty: 4,
binding: None,
offset: 0,
),
(
name: Some("cull_mask"),
ty: 4,
binding: None,
offset: 4,
),
(
name: Some("tmin"),
ty: 1,
binding: None,
offset: 8,
),
(
name: Some("tmax"),
ty: 1,
binding: None,
offset: 12,
),
(
name: Some("origin"),
ty: 5,
binding: None,
offset: 16,
),
(
name: Some("dir"),
ty: 5,
binding: None,
offset: 32,
),
],
span: 48,
),
),
],
special_types: (
ray_desc: Some(6),
ray_intersection: None,
predeclared_types: {},
),
constants: [],
overrides: [
(
name: Some("o"),
id: None,
ty: 1,
init: None,
),
],
global_variables: [
(
name: Some("acc_struct"),
space: Handle,
binding: Some((
group: 0,
binding: 0,
)),
ty: 2,
init: None,
),
],
global_expressions: [],
functions: [],
entry_points: [
(
name: "main",
stage: Compute,
early_depth_test: None,
workgroup_size: (1, 1, 1),
function: (
name: Some("main"),
arguments: [],
result: None,
local_variables: [
(
name: Some("rq"),
ty: 3,
init: None,
),
],
expressions: [
LocalVariable(1),
Literal(U32(4)),
Literal(U32(255)),
Override(1),
Literal(F32(17.0)),
Binary(
op: Multiply,
left: 4,
right: 5,
),
Override(1),
Literal(F32(19.0)),
Binary(
op: Multiply,
left: 7,
right: 8,
),
Override(1),
Literal(F32(23.0)),
Binary(
op: Multiply,
left: 10,
right: 11,
),
Splat(
size: Tri,
value: 12,
),
Override(1),
Literal(F32(29.0)),
Binary(
op: Multiply,
left: 14,
right: 15,
),
Override(1),
Literal(F32(31.0)),
Binary(
op: Multiply,
left: 17,
right: 18,
),
Override(1),
Literal(F32(37.0)),
Binary(
op: Multiply,
left: 20,
right: 21,
),
Compose(
ty: 5,
components: [
16,
19,
22,
],
),
Compose(
ty: 6,
components: [
2,
3,
6,
9,
13,
23,
],
),
GlobalVariable(1),
RayQueryProceedResult,
],
named_expressions: {
24: "desc",
},
body: [
Emit((
start: 5,
end: 6,
)),
Emit((
start: 8,
end: 9,
)),
Emit((
start: 11,
end: 13,
)),
Emit((
start: 15,
end: 16,
)),
Emit((
start: 18,
end: 19,
)),
Emit((
start: 21,
end: 24,
)),
RayQuery(
query: 1,
fun: Initialize(
acceleration_structure: 25,
descriptor: 24,
),
),
Loop(
body: [
RayQuery(
query: 1,
fun: Proceed(
result: 26,
),
),
If(
condition: 26,
accept: [],
reject: [
Break,
],
),
Block([]),
],
continuing: [],
break_if: None,
),
Return(
value: None,
),
],
),
),
],
)

View File

@ -0,0 +1,45 @@
// language: metal2.4
#include <metal_stdlib>
#include <simd/simd.h>
using metal::uint;
struct _RayQuery {
metal::raytracing::intersector<metal::raytracing::instancing, metal::raytracing::triangle_data, metal::raytracing::world_space_data> intersector;
metal::raytracing::intersector<metal::raytracing::instancing, metal::raytracing::triangle_data, metal::raytracing::world_space_data>::result_type intersection;
bool ready = false;
};
constexpr metal::uint _map_intersection_type(const metal::raytracing::intersection_type ty) {
return ty==metal::raytracing::intersection_type::triangle ? 1 :
ty==metal::raytracing::intersection_type::bounding_box ? 4 : 0;
}
struct RayDesc {
uint flags;
uint cull_mask;
float tmin;
float tmax;
metal::float3 origin;
metal::float3 dir;
};
constant float o = 2.0;
kernel void main_(
metal::raytracing::instance_acceleration_structure acc_struct [[user(fake0)]]
) {
_RayQuery rq = {};
RayDesc desc = RayDesc {4u, 255u, 34.0, 38.0, metal::float3(46.0), metal::float3(58.0, 62.0, 74.0)};
rq.intersector.assume_geometry_type(metal::raytracing::geometry_type::triangle);
rq.intersector.set_opacity_cull_mode((desc.flags & 64) != 0 ? metal::raytracing::opacity_cull_mode::opaque : (desc.flags & 128) != 0 ? metal::raytracing::opacity_cull_mode::non_opaque : metal::raytracing::opacity_cull_mode::none);
rq.intersector.force_opacity((desc.flags & 1) != 0 ? metal::raytracing::forced_opacity::opaque : (desc.flags & 2) != 0 ? metal::raytracing::forced_opacity::non_opaque : metal::raytracing::forced_opacity::none);
rq.intersector.accept_any_intersection((desc.flags & 4) != 0);
rq.intersection = rq.intersector.intersect(metal::raytracing::ray(desc.origin, desc.dir, desc.tmin, desc.tmax), acc_struct, desc.cull_mask); rq.ready = true;
while(true) {
bool _e31 = rq.ready;
rq.ready = false;
if (_e31) {
} else {
break;
}
}
return;
}

View File

@ -0,0 +1,77 @@
; SPIR-V
; Version: 1.4
; Generator: rspirv
; Bound: 46
OpCapability Shader
OpCapability RayQueryKHR
OpExtension "SPV_KHR_ray_query"
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %13 "main" %10
OpExecutionMode %13 LocalSize 1 1 1
OpMemberDecorate %8 0 Offset 0
OpMemberDecorate %8 1 Offset 4
OpMemberDecorate %8 2 Offset 8
OpMemberDecorate %8 3 Offset 12
OpMemberDecorate %8 4 Offset 16
OpMemberDecorate %8 5 Offset 32
OpDecorate %10 DescriptorSet 0
OpDecorate %10 Binding 0
%2 = OpTypeVoid
%3 = OpTypeFloat 32
%4 = OpTypeAccelerationStructureNV
%5 = OpTypeRayQueryKHR
%6 = OpTypeInt 32 0
%7 = OpTypeVector %3 3
%8 = OpTypeStruct %6 %6 %3 %3 %7 %7
%9 = OpConstant %3 2.0
%11 = OpTypePointer UniformConstant %4
%10 = OpVariable %11 UniformConstant
%14 = OpTypeFunction %2
%16 = OpConstant %6 4
%17 = OpConstant %6 255
%18 = OpConstant %3 34.0
%19 = OpConstant %3 38.0
%20 = OpConstant %3 46.0
%21 = OpConstantComposite %7 %20 %20 %20
%22 = OpConstant %3 58.0
%23 = OpConstant %3 62.0
%24 = OpConstant %3 74.0
%25 = OpConstantComposite %7 %22 %23 %24
%26 = OpConstantComposite %8 %16 %17 %18 %19 %21 %25
%28 = OpTypePointer Function %5
%41 = OpTypeBool
%13 = OpFunction %2 None %14
%12 = OpLabel
%27 = OpVariable %28 Function
%15 = OpLoad %4 %10
OpBranch %29
%29 = OpLabel
%30 = OpCompositeExtract %6 %26 0
%31 = OpCompositeExtract %6 %26 1
%32 = OpCompositeExtract %3 %26 2
%33 = OpCompositeExtract %3 %26 3
%34 = OpCompositeExtract %7 %26 4
%35 = OpCompositeExtract %7 %26 5
OpRayQueryInitializeKHR %27 %15 %30 %31 %34 %32 %35 %33
OpBranch %36
%36 = OpLabel
OpLoopMerge %37 %39 None
OpBranch %38
%38 = OpLabel
%40 = OpRayQueryProceedKHR %41 %27
OpSelectionMerge %42 None
OpBranchConditional %40 %42 %43
%43 = OpLabel
OpBranch %37
%42 = OpLabel
OpBranch %44
%44 = OpLabel
OpBranch %45
%45 = OpLabel
OpBranch %39
%39 = OpLabel
OpBranch %36
%37 = OpLabel
OpReturn
OpFunctionEnd

View File

@ -466,6 +466,7 @@ fn write_output_spv(
);
}
} else {
assert!(pipeline_constants.is_empty());
write_output_spv_inner(input, module, info, &options, None, "spvasm");
}
}
@ -857,6 +858,10 @@ fn convert_wgsl() {
"overrides-atomicCompareExchangeWeak",
Targets::IR | Targets::SPIRV,
),
(
"overrides-ray-query",
Targets::IR | Targets::SPIRV | Targets::METAL,
),
];
for &(name, targets) in inputs.iter() {