ray query: validation, better test

This commit is contained in:
Dzmitry Malyshau 2023-03-17 22:53:13 -07:00
parent 18710fee1e
commit 024c197cc8
5 changed files with 284 additions and 124 deletions

View File

@ -35,6 +35,8 @@ pub enum ExpressionError {
InvalidPointerType(Handle<crate::Expression>),
#[error("Array length of {0:?} can't be done")]
InvalidArrayType(Handle<crate::Expression>),
#[error("Get intersection of {0:?} can't be done")]
InvalidRayQueryType(Handle<crate::Expression>),
#[error("Splatting {0:?} can't be done")]
InvalidSplatType(Handle<crate::Expression>),
#[error("Swizzling {0:?} can't be done")]
@ -1427,7 +1429,26 @@ impl super::Validator {
return Err(ExpressionError::InvalidArrayType(expr));
}
},
E::RayQueryProceedResult | E::RayQueryGetIntersection { .. } => ShaderStages::all(),
E::RayQueryProceedResult => ShaderStages::all(),
E::RayQueryGetIntersection {
query,
committed: _,
} => match resolver[query] {
Ti::Pointer {
base,
space: crate::AddressSpace::Function,
} => match resolver.types[base].inner {
Ti::RayQuery => ShaderStages::all(),
ref other => {
log::error!("Intersection result of a pointer to {:?}", other);
return Err(ExpressionError::InvalidRayQueryType(query));
}
},
ref other => {
log::error!("Intersection result of {:?}", other);
return Err(ExpressionError::InvalidRayQueryType(query));
}
},
};
Ok(stages)
}

View File

