From e7c7017d2e364938188a1a5876a2b1b4d0dc6e27 Mon Sep 17 00:00:00 2001 From: Patryk Wychowaniec Date: Fri, 5 Jan 2024 19:06:55 +0100 Subject: [PATCH] [naga wgsl-in] Fix parsing `break if`s Closes https://github.com/gfx-rs/wgpu/issues/4982. --- naga/src/front/wgsl/lower/mod.rs | 4 +- naga/tests/in/break-if.wgsl | 12 ++ .../tests/out/glsl/break-if.main.Compute.glsl | 17 ++ naga/tests/out/hlsl/break-if.hlsl | 19 +++ naga/tests/out/msl/break-if.msl | 18 +++ naga/tests/out/spv/break-if.spvasm | 146 +++++++++++------- naga/tests/out/wgsl/break-if.wgsl | 14 ++ 7 files changed, 169 insertions(+), 61 deletions(-) diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index c38c57267..ba9b49e13 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -1328,7 +1328,9 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let mut emitter = Emitter::default(); emitter.start(&ctx.function.expressions); let break_if = break_if - .map(|expr| self.expression(expr, &mut ctx.as_expression(block, &mut emitter))) + .map(|expr| { + self.expression(expr, &mut ctx.as_expression(&mut continuing, &mut emitter)) + }) .transpose()?; continuing.extend(emitter.finish(&ctx.function.expressions)); diff --git a/naga/tests/in/break-if.wgsl b/naga/tests/in/break-if.wgsl index a948edf14..27536edb3 100644 --- a/naga/tests/in/break-if.wgsl +++ b/naga/tests/in/break-if.wgsl @@ -30,3 +30,15 @@ fn breakIf(a: bool) { } } } + +fn breakIfSeparateVariable() { + var counter = 0u; + + loop { + counter += 1u; + + continuing { + break if counter == 5u; + } + } +} diff --git a/naga/tests/out/glsl/break-if.main.Compute.glsl b/naga/tests/out/glsl/break-if.main.Compute.glsl index b4554d5d3..ba96de487 100644 --- a/naga/tests/out/glsl/break-if.main.Compute.glsl +++ b/naga/tests/out/glsl/break-if.main.Compute.glsl @@ -57,6 +57,23 @@ void breakIf(bool a_1) { return; } +void breakIfSeparateVariable() { + uint counter = 0u; + bool loop_init_3 = true; + while(true) { + if (!loop_init_3) { + uint _e5 = counter; + if ((_e5 == 5u)) { + break; + } + } + loop_init_3 = false; + uint _e3 = counter; + counter = (_e3 + 1u); + } + return; +} + void main() { return; } diff --git a/naga/tests/out/hlsl/break-if.hlsl b/naga/tests/out/hlsl/break-if.hlsl index f1aaaac09..56b7b48a2 100644 --- a/naga/tests/out/hlsl/break-if.hlsl +++ b/naga/tests/out/hlsl/break-if.hlsl @@ -54,6 +54,25 @@ void breakIf(bool a_1) return; } +void breakIfSeparateVariable() +{ + uint counter = 0u; + + bool loop_init_3 = true; + while(true) { + if (!loop_init_3) { + uint _expr5 = counter; + if ((_expr5 == 5u)) { + break; + } + } + loop_init_3 = false; + uint _expr3 = counter; + counter = (_expr3 + 1u); + } + return; +} + [numthreads(1, 1, 1)] void main() { diff --git a/naga/tests/out/msl/break-if.msl b/naga/tests/out/msl/break-if.msl index 657fdf9f7..8c0d9343b 100644 --- a/naga/tests/out/msl/break-if.msl +++ b/naga/tests/out/msl/break-if.msl @@ -61,6 +61,24 @@ void breakIf( return; } +void breakIfSeparateVariable( +) { + uint counter = 0u; + bool loop_init_3 = true; + while(true) { + if (!loop_init_3) { + uint _e5 = counter; + if (counter == 5u) { + break; + } + } + loop_init_3 = false; + uint _e3 = counter; + counter = _e3 + 1u; + } + return; +} + kernel void main_( ) { return; diff --git a/naga/tests/out/spv/break-if.spvasm b/naga/tests/out/spv/break-if.spvasm index ea944130e..8c8c3edf4 100644 --- a/naga/tests/out/spv/break-if.spvasm +++ b/naga/tests/out/spv/break-if.spvasm @@ -1,88 +1,114 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 50 +; Bound: 67 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 -OpEntryPoint GLCompute %48 "main" -OpExecutionMode %48 LocalSize 1 1 1 +OpEntryPoint GLCompute %65 "main" +OpExecutionMode %65 LocalSize 1 1 1 %2 = OpTypeVoid %3 = OpTypeBool -%6 = OpTypeFunction %2 -%7 = OpConstantTrue %3 -%16 = OpTypeFunction %2 %3 -%18 = OpTypePointer Function %3 -%19 = OpConstantNull %3 -%21 = OpConstantNull %3 -%35 = OpConstantNull %3 -%37 = OpConstantNull %3 -%5 = OpFunction %2 None %6 -%4 = OpLabel -OpBranch %8 -%8 = OpLabel +%4 = OpTypeInt 32 0 +%7 = OpTypeFunction %2 +%8 = OpConstantTrue %3 +%17 = OpTypeFunction %2 %3 +%19 = OpTypePointer Function %3 +%20 = OpConstantNull %3 +%22 = OpConstantNull %3 +%36 = OpConstantNull %3 +%38 = OpConstantNull %3 +%50 = OpConstant %4 0 +%51 = OpConstant %4 1 +%52 = OpConstant %4 5 +%54 = OpTypePointer Function %4 +%6 = OpFunction %2 None %7 +%5 = OpLabel OpBranch %9 %9 = OpLabel -OpLoopMerge %10 %12 None -OpBranch %11 -%11 = OpLabel +OpBranch %10 +%10 = OpLabel +OpLoopMerge %11 %13 None OpBranch %12 %12 = OpLabel -OpBranchConditional %7 %10 %9 -%10 = OpLabel +OpBranch %13 +%13 = OpLabel +OpBranchConditional %8 %11 %10 +%11 = OpLabel OpReturn OpFunctionEnd -%15 = OpFunction %2 None %16 -%14 = OpFunctionParameter %3 -%13 = OpLabel -%17 = OpVariable %18 Function %19 -%20 = OpVariable %18 Function %21 -OpBranch %22 -%22 = OpLabel +%16 = OpFunction %2 None %17 +%15 = OpFunctionParameter %3 +%14 = OpLabel +%18 = OpVariable %19 Function %20 +%21 = OpVariable %19 Function %22 OpBranch %23 %23 = OpLabel -OpLoopMerge %24 %26 None -OpBranch %25 -%25 = OpLabel +OpBranch %24 +%24 = OpLabel +OpLoopMerge %25 %27 None OpBranch %26 %26 = OpLabel -OpStore %17 %14 -%27 = OpLoad %3 %17 -%28 = OpLogicalNotEqual %3 %14 %27 -OpStore %20 %28 -%29 = OpLoad %3 %20 -%30 = OpLogicalEqual %3 %14 %29 -OpBranchConditional %30 %24 %23 -%24 = OpLabel +OpBranch %27 +%27 = OpLabel +OpStore %18 %15 +%28 = OpLoad %3 %18 +%29 = OpLogicalNotEqual %3 %15 %28 +OpStore %21 %29 +%30 = OpLoad %3 %21 +%31 = OpLogicalEqual %3 %15 %30 +OpBranchConditional %31 %25 %24 +%25 = OpLabel OpReturn OpFunctionEnd -%33 = OpFunction %2 None %16 -%32 = OpFunctionParameter %3 -%31 = OpLabel -%34 = OpVariable %18 Function %35 -%36 = OpVariable %18 Function %37 -OpBranch %38 -%38 = OpLabel +%34 = OpFunction %2 None %17 +%33 = OpFunctionParameter %3 +%32 = OpLabel +%35 = OpVariable %19 Function %36 +%37 = OpVariable %19 Function %38 OpBranch %39 %39 = OpLabel -OpLoopMerge %40 %42 None -OpBranch %41 -%41 = OpLabel -OpStore %34 %32 -%43 = OpLoad %3 %34 -%44 = OpLogicalNotEqual %3 %32 %43 -OpStore %36 %44 +OpBranch %40 +%40 = OpLabel +OpLoopMerge %41 %43 None OpBranch %42 %42 = OpLabel -%45 = OpLoad %3 %36 -%46 = OpLogicalEqual %3 %32 %45 -OpBranchConditional %46 %40 %39 -%40 = OpLabel +OpStore %35 %33 +%44 = OpLoad %3 %35 +%45 = OpLogicalNotEqual %3 %33 %44 +OpStore %37 %45 +OpBranch %43 +%43 = OpLabel +%46 = OpLoad %3 %37 +%47 = OpLogicalEqual %3 %33 %46 +OpBranchConditional %47 %41 %40 +%41 = OpLabel OpReturn OpFunctionEnd -%48 = OpFunction %2 None %6 -%47 = OpLabel -OpBranch %49 -%49 = OpLabel +%49 = OpFunction %2 None %7 +%48 = OpLabel +%53 = OpVariable %54 Function %50 +OpBranch %55 +%55 = OpLabel +OpBranch %56 +%56 = OpLabel +OpLoopMerge %57 %59 None +OpBranch %58 +%58 = OpLabel +%60 = OpLoad %4 %53 +%61 = OpIAdd %4 %60 %51 +OpStore %53 %61 +OpBranch %59 +%59 = OpLabel +%62 = OpLoad %4 %53 +%63 = OpIEqual %3 %62 %52 +OpBranchConditional %63 %57 %56 +%57 = OpLabel +OpReturn +OpFunctionEnd +%65 = OpFunction %2 None %7 +%64 = OpLabel +OpBranch %66 +%66 = OpLabel OpReturn OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/out/wgsl/break-if.wgsl b/naga/tests/out/wgsl/break-if.wgsl index 6e65c5215..c3d45a50a 100644 --- a/naga/tests/out/wgsl/break-if.wgsl +++ b/naga/tests/out/wgsl/break-if.wgsl @@ -39,6 +39,20 @@ fn breakIf(a_1: bool) { return; } +fn breakIfSeparateVariable() { + var counter: u32 = 0u; + + loop { + let _e3 = counter; + counter = (_e3 + 1u); + continuing { + let _e5 = counter; + break if (_e5 == 5u); + } + } + return; +} + @compute @workgroup_size(1, 1, 1) fn main() { return;