From 9993f5c9eceb91f9bb6b9c5c6994f35f50e5f378 Mon Sep 17 00:00:00 2001
From: Dzmitry Malyshau <kvark@fastmail.com>
Date: Mon, 16 Dec 2024 12:04:21 -0800
Subject: [PATCH] [v23] Rt constants backport (#6711)

Co-authored-by: Connor Fitzgerald <connorwadefitzgerald@gmail.com>
---
 CHANGELOG.md                        |   8 ++
 Cargo.lock                          |   2 +-
 examples/src/ray_scene/shader.wgsl  | 164 ++++++++++++++++++++++++++++
 naga/Cargo.toml                     |   2 +-
 naga/src/back/msl/writer.rs         |   8 +-
 naga/src/back/spv/block.rs          |   5 +-
 naga/src/back/spv/ray.rs            |  49 ++++++++-
 naga/src/front/wgsl/lower/mod.rs    |  12 +-
 naga/src/front/wgsl/parse/mod.rs    |  62 ++++++++++-
 naga/src/lib.rs                     |  56 ++++++++++
 naga/tests/in/ray-query.wgsl        |  17 ++-
 naga/tests/out/msl/ray-query.msl    |  37 +++++--
 naga/tests/out/spv/ray-query.spvasm |  41 ++++++-
 13 files changed, 431 insertions(+), 32 deletions(-)
 create mode 100644 examples/src/ray_scene/shader.wgsl

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 0f81b26bf..2e7eccd1f 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -40,6 +40,14 @@ Bottom level categories:
 
 ## Unreleased
 
+## 23.1.0 (2024-12-11)
+
+### New Features
+
+#### Naga
+
+- Expose Ray Query flags as constants in WGSL. Implement candidate intersections. By @kvark in [#5429](https://github.com/gfx-rs/wgpu/pull/5429)
+
 ## 23.0.1 (2024-11-25)
 
 This release includes patches for `wgpu`, `wgpu-core` and `wgpu-hal`. All other crates remain at [23.0.0](https://github.com/gfx-rs/wgpu/releases/tag/v23.0.0).
diff --git a/Cargo.lock b/Cargo.lock
index 97dcaed7f..7a6d7f640 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1875,7 +1875,7 @@ dependencies = [
 
 [[package]]
 name = "naga"
-version = "23.0.0"
+version = "23.1.0"
 dependencies = [
  "arbitrary",
  "arrayvec",
diff --git a/examples/src/ray_scene/shader.wgsl b/examples/src/ray_scene/shader.wgsl
new file mode 100644
index 000000000..4e16bd945
--- /dev/null
+++ b/examples/src/ray_scene/shader.wgsl
@@ -0,0 +1,164 @@
+struct VertexOutput {
+    @builtin(position) position: vec4<f32>,
+    @location(0) tex_coords: vec2<f32>,
+};
+
+@vertex
+fn vs_main(@builtin(vertex_index) vertex_index: u32) -> VertexOutput {
+    var result: VertexOutput;
+    let x = i32(vertex_index) / 2;
+    let y = i32(vertex_index) & 1;
+    let tc = vec2<f32>(
+        f32(x) * 2.0,
+        f32(y) * 2.0
+    );
+    result.position = vec4<f32>(
+        tc.x * 2.0 - 1.0,
+        1.0 - tc.y * 2.0,
+        0.0, 1.0
+    );
+    result.tex_coords = tc;
+    return result;
+}
+
+/*
+The contents of the RayQuery struct are roughly as follows
+let RAY_FLAG_NONE = 0x00u;
+let RAY_FLAG_OPAQUE = 0x01u;
+let RAY_FLAG_NO_OPAQUE = 0x02u;
+let RAY_FLAG_TERMINATE_ON_FIRST_HIT = 0x04u;
+let RAY_FLAG_SKIP_CLOSEST_HIT_SHADER = 0x08u;
+let RAY_FLAG_CULL_BACK_FACING = 0x10u;
+let RAY_FLAG_CULL_FRONT_FACING = 0x20u;
+let RAY_FLAG_CULL_OPAQUE = 0x40u;
+let RAY_FLAG_CULL_NO_OPAQUE = 0x80u;
+let RAY_FLAG_SKIP_TRIANGLES = 0x100u;
+let RAY_FLAG_SKIP_AABBS = 0x200u;
+
+let RAY_QUERY_INTERSECTION_NONE = 0u;
+let RAY_QUERY_INTERSECTION_TRIANGLE = 1u;
+let RAY_QUERY_INTERSECTION_GENERATED = 2u;
+let RAY_QUERY_INTERSECTION_AABB = 3u;
+
+struct RayDesc {
+    flags: u32,
+    cull_mask: u32,
+    t_min: f32,
+    t_max: f32,
+    origin: vec3<f32>,
+    dir: vec3<f32>,
+}
+
+struct RayIntersection {
+    kind: u32,
+    t: f32,
+    instance_custom_index: u32,
+    instance_id: u32,
+    sbt_record_offset: u32,
+    geometry_index: u32,
+    primitive_index: u32,
+    barycentrics: vec2<f32>,
+    front_face: bool,
+    object_to_world: mat4x3<f32>,
+    world_to_object: mat4x3<f32>,
+}
+*/
+
+struct Uniforms {
+    view_inv: mat4x4<f32>,
+    proj_inv: mat4x4<f32>,
+};
+
+struct Vertex {
+    pos: vec3<f32>,
+    normal: vec3<f32>,
+    uv: vec2<f32>,
+};
+
+
+struct Instance {
+    first_vertex: u32,
+    first_geometry: u32,
+    last_geometry: u32,
+    _pad: u32
+};
+
+struct Material{
+    roughness_exponent: f32,
+    metalness: f32,
+    specularity: f32,
+    albedo: vec3<f32>
+}
+
+struct Geometry {
+    first_index: u32,
+    material: Material,
+};
+
+
+@group(0) @binding(0)
+var<uniform> uniforms: Uniforms;
+
+@group(0) @binding(1)
+var<storage, read> vertices: array<Vertex>;
+
+@group(0) @binding(2)
+var<storage, read> indices: array<u32>;
+
+@group(0) @binding(3)
+var<storage, read> geometries: array<Geometry>;
+
+@group(0) @binding(4)
+var<storage, read> instances: array<Instance>;
+
+@group(0) @binding(5)
+var acc_struct: acceleration_structure;
+
+@fragment
+fn fs_main(vertex: VertexOutput) -> @location(0) vec4<f32> {
+
+    var color =  vec4<f32>(vertex.tex_coords, 0.0, 1.0);
+
+	let d = vertex.tex_coords * 2.0 - 1.0;
+
+	let origin = (uniforms.view_inv * vec4<f32>(0.0,0.0,0.0,1.0)).xyz;
+	let temp = uniforms.proj_inv * vec4<f32>(d.x, d.y, 1.0, 1.0);
+	let direction = (uniforms.view_inv * vec4<f32>(normalize(temp.xyz), 0.0)).xyz;
+
+    var rq: ray_query;
+    rayQueryInitialize(&rq, acc_struct, RayDesc(0u, 0xFFu, 0.1, 200.0, origin, direction));
+    rayQueryProceed(&rq);
+
+    let intersection = rayQueryGetCommittedIntersection(&rq);
+    if (intersection.kind != RAY_QUERY_INTERSECTION_NONE) {
+        let instance = instances[intersection.instance_custom_index];
+        let geometry = geometries[intersection.geometry_index + instance.first_geometry];
+
+        let index_offset = geometry.first_index;
+        let vertex_offset = instance.first_vertex;
+
+        let first_index_index = intersection.primitive_index * 3u + index_offset;
+
+        let v_0 = vertices[vertex_offset+indices[first_index_index+0u]];
+        let v_1 = vertices[vertex_offset+indices[first_index_index+1u]];
+        let v_2 = vertices[vertex_offset+indices[first_index_index+2u]];
+
+        let bary = vec3<f32>(1.0 - intersection.barycentrics.x - intersection.barycentrics.y, intersection.barycentrics);
+
+        let pos = v_0.pos * bary.x + v_1.pos * bary.y + v_2.pos * bary.z;
+        let normal_raw = v_0.normal * bary.x + v_1.normal * bary.y + v_2.normal * bary.z;
+        let uv = v_0.uv * bary.x + v_1.uv * bary.y + v_2.uv * bary.z;
+
+        let normal = normalize(normal_raw);
+
+        let material = geometry.material;
+
+        color = vec4<f32>(material.albedo, 1.0);
+
+        if(intersection.instance_custom_index == 1u){
+            color = vec4<f32>(normal, 1.0);
+        }
+    }
+
+    return color;
+}
diff --git a/naga/Cargo.toml b/naga/Cargo.toml
index 16682e526..17be61b68 100644
--- a/naga/Cargo.toml
+++ b/naga/Cargo.toml
@@ -1,6 +1,6 @@
 [package]
 name = "naga"
-version = "23.0.0"
+version = "23.1.0"
 authors = ["gfx-rs developers"]
 edition = "2021"
 description = "Shader translation infrastructure"
diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs
index 83d937d48..4cc363b05 100644
--- a/naga/src/back/msl/writer.rs
+++ b/naga/src/back/msl/writer.rs
@@ -2237,14 +2237,14 @@ impl<W: Write> Writer<W> {
                     write!(self.out, ")")?;
                 }
             }
-            crate::Expression::RayQueryGetIntersection { query, committed } => {
+            crate::Expression::RayQueryGetIntersection {
+                query,
+                committed: _,
+            } => {
                 if context.lang_version < (2, 4) {
                     return Err(Error::UnsupportedRayTracing);
                 }
 
-                if !committed {
-                    unimplemented!()
-                }
                 let ty = context.module.special_types.ray_intersection.unwrap();
                 let type_name = &self.names[&NameKey::Type(ty)];
                 write!(self.out, "{type_name} {{{RAY_QUERY_FUN_MAP_INTERSECTION}(")?;
diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs
index 65c48d9e7..8330537cf 100644
--- a/naga/src/back/spv/block.rs
+++ b/naga/src/back/spv/block.rs
@@ -1730,10 +1730,7 @@ impl<'w> BlockContext<'w> {
             }
             crate::Expression::ArrayLength(expr) => self.write_runtime_array_length(expr, block)?,
             crate::Expression::RayQueryGetIntersection { query, committed } => {
-                if !committed {
-                    return Err(Error::FeatureNotImplemented("candidate intersection"));
-                }
-                self.write_ray_query_get_intersection(query, block)
+                self.write_ray_query_get_intersection(query, block, committed)
             }
         };
 
diff --git a/naga/src/back/spv/ray.rs b/naga/src/back/spv/ray.rs
index c2daf4b3f..7f16f803e 100644
--- a/naga/src/back/spv/ray.rs
+++ b/naga/src/back/spv/ray.rs
@@ -106,23 +106,60 @@ impl<'w> BlockContext<'w> {
         &mut self,
         query: Handle<crate::Expression>,
         block: &mut Block,
+        is_committed: bool,
     ) -> spirv::Word {
         let query_id = self.cached[query];
-        let intersection_id = self.writer.get_constant_scalar(crate::Literal::U32(
-            spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR as _,
-        ));
+        let intersection_id =
+            self.writer
+                .get_constant_scalar(crate::Literal::U32(if is_committed {
+                    spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR
+                } else {
+                    spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR
+                } as _));
 
         let flag_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric(
             NumericType::Scalar(crate::Scalar::U32),
         )));
-        let kind_id = self.gen_id();
+        let raw_kind_id = self.gen_id();
         block.body.push(Instruction::ray_query_get_intersection(
             spirv::Op::RayQueryGetIntersectionTypeKHR,
             flag_type_id,
-            kind_id,
+            raw_kind_id,
             query_id,
             intersection_id,
         ));
+        let kind_id = if is_committed {
+            // Nothing to do: the IR value matches `spirv::RayQueryCommittedIntersectionType`
+            raw_kind_id
+        } else {
+            // Remap from the candidate kind to IR
+            let condition_id = self.gen_id();
+            let committed_triangle_kind_id = self.writer.get_constant_scalar(crate::Literal::U32(
+                spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR
+                    as _,
+            ));
+            block.body.push(Instruction::binary(
+                spirv::Op::IEqual,
+                self.writer.get_bool_type_id(),
+                condition_id,
+                raw_kind_id,
+                committed_triangle_kind_id,
+            ));
+            let kind_id = self.gen_id();
+            block.body.push(Instruction::select(
+                flag_type_id,
+                kind_id,
+                condition_id,
+                self.writer.get_constant_scalar(crate::Literal::U32(
+                    crate::RayQueryIntersection::Triangle as _,
+                )),
+                self.writer.get_constant_scalar(crate::Literal::U32(
+                    crate::RayQueryIntersection::Aabb as _,
+                )),
+            ));
+            kind_id
+        };
+
         let instance_custom_index_id = self.gen_id();
         block.body.push(Instruction::ray_query_get_intersection(
             spirv::Op::RayQueryGetIntersectionInstanceCustomIndexKHR,
@@ -201,6 +238,8 @@ impl<'w> BlockContext<'w> {
             query_id,
             intersection_id,
         ));
+        //Note: there is also `OpRayQueryGetIntersectionCandidateAABBOpaqueKHR`,
+        // but it's not a property of an intersection.
 
         let transform_type_id =
             self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Matrix {
diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs
index 78e81350b..86bdc170e 100644
--- a/naga/src/front/wgsl/lower/mod.rs
+++ b/naga/src/front/wgsl/lower/mod.rs
@@ -2563,12 +2563,22 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
                             args.finish()?;
 
                             let _ = ctx.module.generate_ray_intersection_type();
-
                             crate::Expression::RayQueryGetIntersection {
                                 query,
                                 committed: true,
                             }
                         }
+                        "rayQueryGetCandidateIntersection" => {
+                            let mut args = ctx.prepare_args(arguments, 1, span);
+                            let query = self.ray_query_pointer(args.next()?, ctx)?;
+                            args.finish()?;
+
+                            let _ = ctx.module.generate_ray_intersection_type();
+                            crate::Expression::RayQueryGetIntersection {
+                                query,
+                                committed: false,
+                            }
+                        }
                         "RayDesc" => {
                             let ty = ctx.module.generate_ray_desc_type();
                             let handle = self.construct(
diff --git a/naga/src/front/wgsl/parse/mod.rs b/naga/src/front/wgsl/parse/mod.rs
index fcfcc3775..f380f406a 100644
--- a/naga/src/front/wgsl/parse/mod.rs
+++ b/naga/src/front/wgsl/parse/mod.rs
@@ -651,6 +651,14 @@ impl Parser {
         ctx: &mut ExpressionContext<'a, '_, '_>,
     ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> {
         self.push_rule_span(Rule::PrimaryExpr, lexer);
+        const fn literal_ray_flag<'b>(flag: crate::RayFlag) -> ast::Expression<'b> {
+            ast::Expression::Literal(ast::Literal::Number(Number::U32(flag.bits())))
+        }
+        const fn literal_ray_intersection<'b>(
+            intersection: crate::RayQueryIntersection,
+        ) -> ast::Expression<'b> {
+            ast::Expression::Literal(ast::Literal::Number(Number::U32(intersection as u32)))
+        }
 
         let expr = match lexer.peek() {
             (Token::Paren('('), _) => {
@@ -683,15 +691,63 @@ impl Parser {
             }
             (Token::Word("RAY_FLAG_NONE"), _) => {
                 let _ = lexer.next();
-                ast::Expression::Literal(ast::Literal::Number(Number::U32(0)))
+                literal_ray_flag(crate::RayFlag::empty())
+            }
+            (Token::Word("RAY_FLAG_FORCE_OPAQUE"), _) => {
+                let _ = lexer.next();
+                literal_ray_flag(crate::RayFlag::FORCE_OPAQUE)
+            }
+            (Token::Word("RAY_FLAG_FORCE_NO_OPAQUE"), _) => {
+                let _ = lexer.next();
+                literal_ray_flag(crate::RayFlag::FORCE_NO_OPAQUE)
             }
             (Token::Word("RAY_FLAG_TERMINATE_ON_FIRST_HIT"), _) => {
                 let _ = lexer.next();
-                ast::Expression::Literal(ast::Literal::Number(Number::U32(4)))
+                literal_ray_flag(crate::RayFlag::TERMINATE_ON_FIRST_HIT)
+            }
+            (Token::Word("RAY_FLAG_SKIP_CLOSEST_HIT_SHADER"), _) => {
+                let _ = lexer.next();
+                literal_ray_flag(crate::RayFlag::SKIP_CLOSEST_HIT_SHADER)
+            }
+            (Token::Word("RAY_FLAG_CULL_BACK_FACING"), _) => {
+                let _ = lexer.next();
+                literal_ray_flag(crate::RayFlag::CULL_BACK_FACING)
+            }
+            (Token::Word("RAY_FLAG_CULL_FRONT_FACING"), _) => {
+                let _ = lexer.next();
+                literal_ray_flag(crate::RayFlag::CULL_FRONT_FACING)
+            }
+            (Token::Word("RAY_FLAG_CULL_OPAQUE"), _) => {
+                let _ = lexer.next();
+                literal_ray_flag(crate::RayFlag::CULL_OPAQUE)
+            }
+            (Token::Word("RAY_FLAG_CULL_NO_OPAQUE"), _) => {
+                let _ = lexer.next();
+                literal_ray_flag(crate::RayFlag::CULL_NO_OPAQUE)
+            }
+            (Token::Word("RAY_FLAG_SKIP_TRIANGLES"), _) => {
+                let _ = lexer.next();
+                literal_ray_flag(crate::RayFlag::SKIP_TRIANGLES)
+            }
+            (Token::Word("RAY_FLAG_SKIP_AABBS"), _) => {
+                let _ = lexer.next();
+                literal_ray_flag(crate::RayFlag::SKIP_AABBS)
             }
             (Token::Word("RAY_QUERY_INTERSECTION_NONE"), _) => {
                 let _ = lexer.next();
-                ast::Expression::Literal(ast::Literal::Number(Number::U32(0)))
+                literal_ray_intersection(crate::RayQueryIntersection::None)
+            }
+            (Token::Word("RAY_QUERY_INTERSECTION_TRIANGLE"), _) => {
+                let _ = lexer.next();
+                literal_ray_intersection(crate::RayQueryIntersection::Triangle)
+            }
+            (Token::Word("RAY_QUERY_INTERSECTION_GENERATED"), _) => {
+                let _ = lexer.next();
+                literal_ray_intersection(crate::RayQueryIntersection::Generated)
+            }
+            (Token::Word("RAY_QUERY_INTERSECTION_AABB"), _) => {
+                let _ = lexer.next();
+                literal_ray_intersection(crate::RayQueryIntersection::Aabb)
             }
             (Token::Word(word), span) => {
                 let start = lexer.start_byte_offset();
diff --git a/naga/src/lib.rs b/naga/src/lib.rs
index 145b95f66..3450c0d20 100644
--- a/naga/src/lib.rs
+++ b/naga/src/lib.rs
@@ -2271,3 +2271,59 @@ pub struct Module {
     /// Entry points.
     pub entry_points: Vec<EntryPoint>,
 }
+
+bitflags::bitflags! {
+    /// Ray flags used when casting rays.
+    /// Matching vulkan constants can be found in
+    /// https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/KHR/ray_common/ray_flags_section.txt
+    #[cfg_attr(feature = "serialize", derive(Serialize))]
+    #[cfg_attr(feature = "deserialize", derive(Deserialize))]
+    #[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+    #[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
+    pub struct RayFlag: u32 {
+        /// Force all intersections to be treated as opaque.
+        const FORCE_OPAQUE = 0x1;
+        /// Force all intersections to be treated as non-opaque.
+        const FORCE_NO_OPAQUE = 0x2;
+        /// Stop traversal after the first hit.
+        const TERMINATE_ON_FIRST_HIT = 0x4;
+        /// Don't execute the closest hit shader.
+        const SKIP_CLOSEST_HIT_SHADER = 0x8;
+        /// Cull back facing geometry.
+        const CULL_BACK_FACING = 0x10;
+        /// Cull front facing geometry.
+        const CULL_FRONT_FACING = 0x20;
+        /// Cull opaque geometry.
+        const CULL_OPAQUE = 0x40;
+        /// Cull non-opaque geometry.
+        const CULL_NO_OPAQUE = 0x80;
+        /// Skip triangular geometry.
+        const SKIP_TRIANGLES = 0x100;
+        /// Skip axis-aligned bounding boxes.
+        const SKIP_AABBS = 0x200;
+    }
+}
+
+/// Type of a ray query intersection.
+/// Matching vulkan constants can be found in
+/// <https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/KHR/SPV_KHR_ray_query.asciidoc>
+/// but the actual values are different for candidate intersections.
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
+pub enum RayQueryIntersection {
+    /// No intersection found.
+    /// Matches `RayQueryCommittedIntersectionNoneKHR`.
+    #[default]
+    None = 0,
+    /// Intersecting with triangles.
+    /// Matches `RayQueryCommittedIntersectionTriangleKHR` and `RayQueryCandidateIntersectionTriangleKHR`.
+    Triangle = 1,
+    /// Intersecting with generated primitives.
+    /// Matches `RayQueryCommittedIntersectionGeneratedKHR`.
+    Generated = 2,
+    /// Intersecting with Axis Aligned Bounding Boxes.
+    /// Matches `RayQueryCandidateIntersectionAABBKHR`.
+    Aabb = 3,
+}
diff --git a/naga/tests/in/ray-query.wgsl b/naga/tests/in/ray-query.wgsl
index 0af8c7c95..9f94356b8 100644
--- a/naga/tests/in/ray-query.wgsl
+++ b/naga/tests/in/ray-query.wgsl
@@ -1,7 +1,7 @@
 /*
 let RAY_FLAG_NONE = 0x00u;
-let RAY_FLAG_OPAQUE = 0x01u;
-let RAY_FLAG_NO_OPAQUE = 0x02u;
+let RAY_FLAG_FORCE_OPAQUE = 0x01u;
+let RAY_FLAG_FORCE_NO_OPAQUE = 0x02u;
 let RAY_FLAG_TERMINATE_ON_FIRST_HIT = 0x04u;
 let RAY_FLAG_SKIP_CLOSEST_HIT_SHADER = 0x08u;
 let RAY_FLAG_CULL_BACK_FACING = 0x10u;
@@ -14,7 +14,7 @@ let RAY_FLAG_SKIP_AABBS = 0x200u;
 let RAY_QUERY_INTERSECTION_NONE = 0u;
 let RAY_QUERY_INTERSECTION_TRIANGLE = 1u;
 let RAY_QUERY_INTERSECTION_GENERATED = 2u;
-let RAY_QUERY_INTERSECTION_AABB = 4u;
+let RAY_QUERY_INTERSECTION_AABB = 3u;
 
 struct RayDesc {
     flags: u32,
@@ -78,3 +78,14 @@ fn main() {
     output.visible = u32(intersection.kind == RAY_QUERY_INTERSECTION_NONE);
     output.normal = get_torus_normal(dir * intersection.t, intersection);
 }
+
+@compute @workgroup_size(1)
+fn main_candidate() {
+    let pos = vec3<f32>(0.0);
+    let dir = vec3<f32>(0.0, 1.0, 0.0);
+
+    var rq: ray_query;
+    rayQueryInitialize(&rq, acc_struct, RayDesc(RAY_FLAG_TERMINATE_ON_FIRST_HIT, 0xFFu, 0.1, 100.0, pos, dir));
+    let intersection = rayQueryGetCandidateIntersection(&rq);
+    output.visible = u32(intersection.kind == RAY_QUERY_INTERSECTION_AABB);
+}
diff --git a/naga/tests/out/msl/ray-query.msl b/naga/tests/out/msl/ray-query.msl
index 129ad108a..b8d5bb340 100644
--- a/naga/tests/out/msl/ray-query.msl
+++ b/naga/tests/out/msl/ray-query.msl
@@ -46,23 +46,23 @@ RayIntersection query_loop(
     metal::float3 dir,
     metal::raytracing::instance_acceleration_structure acs
 ) {
-    _RayQuery rq = {};
+    _RayQuery rq_1 = {};
     RayDesc _e8 = RayDesc {4u, 255u, 0.1, 100.0, pos, dir};
-    rq.intersector.assume_geometry_type(metal::raytracing::geometry_type::triangle);
-    rq.intersector.set_opacity_cull_mode((_e8.flags & 64) != 0 ? metal::raytracing::opacity_cull_mode::opaque : (_e8.flags & 128) != 0 ? metal::raytracing::opacity_cull_mode::non_opaque : metal::raytracing::opacity_cull_mode::none);
-    rq.intersector.force_opacity((_e8.flags & 1) != 0 ? metal::raytracing::forced_opacity::opaque : (_e8.flags & 2) != 0 ? metal::raytracing::forced_opacity::non_opaque : metal::raytracing::forced_opacity::none);
-    rq.intersector.accept_any_intersection((_e8.flags & 4) != 0);
-    rq.intersection = rq.intersector.intersect(metal::raytracing::ray(_e8.origin, _e8.dir, _e8.tmin, _e8.tmax), acs, _e8.cull_mask);    rq.ready = true;
+    rq_1.intersector.assume_geometry_type(metal::raytracing::geometry_type::triangle);
+    rq_1.intersector.set_opacity_cull_mode((_e8.flags & 64) != 0 ? metal::raytracing::opacity_cull_mode::opaque : (_e8.flags & 128) != 0 ? metal::raytracing::opacity_cull_mode::non_opaque : metal::raytracing::opacity_cull_mode::none);
+    rq_1.intersector.force_opacity((_e8.flags & 1) != 0 ? metal::raytracing::forced_opacity::opaque : (_e8.flags & 2) != 0 ? metal::raytracing::forced_opacity::non_opaque : metal::raytracing::forced_opacity::none);
+    rq_1.intersector.accept_any_intersection((_e8.flags & 4) != 0);
+    rq_1.intersection = rq_1.intersector.intersect(metal::raytracing::ray(_e8.origin, _e8.dir, _e8.tmin, _e8.tmax), acs, _e8.cull_mask);    rq_1.ready = true;
 #define LOOP_IS_REACHABLE if (volatile bool unpredictable_jump_over_loop = true; unpredictable_jump_over_loop)
     LOOP_IS_REACHABLE while(true) {
-        bool _e9 = rq.ready;
-        rq.ready = false;
+        bool _e9 = rq_1.ready;
+        rq_1.ready = false;
         if (_e9) {
         } else {
             break;
         }
     }
-    return RayIntersection {_map_intersection_type(rq.intersection.type), rq.intersection.distance, rq.intersection.user_instance_id, rq.intersection.instance_id, {}, rq.intersection.geometry_id, rq.intersection.primitive_id, rq.intersection.triangle_barycentric_coord, rq.intersection.triangle_front_facing, {}, rq.intersection.object_to_world_transform, rq.intersection.world_to_object_transform};
+    return RayIntersection {_map_intersection_type(rq_1.intersection.type), rq_1.intersection.distance, rq_1.intersection.user_instance_id, rq_1.intersection.instance_id, {}, rq_1.intersection.geometry_id, rq_1.intersection.primitive_id, rq_1.intersection.triangle_barycentric_coord, rq_1.intersection.triangle_front_facing, {}, rq_1.intersection.object_to_world_transform, rq_1.intersection.world_to_object_transform};
 }
 
 metal::float3 get_torus_normal(
@@ -87,3 +87,22 @@ kernel void main_(
     output.normal = _e18;
     return;
 }
+
+
+kernel void main_candidate(
+  metal::raytracing::instance_acceleration_structure acc_struct [[user(fake0)]]
+, device Output& output [[user(fake0)]]
+) {
+    _RayQuery rq = {};
+    metal::float3 pos_2 = metal::float3(0.0);
+    metal::float3 dir_2 = metal::float3(0.0, 1.0, 0.0);
+    RayDesc _e12 = RayDesc {4u, 255u, 0.1, 100.0, pos_2, dir_2};
+    rq.intersector.assume_geometry_type(metal::raytracing::geometry_type::triangle);
+    rq.intersector.set_opacity_cull_mode((_e12.flags & 64) != 0 ? metal::raytracing::opacity_cull_mode::opaque : (_e12.flags & 128) != 0 ? metal::raytracing::opacity_cull_mode::non_opaque : metal::raytracing::opacity_cull_mode::none);
+    rq.intersector.force_opacity((_e12.flags & 1) != 0 ? metal::raytracing::forced_opacity::opaque : (_e12.flags & 2) != 0 ? metal::raytracing::forced_opacity::non_opaque : metal::raytracing::forced_opacity::none);
+    rq.intersector.accept_any_intersection((_e12.flags & 4) != 0);
+    rq.intersection = rq.intersector.intersect(metal::raytracing::ray(_e12.origin, _e12.dir, _e12.tmin, _e12.tmax), acc_struct, _e12.cull_mask);    rq.ready = true;
+    RayIntersection intersection_1 = RayIntersection {_map_intersection_type(rq.intersection.type), rq.intersection.distance, rq.intersection.user_instance_id, rq.intersection.instance_id, {}, rq.intersection.geometry_id, rq.intersection.primitive_id, rq.intersection.triangle_barycentric_coord, rq.intersection.triangle_front_facing, {}, rq.intersection.object_to_world_transform, rq.intersection.world_to_object_transform};
+    output.visible = static_cast<uint>(intersection_1.kind == 3u);
+    return;
+}
diff --git a/naga/tests/out/spv/ray-query.spvasm b/naga/tests/out/spv/ray-query.spvasm
index 8b784f2fa..5279bfc2e 100644
--- a/naga/tests/out/spv/ray-query.spvasm
+++ b/naga/tests/out/spv/ray-query.spvasm
@@ -1,14 +1,16 @@
 ; SPIR-V
 ; Version: 1.4
 ; Generator: rspirv
-; Bound: 104
+; Bound: 136
 OpCapability Shader
 OpCapability RayQueryKHR
 OpExtension "SPV_KHR_ray_query"
 %1 = OpExtInstImport "GLSL.std.450"
 OpMemoryModel Logical GLSL450
 OpEntryPoint GLCompute %84 "main" %15 %17
+OpEntryPoint GLCompute %105 "main_candidate" %15 %17
 OpExecutionMode %84 LocalSize 1 1 1
+OpExecutionMode %105 LocalSize 1 1 1
 OpMemberDecorate %10 0 Offset 0
 OpMemberDecorate %10 1 Offset 4
 OpMemberDecorate %10 2 Offset 8
@@ -74,6 +76,8 @@ OpMemberDecorate %18 0 Offset 0
 %91 = OpConstantComposite  %4  %70 %68 %70
 %94 = OpTypePointer StorageBuffer %6
 %99 = OpTypePointer StorageBuffer %4
+%108 = OpConstantComposite  %12  %27 %28 %29 %30 %90 %91
+%109 = OpConstant  %6  3
 %25 = OpFunction  %10  None %26
 %21 = OpFunctionParameter  %4
 %22 = OpFunctionParameter  %4
@@ -161,4 +165,39 @@ OpStore %98 %97
 %103 = OpAccessChain  %99  %89 %50
 OpStore %103 %102
 OpReturn
+OpFunctionEnd
+%105 = OpFunction  %2  None %85
+%104 = OpLabel
+%110 = OpVariable  %32  Function
+%106 = OpLoad  %5  %15
+%107 = OpAccessChain  %87  %17 %88
+OpBranch %111
+%111 = OpLabel
+%112 = OpCompositeExtract  %6  %108 0
+%113 = OpCompositeExtract  %6  %108 1
+%114 = OpCompositeExtract  %3  %108 2
+%115 = OpCompositeExtract  %3  %108 3
+%116 = OpCompositeExtract  %4  %108 4
+%117 = OpCompositeExtract  %4  %108 5
+OpRayQueryInitializeKHR %110 %106 %112 %113 %116 %114 %117 %115
+%118 = OpRayQueryGetIntersectionTypeKHR  %6  %110 %88
+%119 = OpIEqual  %8  %118 %88
+%120 = OpSelect  %6  %119 %50 %109
+%121 = OpRayQueryGetIntersectionInstanceCustomIndexKHR  %6  %110 %88
+%122 = OpRayQueryGetIntersectionInstanceIdKHR  %6  %110 %88
+%123 = OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR  %6  %110 %88
+%124 = OpRayQueryGetIntersectionGeometryIndexKHR  %6  %110 %88
+%125 = OpRayQueryGetIntersectionPrimitiveIndexKHR  %6  %110 %88
+%126 = OpRayQueryGetIntersectionTKHR  %3  %110 %88
+%127 = OpRayQueryGetIntersectionBarycentricsKHR  %7  %110 %88
+%128 = OpRayQueryGetIntersectionFrontFaceKHR  %8  %110 %88
+%129 = OpRayQueryGetIntersectionObjectToWorldKHR  %9  %110 %88
+%130 = OpRayQueryGetIntersectionWorldToObjectKHR  %9  %110 %88
+%131 = OpCompositeConstruct  %10  %120 %126 %121 %122 %123 %124 %125 %127 %128 %129 %130
+%132 = OpCompositeExtract  %6  %131 0
+%133 = OpIEqual  %8  %132 %109
+%134 = OpSelect  %6  %133 %50 %88
+%135 = OpAccessChain  %94  %107 %88
+OpStore %135 %134
+OpReturn
 OpFunctionEnd
\ No newline at end of file