@ -47,8 +47,6 @@ pub enum AtomicError {
InvalidPointer(Handle<crate::Expression>),
#[error("Operand {0:?} has invalid type.")]
InvalidOperand(Handle<crate::Expression>),
#[error("Result expression {0:?} has already been introduced earlier")]
ResultAlreadyInScope(Handle<crate::Expression>),
#[error("Result type for {0:?} doesn't match the statement")]
ResultTypeMismatch(Handle<crate::Expression>),
}
@ -131,6 +129,14 @@ pub enum FunctionError {
},
#[error("Atomic operation is invalid")]
InvalidAtomic(#[from] AtomicError),
#[error("Ray Query {0:?} is not a local variable")]
InvalidRayQueryExpression(Handle<crate::Expression>),
#[error("Acceleration structure {0:?} is not a matching expression")]
InvalidAccelerationStructure(Handle<crate::Expression>),
#[error("Ray descriptor {0:?} is not a matching expression")]
InvalidRayDescriptor(Handle<crate::Expression>),
#[error("Ray Query {0:?} does not have a matching type")]
InvalidRayQueryType(Handle<crate::Type>),
#[error(
"Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}"
)]
@ -169,8 +175,10 @@ struct BlockContext<'a> {
info: &'a FunctionInfo,
expressions: &'a Arena<crate::Expression>,
types: &'a UniqueArena<crate::Type>,
local_vars: &'a Arena<crate::LocalVariable>,
global_vars: &'a Arena<crate::GlobalVariable>,
functions: &'a Arena<crate::Function>,
special_types: &'a crate::SpecialTypes,
prev_infos: &'a [FunctionInfo],
return_type: Option<Handle<crate::Type>>,
}
@ -188,8 +196,10 @@ impl<'a> BlockContext<'a> {
info,
expressions: &fun.expressions,
types: &module.types,
local_vars: &fun.local_variables,
global_vars: &module.global_variables,
functions: &module.functions,
special_types: &module.special_types,
prev_infos,
return_type: fun.result.as_ref().map(|fr| fr.ty),
}
@ -299,6 +309,21 @@ impl super::Validator {
Ok(callee_info.available_stages)
}
#[cfg(feature = "validate")]
fn emit_expression(
&mut self,
handle: Handle<crate::Expression>,
context: &BlockContext,
) -> Result<(), WithSpan<FunctionError>> {
if self.valid_expression_set.insert(handle.index()) {
self.valid_expression_list.push(handle);
Ok(())
} else {
Err(FunctionError::ExpressionAlreadyInScope(handle)
.with_span_handle(handle, context.expressions))
}
}
#[cfg(feature = "validate")]
fn validate_atomic(
&mut self,
@ -347,13 +372,7 @@ impl super::Validator {
}
}
if self.valid_expression_set.insert(result.index()) {
self.valid_expression_list.push(result);
} else {
return Err(AtomicError::ResultAlreadyInScope(result)
.with_span_handle(result, context.expressions)
.into_other());
}
self.emit_expression(result, context)?;
match context.expressions[result] {
crate::Expression::AtomicResult { ty, comparison }
if {
@ -401,12 +420,7 @@ impl super::Validator {
match *statement {
S::Emit(ref range) => {
for handle in range.clone() {
if self.valid_expression_set.insert(handle.index()) {
self.valid_expression_list.push(handle);
} else {
return Err(FunctionError::ExpressionAlreadyInScope(handle)
.with_span_handle(handle, context.expressions));
}
self.emit_expression(handle, context)?;
}
}
S::Block(ref block) => {
@ -807,8 +821,55 @@ impl super::Validator {
} => {
self.validate_atomic(pointer, fun, value, result, context)?;
}
S::RayQuery { query: _, fun: _ } => {
//TODO
S::RayQuery { query, ref fun } => {
let query_var = match *context.get_expression(query) {
crate::Expression::LocalVariable(var) => &context.local_vars[var],
ref other => {
log::error!("Unexpected ray query expression {other:?}");
return Err(FunctionError::InvalidRayQueryExpression(query)
.with_span_static(span, "invalid query expression"));
}
};
match context.types[query_var.ty].inner {
Ti::RayQuery => {}
ref other => {
log::error!("Unexpected ray query type {other:?}");
return Err(FunctionError::InvalidRayQueryType(query_var.ty)
.with_span_static(span, "invalid query type"));
}
}
match *fun {
crate::RayQueryFunction::Initialize {
acceleration_structure,
descriptor,
} => {
match *context
.resolve_type(acceleration_structure, &self.valid_expression_set)?
{
Ti::AccelerationStructure => {}
_ => {
return Err(FunctionError::InvalidAccelerationStructure(
acceleration_structure,
)
.with_span_static(span, "invalid acceleration structure"))
}
}
let desc_ty_given =
context.resolve_type(descriptor, &self.valid_expression_set)?;
let desc_ty_expected = context
.special_types
.ray_desc
.map(|handle| &context.types[handle].inner);
if Some(desc_ty_given) != desc_ty_expected {
return Err(FunctionError::InvalidRayDescriptor(descriptor)
.with_span_static(span, "invalid ray descriptor"));
}
}
crate::RayQueryFunction::Proceed { result } => {
self.emit_expression(result, context)?;
}
crate::RayQueryFunction::Terminate => {}
}
}
}
}

View File

@ -45,19 +45,29 @@ struct RayIntersection {
struct Output {
visible: u32,
normal: vec3<f32>,
}
@group(0) @binding(1)
var<storage, read_write> output: Output;
fn get_torus_normal(world_point: vec3<f32>, intersection: RayIntersection) -> vec3<f32> {
let local_point = intersection.world_to_object * vec4<f32>(world_point, 1.0);
let point_on_guiding_line = normalize(local_point.xy) * 2.4;
let world_point_on_guiding_line = intersection.object_to_world * vec4<f32>(point_on_guiding_line, 0.0, 1.0);
return normalize(world_point - world_point_on_guiding_line);
}
@compute @workgroup_size(1)
fn main() {
var rq: ray_query;
rayQueryInitialize(&rq, acc_struct, RayDesc(RAY_FLAG_TERMINATE_ON_FIRST_HIT, 0xFFu, 0.1, 100.0, vec3<f32>(0.0), vec3<f32>(0.0, 1.0, 0.0)));
let dir = vec3<f32>(0.0, 1.0, 0.0);
rayQueryInitialize(&rq, acc_struct, RayDesc(RAY_FLAG_TERMINATE_ON_FIRST_HIT, 0xFFu, 0.1, 100.0, vec3<f32>(0.0), dir));
rayQueryProceed(&rq);
while (rayQueryProceed(&rq)) {}
let intersection = rayQueryGetCommittedIntersection(&rq);
output.visible = u32(intersection.kind == RAY_QUERY_INTERSECTION_NONE);
output.normal = get_torus_normal(dir * intersection.t, intersection);
}

View File

@ -15,14 +15,8 @@ constexpr metal::uint _map_intersection_type(const metal::raytracing::intersecti
struct Output {
uint visible_;
};
struct RayDesc {
uint flags;
uint cull_mask;
float tmin;
float tmax;
metal::float3 origin;
metal::float3 dir;
char _pad1[12];
metal::float3 normal;
};
struct RayIntersection {
uint kind;
@ -38,21 +32,48 @@ struct RayIntersection {
metal::float4x3 object_to_world;
metal::float4x3 world_to_object;
};
struct RayDesc {
uint flags;
uint cull_mask;
float tmin;
float tmax;
metal::float3 origin;
metal::float3 dir;
};
metal::float3 get_torus_normal(
metal::float3 world_point,
RayIntersection intersection
) {
metal::float3 local_point = intersection.world_to_object * metal::float4(world_point, 1.0);
metal::float2 point_on_guiding_line = metal::normalize(local_point.xy) * 2.4000000953674316;
metal::float3 world_point_on_guiding_line = intersection.object_to_world * metal::float4(point_on_guiding_line, 0.0, 1.0);
return metal::normalize(world_point - world_point_on_guiding_line);
}
kernel void main_(
metal::raytracing::instance_acceleration_structure acc_struct [[user(fake0)]]
, device Output& output [[user(fake0)]]
) {
_RayQuery rq = {};
RayDesc _e12 = RayDesc {4u, 255u, 0.10000000149011612, 100.0, metal::float3(0.0), metal::float3(0.0, 1.0, 0.0)};
metal::float3 dir = metal::float3(0.0, 1.0, 0.0);
RayDesc _e12 = RayDesc {4u, 255u, 0.10000000149011612, 100.0, metal::float3(0.0), dir};
rq.intersector.assume_geometry_type(metal::raytracing::geometry_type::triangle);
rq.intersector.set_opacity_cull_mode((_e12.flags & 64) != 0 ? metal::raytracing::opacity_cull_mode::opaque : (_e12.flags & 128) != 0 ? metal::raytracing::opacity_cull_mode::non_opaque : metal::raytracing::opacity_cull_mode::none);
rq.intersector.force_opacity((_e12.flags & 1) != 0 ? metal::raytracing::forced_opacity::opaque : (_e12.flags & 2) != 0 ? metal::raytracing::forced_opacity::non_opaque : metal::raytracing::forced_opacity::none);
rq.intersector.accept_any_intersection((_e12.flags & 4) != 0);
rq.intersection = rq.intersector.intersect(metal::raytracing::ray(_e12.origin, _e12.dir, _e12.tmin, _e12.tmax), acc_struct, _e12.cull_mask); rq.ready = true;
while(true) {
bool _e13 = rq.ready;
rq.ready = false;
RayIntersection intersection = RayIntersection {_map_intersection_type(rq.intersection.type), rq.intersection.distance, rq.intersection.user_instance_id, rq.intersection.instance_id, {}, rq.intersection.geometry_id, rq.intersection.primitive_id, rq.intersection.triangle_barycentric_coord, rq.intersection.triangle_front_facing, {}, rq.intersection.object_to_world_transform, rq.intersection.world_to_object_transform};
output.visible_ = static_cast<uint>(intersection.kind == 0u);
if (_e13) {
} else {
break;
}
}
RayIntersection intersection_1 = RayIntersection {_map_intersection_type(rq.intersection.type), rq.intersection.distance, rq.intersection.user_instance_id, rq.intersection.instance_id, {}, rq.intersection.geometry_id, rq.intersection.primitive_id, rq.intersection.triangle_barycentric_coord, rq.intersection.triangle_front_facing, {}, rq.intersection.object_to_world_transform, rq.intersection.world_to_object_transform};
output.visible_ = static_cast<uint>(intersection_1.kind == 0u);
metal::float3 _e25 = get_torus_normal(dir * intersection_1.t, intersection_1);
output.normal = _e25;
return;
}

View File

@ -1,105 +1,152 @@
; SPIR-V
; Version: 1.4
; Generator: rspirv
; Bound: 63
; Bound: 95
OpCapability RayQueryKHR
OpCapability Shader
OpExtension "SPV_KHR_ray_query"
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %29 "main" %21 %23
OpExecutionMode %29 LocalSize 1 1 1
OpMemberDecorate %13 0 Offset 0
OpEntryPoint GLCompute %48 "main" %23 %25
OpExecutionMode %48 LocalSize 1 1 1
OpMemberDecorate %15 0 Offset 0
OpMemberDecorate %15 1 Offset 4
OpMemberDecorate %15 2 Offset 8
OpMemberDecorate %15 3 Offset 12
OpMemberDecorate %15 4 Offset 16
OpMemberDecorate %15 5 Offset 32
OpMemberDecorate %20 0 Offset 0
OpMemberDecorate %20 1 Offset 4
OpMemberDecorate %20 2 Offset 8
OpMemberDecorate %20 3 Offset 12
OpMemberDecorate %20 4 Offset 16
OpMemberDecorate %20 5 Offset 20
OpMemberDecorate %20 6 Offset 24
OpMemberDecorate %20 7 Offset 28
OpMemberDecorate %20 8 Offset 36
OpMemberDecorate %20 9 Offset 48
OpMemberDecorate %20 9 ColMajor
OpMemberDecorate %20 9 MatrixStride 16
OpMemberDecorate %20 10 Offset 112
OpMemberDecorate %20 10 ColMajor
OpMemberDecorate %20 10 MatrixStride 16
OpDecorate %21 DescriptorSet 0
OpDecorate %21 Binding 0
OpMemberDecorate %15 1 Offset 16
OpMemberDecorate %19 0 Offset 0
OpMemberDecorate %19 1 Offset 4
OpMemberDecorate %19 2 Offset 8
OpMemberDecorate %19 3 Offset 12
OpMemberDecorate %19 4 Offset 16
OpMemberDecorate %19 5 Offset 20
OpMemberDecorate %19 6 Offset 24
OpMemberDecorate %19 7 Offset 28
OpMemberDecorate %19 8 Offset 36
OpMemberDecorate %19 9 Offset 48
OpMemberDecorate %19 9 ColMajor
OpMemberDecorate %19 9 MatrixStride 16
OpMemberDecorate %19 10 Offset 112
OpMemberDecorate %19 10 ColMajor
OpMemberDecorate %19 10 MatrixStride 16
OpMemberDecorate %21 0 Offset 0
OpMemberDecorate %21 1 Offset 4
OpMemberDecorate %21 2 Offset 8
OpMemberDecorate %21 3 Offset 12
OpMemberDecorate %21 4 Offset 16
OpMemberDecorate %21 5 Offset 32
OpDecorate %23 DescriptorSet 0
OpDecorate %23 Binding 1
OpDecorate %24 Block
OpMemberDecorate %24 0 Offset 0
OpDecorate %23 Binding 0
OpDecorate %25 DescriptorSet 0
OpDecorate %25 Binding 1
OpDecorate %26 Block
OpMemberDecorate %26 0 Offset 0
%2 = OpTypeVoid
%4 = OpTypeInt 32 0
%3 = OpConstant %4 4
%5 = OpConstant %4 255
%7 = OpTypeFloat 32
%6 = OpConstant %7 0.1
%8 = OpConstant %7 100.0
%9 = OpConstant %7 0.0
%10 = OpConstant %7 1.0
%11 = OpConstant %4 0
%12 = OpTypeAccelerationStructureNV
%13 = OpTypeStruct %4
%14 = OpTypeVector %7 3
%15 = OpTypeStruct %4 %4 %7 %7 %14 %14
%16 = OpTypeRayQueryKHR
%17 = OpTypeVector %7 2
%18 = OpTypeBool
%19 = OpTypeMatrix %14 4
%20 = OpTypeStruct %4 %7 %4 %4 %4 %4 %4 %17 %18 %19 %19
%22 = OpTypePointer UniformConstant %12
%21 = OpVariable %22 UniformConstant
%24 = OpTypeStruct %13
%25 = OpTypePointer StorageBuffer %24
%23 = OpVariable %25 StorageBuffer
%27 = OpTypePointer Function %16
%30 = OpTypeFunction %2
%32 = OpTypePointer StorageBuffer %13
%45 = OpConstant %4 1
%58 = OpTypePointer StorageBuffer %4
%29 = OpFunction %2 None %30
%4 = OpTypeFloat 32
%3 = OpConstant %4 1.0
%5 = OpConstant %4 2.4
%6 = OpConstant %4 0.0
%8 = OpTypeInt 32 0
%7 = OpConstant %8 4
%9 = OpConstant %8 255
%10 = OpConstant %4 0.1
%11 = OpConstant %4 100.0
%12 = OpConstant %8 0
%13 = OpTypeAccelerationStructureNV
%14 = OpTypeVector %4 3
%15 = OpTypeStruct %8 %14
%16 = OpTypeVector %4 2
%17 = OpTypeBool
%18 = OpTypeMatrix %14 4
%19 = OpTypeStruct %8 %4 %8 %8 %8 %8 %8 %16 %17 %18 %18
%20 = OpTypeVector %4 4
%21 = OpTypeStruct %8 %8 %4 %4 %14 %14
%22 = OpTypeRayQueryKHR
%24 = OpTypePointer UniformConstant %13
%23 = OpVariable %24 UniformConstant
%26 = OpTypeStruct %15
%27 = OpTypePointer StorageBuffer %26
%25 = OpVariable %27 StorageBuffer
%32 = OpTypeFunction %14 %14 %19
%46 = OpTypePointer Function %22
%49 = OpTypeFunction %2
%51 = OpTypePointer StorageBuffer %15
%72 = OpConstant %8 1
%85 = OpTypePointer StorageBuffer %8
%90 = OpTypePointer StorageBuffer %14
%31 = OpFunction %14 None %32
%29 = OpFunctionParameter %14
%30 = OpFunctionParameter %19
%28 = OpLabel
%26 = OpVariable %27 Function
%31 = OpLoad %12 %21
%33 = OpAccessChain %32 %23 %11
OpBranch %34
%34 = OpLabel
%35 = OpCompositeConstruct %14 %9 %9 %9
%36 = OpCompositeConstruct %14 %9 %10 %9
%37 = OpCompositeConstruct %15 %3 %5 %6 %8 %35 %36
%38 = OpCompositeExtract %4 %37 0
%39 = OpCompositeExtract %4 %37 1
%40 = OpCompositeExtract %7 %37 2
%41 = OpCompositeExtract %7 %37 3
%42 = OpCompositeExtract %14 %37 4
%43 = OpCompositeExtract %14 %37 5
OpRayQueryInitializeKHR %26 %31 %38 %39 %42 %40 %43 %41
%44 = OpRayQueryProceedKHR %18 %26
%46 = OpRayQueryGetIntersectionTypeKHR %4 %26 %45
%47 = OpRayQueryGetIntersectionInstanceCustomIndexKHR %4 %26 %45
%48 = OpRayQueryGetIntersectionInstanceIdKHR %4 %26 %45
%49 = OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR %4 %26 %45
%50 = OpRayQueryGetIntersectionGeometryIndexKHR %4 %26 %45
%51 = OpRayQueryGetIntersectionPrimitiveIndexKHR %4 %26 %45
%52 = OpRayQueryGetIntersectionTKHR %7 %26 %45
%53 = OpRayQueryGetIntersectionBarycentricsKHR %17 %26 %45
%54 = OpRayQueryGetIntersectionFrontFaceKHR %18 %26 %45
%55 = OpRayQueryGetIntersectionObjectToWorldKHR %19 %26 %45
%56 = OpRayQueryGetIntersectionWorldToObjectKHR %19 %26 %45
%57 = OpCompositeConstruct %20 %46 %52 %47 %48 %49 %50 %51 %53 %54 %55 %56
%59 = OpCompositeExtract %4 %57 0
%60 = OpIEqual %18 %59 %11
%61 = OpSelect %4 %60 %45 %11
%62 = OpAccessChain %58 %33 %11
OpStore %62 %61
OpBranch %33
%33 = OpLabel
%34 = OpCompositeExtract %18 %30 10
%35 = OpCompositeConstruct %20 %29 %3
%36 = OpMatrixTimesVector %14 %34 %35
%37 = OpVectorShuffle %16 %36 %36 0 1
%38 = OpExtInst %16 %1 Normalize %37
%39 = OpVectorTimesScalar %16 %38 %5
%40 = OpCompositeExtract %18 %30 9
%41 = OpCompositeConstruct %20 %39 %6 %3
%42 = OpMatrixTimesVector %14 %40 %41
%43 = OpFSub %14 %29 %42
%44 = OpExtInst %14 %1 Normalize %43
OpReturnValue %44
OpFunctionEnd
%48 = OpFunction %2 None %49
%47 = OpLabel
%45 = OpVariable %46 Function
%50 = OpLoad %13 %23
%52 = OpAccessChain %51 %25 %12
OpBranch %53
%53 = OpLabel
%54 = OpCompositeConstruct %14 %6 %3 %6
%55 = OpCompositeConstruct %14 %6 %6 %6
%56 = OpCompositeConstruct %21 %7 %9 %10 %11 %55 %54
%57 = OpCompositeExtract %8 %56 0
%58 = OpCompositeExtract %8 %56 1
%59 = OpCompositeExtract %4 %56 2
%60 = OpCompositeExtract %4 %56 3
%61 = OpCompositeExtract %14 %56 4
%62 = OpCompositeExtract %14 %56 5
OpRayQueryInitializeKHR %45 %50 %57 %58 %61 %59 %62 %60
OpBranch %63
%63 = OpLabel
OpLoopMerge %64 %66 None
OpBranch %65
%65 = OpLabel
%67 = OpRayQueryProceedKHR %17 %45
OpSelectionMerge %68 None
OpBranchConditional %67 %68 %69
%69 = OpLabel
OpBranch %64
%68 = OpLabel
OpBranch %70
%70 = OpLabel
OpBranch %71
%71 = OpLabel
OpBranch %66
%66 = OpLabel
OpBranch %63
%64 = OpLabel
%73 = OpRayQueryGetIntersectionTypeKHR %8 %45 %72
%74 = OpRayQueryGetIntersectionInstanceCustomIndexKHR %8 %45 %72
%75 = OpRayQueryGetIntersectionInstanceIdKHR %8 %45 %72
%76 = OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR %8 %45 %72
%77 = OpRayQueryGetIntersectionGeometryIndexKHR %8 %45 %72
%78 = OpRayQueryGetIntersectionPrimitiveIndexKHR %8 %45 %72
%79 = OpRayQueryGetIntersectionTKHR %4 %45 %72
%80 = OpRayQueryGetIntersectionBarycentricsKHR %16 %45 %72
%81 = OpRayQueryGetIntersectionFrontFaceKHR %17 %45 %72
%82 = OpRayQueryGetIntersectionObjectToWorldKHR %18 %45 %72
%83 = OpRayQueryGetIntersectionWorldToObjectKHR %18 %45 %72
%84 = OpCompositeConstruct %19 %73 %79 %74 %75 %76 %77 %78 %80 %81 %82 %83
%86 = OpCompositeExtract %8 %84 0
%87 = OpIEqual %17 %86 %12
%88 = OpSelect %8 %87 %72 %12
%89 = OpAccessChain %85 %52 %12
OpStore %89 %88
%91 = OpCompositeExtract %4 %84 1
%92 = OpVectorTimesScalar %14 %54 %91
%93 = OpFunctionCall %14 %31 %92 %84
%94 = OpAccessChain %90 %52 %72
OpStore %94 %93
OpReturn
OpFunctionEnd