From 67ef37ae991f72f06a58774c3866d716d1c9a9c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Capucho?= Date: Thu, 16 Jun 2022 22:45:21 +0100 Subject: [PATCH] Add support for 'break if' to IR, wgsl-in, and all backends. --- src/back/dot/mod.rs | 4 + src/back/glsl/mod.rs | 13 ++- src/back/hlsl/writer.rs | 17 ++- src/back/msl/writer.rs | 13 ++- src/back/spv/block.rs | 96 ++++++++++++++--- src/back/spv/writer.rs | 7 +- src/back/wgsl/writer.rs | 18 +++- src/front/glsl/parser/functions.rs | 3 + src/front/spv/function.rs | 6 +- src/front/spv/mod.rs | 1 + src/front/wgsl/mod.rs | 123 ++++++++++++++++++---- src/lib.rs | 14 ++- src/valid/analyzer.rs | 1 + src/valid/function.rs | 15 +++ tests/in/break-if.wgsl | 32 ++++++ tests/out/glsl/break-if.main.Compute.glsl | 65 ++++++++++++ tests/out/hlsl/break-if.hlsl | 64 +++++++++++ tests/out/hlsl/break-if.hlsl.config | 3 + tests/out/ir/collatz.ron | 1 + tests/out/ir/shadow.ron | 1 + tests/out/msl/break-if.msl | 69 ++++++++++++ tests/out/spv/break-if.spvasm | 88 ++++++++++++++++ tests/out/wgsl/break-if.wgsl | 47 +++++++++ tests/snapshots.rs | 4 + tests/wgsl-errors.rs | 41 ++++++++ 25 files changed, 698 insertions(+), 48 deletions(-) create mode 100644 tests/in/break-if.wgsl create mode 100644 tests/out/glsl/break-if.main.Compute.glsl create mode 100644 tests/out/hlsl/break-if.hlsl create mode 100644 tests/out/hlsl/break-if.hlsl.config create mode 100644 tests/out/msl/break-if.msl create mode 100644 tests/out/spv/break-if.spvasm create mode 100644 tests/out/wgsl/break-if.wgsl diff --git a/src/back/dot/mod.rs b/src/back/dot/mod.rs index 523e48ce5..265e40075 100644 --- a/src/back/dot/mod.rs +++ b/src/back/dot/mod.rs @@ -81,11 +81,15 @@ impl StatementGraph { S::Loop { ref body, ref continuing, + break_if, } => { let body_id = self.add(body); self.flow.push((id, body_id, "body")); let continuing_id = self.add(continuing); self.flow.push((body_id, continuing_id, "continuing")); + if let Some(expr) = break_if { + self.dependencies.push((id, expr, "break if")); + } "Loop" } S::Return { value } => { diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index a9d7efc8d..a59a0b9d7 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -1800,15 +1800,24 @@ impl<'a, W: Write> Writer<'a, W> { Statement::Loop { ref body, ref continuing, + break_if, } => { - if !continuing.is_empty() { + if !continuing.is_empty() || break_if.is_some() { let gate_name = self.namer.call("loop_init"); writeln!(self.out, "{}bool {} = true;", level, gate_name)?; writeln!(self.out, "{}while(true) {{", level)?; let l2 = level.next(); + let l3 = l2.next(); writeln!(self.out, "{}if (!{}) {{", l2, gate_name)?; for sta in continuing { - self.write_stmt(sta, ctx, l2.next())?; + self.write_stmt(sta, ctx, l3)?; + } + if let Some(condition) = break_if { + write!(self.out, "{}if (", l3)?; + self.write_expr(condition, ctx)?; + writeln!(self.out, ") {{")?; + writeln!(self.out, "{}break;", l3.next())?; + writeln!(self.out, "{}}}", l3)?; } writeln!(self.out, "{}}}", l2)?; writeln!(self.out, "{}{} = false;", level.next(), gate_name)?; diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index de8b02936..4a14cc6ba 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -1497,18 +1497,27 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { Statement::Loop { ref body, ref continuing, + break_if, } => { let l2 = level.next(); - if !continuing.is_empty() { + if !continuing.is_empty() || break_if.is_some() { let gate_name = self.namer.call("loop_init"); writeln!(self.out, "{}bool {} = true;", level, gate_name)?; writeln!(self.out, "{}while(true) {{", level)?; writeln!(self.out, "{}if (!{}) {{", l2, gate_name)?; + let l3 = l2.next(); for sta in continuing.iter() { - self.write_stmt(module, sta, func_ctx, l2.next())?; + self.write_stmt(module, sta, func_ctx, l3)?; } - writeln!(self.out, "{}}}", level.next())?; - writeln!(self.out, "{}{} = false;", level.next(), gate_name)?; + if let Some(condition) = break_if { + write!(self.out, "{}if (", l3)?; + self.write_expr(module, condition, func_ctx)?; + writeln!(self.out, ") {{")?; + writeln!(self.out, "{}break;", l3.next())?; + writeln!(self.out, "{}}}", l3)?; + } + writeln!(self.out, "{}}}", l2)?; + writeln!(self.out, "{}{} = false;", l2, gate_name)?; } else { writeln!(self.out, "{}while(true) {{", level)?; } diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 3992b6374..c35786f7a 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -2552,14 +2552,23 @@ impl Writer { crate::Statement::Loop { ref body, ref continuing, + break_if, } => { - if !continuing.is_empty() { + if !continuing.is_empty() || break_if.is_some() { let gate_name = self.namer.call("loop_init"); writeln!(self.out, "{}bool {} = true;", level, gate_name)?; writeln!(self.out, "{}while(true) {{", level)?; let lif = level.next(); + let lcontinuing = lif.next(); writeln!(self.out, "{}if (!{}) {{", lif, gate_name)?; - self.put_block(lif.next(), continuing, context)?; + self.put_block(lcontinuing, continuing, context)?; + if let Some(condition) = break_if { + write!(self.out, "{}if (", lcontinuing)?; + self.put_expression(condition, &context.expression, true)?; + writeln!(self.out, ") {{")?; + writeln!(self.out, "{}break;", lcontinuing.next())?; + writeln!(self.out, "{}}}", lcontinuing)?; + } writeln!(self.out, "{}}}", lif)?; writeln!(self.out, "{}{} = false;", lif, gate_name)?; } else { diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index 413f4df5d..1c475e49b 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -37,6 +37,28 @@ enum ExpressionPointer { }, } +/// The termination statement to be added to the end of the block +pub enum BlockExit { + /// Generates an OpReturn (void return) + Return, + /// Generates an OpBranch to the specified block + Branch { + /// The branch target block + target: Word, + }, + /// Translates a loop `break if` into an `OpBranchConditional` to the + /// merge block if true (the merge block is passed through [`LoopContext::break_id`] + /// or else to the loop header (passed through [`preamble_id`]) + /// + /// [`preamble_id`]: Self::BreakIf::preamble_id + BreakIf { + /// The condition of the `break if` + condition: Handle, + /// The loop header block id + preamble_id: Word, + }, +} + impl Writer { // Flip Y coordinate to adjust for coordinate space difference // between SPIR-V and our IR. @@ -1491,7 +1513,7 @@ impl<'w> BlockContext<'w> { &mut self, label_id: Word, statements: &[crate::Statement], - exit_id: Option, + exit: BlockExit, loop_context: LoopContext, ) -> Result<(), Error> { let mut block = Block::new(label_id); @@ -1508,7 +1530,12 @@ impl<'w> BlockContext<'w> { self.function.consume(block, Instruction::branch(scope_id)); let merge_id = self.gen_id(); - self.write_block(scope_id, block_statements, Some(merge_id), loop_context)?; + self.write_block( + scope_id, + block_statements, + BlockExit::Branch { target: merge_id }, + loop_context, + )?; block = Block::new(merge_id); } @@ -1546,10 +1573,20 @@ impl<'w> BlockContext<'w> { ); if let Some(block_id) = accept_id { - self.write_block(block_id, accept, Some(merge_id), loop_context)?; + self.write_block( + block_id, + accept, + BlockExit::Branch { target: merge_id }, + loop_context, + )?; } if let Some(block_id) = reject_id { - self.write_block(block_id, reject, Some(merge_id), loop_context)?; + self.write_block( + block_id, + reject, + BlockExit::Branch { target: merge_id }, + loop_context, + )?; } block = Block::new(merge_id); @@ -1611,7 +1648,9 @@ impl<'w> BlockContext<'w> { self.write_block( *label_id, &case.body, - Some(case_finish_id), + BlockExit::Branch { + target: case_finish_id, + }, inner_context, )?; } @@ -1619,7 +1658,12 @@ impl<'w> BlockContext<'w> { // If no default was encountered write a empty block to satisfy the presence of // a block the default label if !reached_default { - self.write_block(default_id, &[], Some(merge_id), inner_context)?; + self.write_block( + default_id, + &[], + BlockExit::Branch { target: merge_id }, + inner_context, + )?; } block = Block::new(merge_id); @@ -1627,6 +1671,7 @@ impl<'w> BlockContext<'w> { crate::Statement::Loop { ref body, ref continuing, + break_if, } => { let preamble_id = self.gen_id(); self.function @@ -1649,17 +1694,29 @@ impl<'w> BlockContext<'w> { self.write_block( body_id, body, - Some(continuing_id), + BlockExit::Branch { + target: continuing_id, + }, LoopContext { continuing_id: Some(continuing_id), break_id: Some(merge_id), }, )?; + let exit = match break_if { + Some(condition) => BlockExit::BreakIf { + condition, + preamble_id, + }, + None => BlockExit::Branch { + target: preamble_id, + }, + }; + self.write_block( continuing_id, continuing, - Some(preamble_id), + exit, LoopContext { continuing_id: None, break_id: Some(merge_id), @@ -1955,12 +2012,10 @@ impl<'w> BlockContext<'w> { } } - let termination = match exit_id { - Some(id) => Instruction::branch(id), - // This can happen if the last branch had all the paths - // leading out of the graph (i.e. returning). - // Or it may be the end of the self.function. - None => match self.ir_function.result { + let termination = match exit { + // We're generating code for the top-level Block of the function, so we + // need to end it with some kind of return instruction. + BlockExit::Return => match self.ir_function.result { Some(ref result) if self.function.entry_point_context.is_none() => { let type_id = self.get_type_id(LookupType::Handle(result.ty)); let null_id = self.writer.write_constant_null(type_id); @@ -1968,6 +2023,19 @@ impl<'w> BlockContext<'w> { } _ => Instruction::return_void(), }, + BlockExit::Branch { target } => Instruction::branch(target), + BlockExit::BreakIf { + condition, + preamble_id, + } => { + let condition_id = self.cached[condition]; + + Instruction::branch_conditional( + condition_id, + loop_context.break_id.unwrap(), + preamble_id, + ) + } }; self.function.consume(block, termination); diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index e9dae65f9..6e8758f0c 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -574,7 +574,12 @@ impl Writer { context .function .consume(prelude, Instruction::branch(main_id)); - context.write_block(main_id, &ir_function.body, None, LoopContext::default())?; + context.write_block( + main_id, + &ir_function.body, + super::block::BlockExit::Return, + LoopContext::default(), + )?; // Consume the `BlockContext`, ending its borrows and letting the // `Writer` steal back its cached expression table and temp_list. diff --git a/src/back/wgsl/writer.rs b/src/back/wgsl/writer.rs index 77ab496cb..8a7a71cdc 100644 --- a/src/back/wgsl/writer.rs +++ b/src/back/wgsl/writer.rs @@ -908,6 +908,7 @@ impl Writer { Statement::Loop { ref body, ref continuing, + break_if, } => { write!(self.out, "{}", level)?; writeln!(self.out, "loop {{")?; @@ -917,11 +918,26 @@ impl Writer { self.write_stmt(module, sta, func_ctx, l2)?; } - if !continuing.is_empty() { + // The continuing is optional so we don't need to write it if + // it is empty, but the `break if` counts as a continuing statement + // so even if `continuing` is empty we must generate it if a + // `break if` exists + if !continuing.is_empty() || break_if.is_some() { writeln!(self.out, "{}continuing {{", l2)?; for sta in continuing.iter() { self.write_stmt(module, sta, func_ctx, l2.next())?; } + + // The `break if` is always the last + // statement of the `continuing` block + if let Some(condition) = break_if { + // The trailing space is important + write!(self.out, "{}break if ", l2.next())?; + self.write_expr(module, condition, func_ctx)?; + // Close the `break if` statement + writeln!(self.out, ";")?; + } + writeln!(self.out, "{}}}", l2)?; } diff --git a/src/front/glsl/parser/functions.rs b/src/front/glsl/parser/functions.rs index 2f271b816..6cb9bf017 100644 --- a/src/front/glsl/parser/functions.rs +++ b/src/front/glsl/parser/functions.rs @@ -358,6 +358,7 @@ impl<'source> ParsingContext<'source> { Statement::Loop { body: loop_body, continuing: Block::new(), + break_if: None, }, meta, ); @@ -411,6 +412,7 @@ impl<'source> ParsingContext<'source> { Statement::Loop { body: loop_body, continuing: Block::new(), + break_if: None, }, meta, ); @@ -513,6 +515,7 @@ impl<'source> ParsingContext<'source> { Statement::Loop { body: block, continuing, + break_if: None, }, meta, ); diff --git a/src/front/spv/function.rs b/src/front/spv/function.rs index c57695601..956f93cf9 100644 --- a/src/front/spv/function.rs +++ b/src/front/spv/function.rs @@ -556,7 +556,11 @@ impl<'function> BlockContext<'function> { let continuing = lower_impl(blocks, bodies, continuing); block.push( - crate::Statement::Loop { body, continuing }, + crate::Statement::Loop { + body, + continuing, + break_if: None, + }, crate::Span::default(), ) } diff --git a/src/front/spv/mod.rs b/src/front/spv/mod.rs index 234efd016..9f5d710e1 100644 --- a/src/front/spv/mod.rs +++ b/src/front/spv/mod.rs @@ -3565,6 +3565,7 @@ impl> Parser { S::Loop { ref mut body, ref mut continuing, + break_if: _, } => { self.patch_statements(body, expressions, fun_parameter_sampling)?; self.patch_statements(continuing, expressions, fun_parameter_sampling)?; diff --git a/src/front/wgsl/mod.rs b/src/front/wgsl/mod.rs index f0e6ac51c..2f68534cf 100644 --- a/src/front/wgsl/mod.rs +++ b/src/front/wgsl/mod.rs @@ -125,6 +125,8 @@ pub enum Error<'a> { BadIncrDecrReferenceType(Span), InvalidResolve(ResolveError), InvalidForInitializer(Span), + /// A break if appeared outside of a continuing block + InvalidBreakIf(Span), InvalidGatherComponent(Span, u32), InvalidConstructorComponentType(Span, i32), InvalidIdentifierUnderscore(Span), @@ -307,6 +309,11 @@ impl<'a> Error<'a> { labels: vec![(bad_span.clone(), "not an assignment or function call".into())], notes: vec![], }, + Error::InvalidBreakIf(ref bad_span) => ParseError { + message: "A break if is only allowed in a continuing block".to_string(), + labels: vec![(bad_span.clone(), "not in a continuing block".into())], + notes: vec![], + }, Error::InvalidGatherComponent(ref bad_span, component) => ParseError { message: format!("textureGather component {} doesn't exist, must be 0, 1, 2, or 3", component), labels: vec![(bad_span.clone(), "invalid component".into())], @@ -3811,26 +3818,7 @@ impl Parser { Some(crate::Statement::Switch { selector, cases }) } - "loop" => { - let _ = lexer.next(); - let mut body = crate::Block::new(); - let mut continuing = crate::Block::new(); - lexer.expect(Token::Paren('{'))?; - - loop { - if lexer.skip(Token::Word("continuing")) { - continuing = self.parse_block(lexer, context.reborrow(), false)?; - lexer.expect(Token::Paren('}'))?; - break; - } - if lexer.skip(Token::Paren('}')) { - break; - } - self.parse_statement(lexer, context.reborrow(), &mut body, false)?; - } - - Some(crate::Statement::Loop { body, continuing }) - } + "loop" => Some(self.parse_loop(lexer, context.reborrow(), &mut emitter)?), "while" => { let _ = lexer.next(); let mut body = crate::Block::new(); @@ -3863,6 +3851,7 @@ impl Parser { Some(crate::Statement::Loop { body, continuing: crate::Block::new(), + break_if: None, }) } "for" => { @@ -3935,10 +3924,22 @@ impl Parser { self.parse_statement(lexer, context.reborrow(), &mut body, false)?; } - Some(crate::Statement::Loop { body, continuing }) + Some(crate::Statement::Loop { + body, + continuing, + break_if: None, + }) } "break" => { - let _ = lexer.next(); + let (_, mut span) = lexer.next(); + // Check if the next token is an `if`, this indicates + // that the user tried to type out a `break if` which + // is illegal in this position. + let (peeked_token, peeked_span) = lexer.peek(); + if let Token::Word("if") = peeked_token { + span.end = peeked_span.end; + return Err(Error::InvalidBreakIf(span)); + } Some(crate::Statement::Break) } "continue" => { @@ -4041,6 +4042,84 @@ impl Parser { Ok(()) } + fn parse_loop<'a>( + &mut self, + lexer: &mut Lexer<'a>, + mut context: StatementContext<'a, '_, '_>, + emitter: &mut super::Emitter, + ) -> Result> { + let _ = lexer.next(); + let mut body = crate::Block::new(); + let mut continuing = crate::Block::new(); + let mut break_if = None; + lexer.expect(Token::Paren('{'))?; + + loop { + if lexer.skip(Token::Word("continuing")) { + // Branch for the `continuing` block, this must be + // the last thing in the loop body + + // Expect a opening brace to start the continuing block + lexer.expect(Token::Paren('{'))?; + loop { + if lexer.skip(Token::Word("break")) { + // Branch for the `break if` statement, this statement + // has the form `break if ;` and must be the last + // statement in a continuing block + + // The break must be followed by an `if` to form + // the break if + lexer.expect(Token::Word("if"))?; + + // Start the emitter to begin parsing an expression + emitter.start(context.expressions); + let condition = self.parse_general_expression( + lexer, + context.as_expression(&mut body, emitter), + )?; + // Add all emits to the continuing body + continuing.extend(emitter.finish(context.expressions)); + // Set the condition of the break if to the newly parsed + // expression + break_if = Some(condition); + + // Expext a semicolon to close the statement + lexer.expect(Token::Separator(';'))?; + // Expect a closing brace to close the continuing block, + // since the break if must be the last statement + lexer.expect(Token::Paren('}'))?; + // Stop parsing the continuing block + break; + } else if lexer.skip(Token::Paren('}')) { + // If we encounter a closing brace it means we have reached + // the end of the continuing block and should stop processing + break; + } else { + // Otherwise try to parse a statement + self.parse_statement(lexer, context.reborrow(), &mut continuing, false)?; + } + } + // Since the continuing block must be the last part of the loop body, + // we expect to see a closing brace to end the loop body + lexer.expect(Token::Paren('}'))?; + break; + } + if lexer.skip(Token::Paren('}')) { + // If we encounter a closing brace it means we have reached + // the end of the loop body and should stop processing + break; + } + // Otherwise try to parse a statement + self.parse_statement(lexer, context.reborrow(), &mut body, false)?; + } + + Ok(crate::Statement::Loop { + body, + continuing, + break_if, + }) + } + fn parse_block<'a>( &mut self, lexer: &mut Lexer<'a>, diff --git a/src/lib.rs b/src/lib.rs index 60bbe0e28..29bcbeaf4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1439,11 +1439,23 @@ pub enum Statement { /// this loop. (It may have `Break` and `Continue` statements targeting /// loops or switches nested within the `continuing` block.) /// + /// If present, `break_if` is an expression which is evaluated after the + /// continuing block. If its value is true, control continues after the + /// `Loop` statement, rather than branching back to the top of body as + /// usual. The `break_if` expression corresponds to a "break if" statement + /// in WGSL, or a loop whose back edge is an `OpBranchConditional` + /// instruction in SPIR-V. + /// /// [`Break`]: Statement::Break /// [`Continue`]: Statement::Continue /// [`Kill`]: Statement::Kill /// [`Return`]: Statement::Return - Loop { body: Block, continuing: Block }, + /// [`break if`]: Self::Loop::break_if + Loop { + body: Block, + continuing: Block, + break_if: Option>, + }, /// Exits the innermost enclosing [`Loop`] or [`Switch`]. /// diff --git a/src/valid/analyzer.rs b/src/valid/analyzer.rs index 932e8b0f8..9a7130ff9 100644 --- a/src/valid/analyzer.rs +++ b/src/valid/analyzer.rs @@ -841,6 +841,7 @@ impl FunctionInfo { S::Loop { ref body, ref continuing, + break_if: _, } => { let body_uniformity = self.process_block(body, other_functions, disruptor, expression_arena)?; diff --git a/src/valid/function.rs b/src/valid/function.rs index b928ff517..67903601c 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -499,6 +499,7 @@ impl super::Validator { S::Loop { ref body, ref continuing, + break_if, } => { // special handling for block scoping is needed here, // because the continuing{} block inherits the scope @@ -520,6 +521,20 @@ impl super::Validator { &context.with_abilities(ControlFlowAbility::empty()), )? .stages; + + if let Some(condition) = break_if { + match *context.resolve_type(condition, &self.valid_expression_set)? { + Ti::Scalar { + kind: crate::ScalarKind::Bool, + width: _, + } => {} + _ => { + return Err(FunctionError::InvalidIfType(condition) + .with_span_handle(condition, context.expressions)) + } + } + } + for handle in self.valid_expression_list.drain(base_expression_count..) { self.valid_expression_set.remove(handle.index()); } diff --git a/tests/in/break-if.wgsl b/tests/in/break-if.wgsl new file mode 100644 index 000000000..a948edf14 --- /dev/null +++ b/tests/in/break-if.wgsl @@ -0,0 +1,32 @@ +@compute @workgroup_size(1) +fn main() {} + +fn breakIfEmpty() { + loop { + continuing { + break if true; + } + } +} + +fn breakIfEmptyBody(a: bool) { + loop { + continuing { + var b = a; + var c = a != b; + + break if a == c; + } + } +} + +fn breakIf(a: bool) { + loop { + var d = a; + var e = a != d; + + continuing { + break if a == e; + } + } +} diff --git a/tests/out/glsl/break-if.main.Compute.glsl b/tests/out/glsl/break-if.main.Compute.glsl new file mode 100644 index 000000000..025156af7 --- /dev/null +++ b/tests/out/glsl/break-if.main.Compute.glsl @@ -0,0 +1,65 @@ +#version 310 es + +precision highp float; +precision highp int; + +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + + +void breakIfEmpty() { + bool loop_init = true; + while(true) { + if (!loop_init) { + if (true) { + break; + } + } + loop_init = false; + } + return; +} + +void breakIfEmptyBody(bool a) { + bool b = false; + bool c = false; + bool loop_init_1 = true; + while(true) { + if (!loop_init_1) { + b = a; + bool _e2 = b; + c = (a != _e2); + bool _e5 = c; + bool unnamed = (a == _e5); + if (unnamed) { + break; + } + } + loop_init_1 = false; + } + return; +} + +void breakIf(bool a_1) { + bool d = false; + bool e = false; + bool loop_init_2 = true; + while(true) { + if (!loop_init_2) { + bool _e5 = e; + bool unnamed_1 = (a_1 == _e5); + if (unnamed_1) { + break; + } + } + loop_init_2 = false; + d = a_1; + bool _e2 = d; + e = (a_1 != _e2); + } + return; +} + +void main() { + return; +} + diff --git a/tests/out/hlsl/break-if.hlsl b/tests/out/hlsl/break-if.hlsl new file mode 100644 index 000000000..fc91096a9 --- /dev/null +++ b/tests/out/hlsl/break-if.hlsl @@ -0,0 +1,64 @@ + +void breakIfEmpty() +{ + bool loop_init = true; + while(true) { + if (!loop_init) { + if (true) { + break; + } + } + loop_init = false; + } + return; +} + +void breakIfEmptyBody(bool a) +{ + bool b = (bool)0; + bool c = (bool)0; + + bool loop_init_1 = true; + while(true) { + if (!loop_init_1) { + b = a; + bool _expr2 = b; + c = (a != _expr2); + bool _expr5 = c; + bool unnamed = (a == _expr5); + if (unnamed) { + break; + } + } + loop_init_1 = false; + } + return; +} + +void breakIf(bool a_1) +{ + bool d = (bool)0; + bool e = (bool)0; + + bool loop_init_2 = true; + while(true) { + if (!loop_init_2) { + bool _expr5 = e; + bool unnamed_1 = (a_1 == _expr5); + if (unnamed_1) { + break; + } + } + loop_init_2 = false; + d = a_1; + bool _expr2 = d; + e = (a_1 != _expr2); + } + return; +} + +[numthreads(1, 1, 1)] +void main() +{ + return; +} diff --git a/tests/out/hlsl/break-if.hlsl.config b/tests/out/hlsl/break-if.hlsl.config new file mode 100644 index 000000000..246c485cf --- /dev/null +++ b/tests/out/hlsl/break-if.hlsl.config @@ -0,0 +1,3 @@ +vertex=() +fragment=() +compute=(main:cs_5_1 ) diff --git a/tests/out/ir/collatz.ron b/tests/out/ir/collatz.ron index 50860b2cb..2f3d06d13 100644 --- a/tests/out/ir/collatz.ron +++ b/tests/out/ir/collatz.ron @@ -261,6 +261,7 @@ ), ], continuing: [], + break_if: None, ), Emit(( start: 24, diff --git a/tests/out/ir/shadow.ron b/tests/out/ir/shadow.ron index d9a07a96a..11411c89e 100644 --- a/tests/out/ir/shadow.ron +++ b/tests/out/ir/shadow.ron @@ -1320,6 +1320,7 @@ value: 120, ), ], + break_if: None, ), Emit(( start: 120, diff --git a/tests/out/msl/break-if.msl b/tests/out/msl/break-if.msl new file mode 100644 index 000000000..600cca60d --- /dev/null +++ b/tests/out/msl/break-if.msl @@ -0,0 +1,69 @@ +// language: metal2.0 +#include +#include + +using metal::uint; + + +void breakIfEmpty( +) { + bool loop_init = true; + while(true) { + if (!loop_init) { + if (true) { + break; + } + } + loop_init = false; + } + return; +} + +void breakIfEmptyBody( + bool a +) { + bool b = {}; + bool c = {}; + bool loop_init_1 = true; + while(true) { + if (!loop_init_1) { + b = a; + bool _e2 = b; + c = a != _e2; + bool _e5 = c; + bool unnamed = a == _e5; + if (a == c) { + break; + } + } + loop_init_1 = false; + } + return; +} + +void breakIf( + bool a_1 +) { + bool d = {}; + bool e = {}; + bool loop_init_2 = true; + while(true) { + if (!loop_init_2) { + bool _e5 = e; + bool unnamed_1 = a_1 == _e5; + if (a_1 == e) { + break; + } + } + loop_init_2 = false; + d = a_1; + bool _e2 = d; + e = a_1 != _e2; + } + return; +} + +kernel void main_( +) { + return; +} diff --git a/tests/out/spv/break-if.spvasm b/tests/out/spv/break-if.spvasm new file mode 100644 index 000000000..138fb0b61 --- /dev/null +++ b/tests/out/spv/break-if.spvasm @@ -0,0 +1,88 @@ +; SPIR-V +; Version: 1.1 +; Generator: rspirv +; Bound: 50 +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %48 "main" +OpExecutionMode %48 LocalSize 1 1 1 +%2 = OpTypeVoid +%4 = OpTypeBool +%3 = OpConstantTrue %4 +%7 = OpTypeFunction %2 +%14 = OpTypePointer Function %4 +%15 = OpConstantNull %4 +%17 = OpConstantNull %4 +%21 = OpTypeFunction %2 %4 +%32 = OpConstantNull %4 +%34 = OpConstantNull %4 +%6 = OpFunction %2 None %7 +%5 = OpLabel +OpBranch %8 +%8 = OpLabel +OpBranch %9 +%9 = OpLabel +OpLoopMerge %10 %12 None +OpBranch %11 +%11 = OpLabel +OpBranch %12 +%12 = OpLabel +OpBranchConditional %3 %10 %9 +%10 = OpLabel +OpReturn +OpFunctionEnd +%20 = OpFunction %2 None %21 +%19 = OpFunctionParameter %4 +%18 = OpLabel +%13 = OpVariable %14 Function %15 +%16 = OpVariable %14 Function %17 +OpBranch %22 +%22 = OpLabel +OpBranch %23 +%23 = OpLabel +OpLoopMerge %24 %26 None +OpBranch %25 +%25 = OpLabel +OpBranch %26 +%26 = OpLabel +OpStore %13 %19 +%27 = OpLoad %4 %13 +%28 = OpLogicalNotEqual %4 %19 %27 +OpStore %16 %28 +%29 = OpLoad %4 %16 +%30 = OpLogicalEqual %4 %19 %29 +OpBranchConditional %30 %24 %23 +%24 = OpLabel +OpReturn +OpFunctionEnd +%37 = OpFunction %2 None %21 +%36 = OpFunctionParameter %4 +%35 = OpLabel +%31 = OpVariable %14 Function %32 +%33 = OpVariable %14 Function %34 +OpBranch %38 +%38 = OpLabel +OpBranch %39 +%39 = OpLabel +OpLoopMerge %40 %42 None +OpBranch %41 +%41 = OpLabel +OpStore %31 %36 +%43 = OpLoad %4 %31 +%44 = OpLogicalNotEqual %4 %36 %43 +OpStore %33 %44 +OpBranch %42 +%42 = OpLabel +%45 = OpLoad %4 %33 +%46 = OpLogicalEqual %4 %36 %45 +OpBranchConditional %46 %40 %39 +%40 = OpLabel +OpReturn +OpFunctionEnd +%48 = OpFunction %2 None %7 +%47 = OpLabel +OpBranch %49 +%49 = OpLabel +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/break-if.wgsl b/tests/out/wgsl/break-if.wgsl new file mode 100644 index 000000000..04b232905 --- /dev/null +++ b/tests/out/wgsl/break-if.wgsl @@ -0,0 +1,47 @@ +fn breakIfEmpty() { + loop { + continuing { + break if true; + } + } + return; +} + +fn breakIfEmptyBody(a: bool) { + var b: bool; + var c: bool; + + loop { + continuing { + b = a; + let _e2 = b; + c = (a != _e2); + let _e5 = c; + _ = (a == _e5); + break if (a == _e5); + } + } + return; +} + +fn breakIf(a_1: bool) { + var d: bool; + var e: bool; + + loop { + d = a_1; + let _e2 = d; + e = (a_1 != _e2); + continuing { + let _e5 = e; + _ = (a_1 == _e5); + break if (a_1 == _e5); + } + } + return; +} + +@compute @workgroup_size(1, 1, 1) +fn main() { + return; +} diff --git a/tests/snapshots.rs b/tests/snapshots.rs index e36378556..a7f773490 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -531,6 +531,10 @@ fn convert_wgsl() { "binding-arrays", Targets::WGSL | Targets::HLSL | Targets::METAL | Targets::SPIRV, ), + ( + "break-if", + Targets::WGSL | Targets::GLSL | Targets::SPIRV | Targets::HLSL | Targets::METAL, + ), ]; for &(name, targets) in inputs.iter() { diff --git a/tests/wgsl-errors.rs b/tests/wgsl-errors.rs index 564958f81..8ccaf1597 100644 --- a/tests/wgsl-errors.rs +++ b/tests/wgsl-errors.rs @@ -1556,3 +1556,44 @@ fn host_shareable_types() { } } } + +#[test] +fn misplaced_break_if() { + check( + " + fn test_misplaced_break_if() { + loop { + break if true; + } + } + ", + r###"error: A break if is only allowed in a continuing block + ┌─ wgsl:4:17 + │ +4 │ break if true; + │ ^^^^^^^^ not in a continuing block + +"###, + ); +} + +#[test] +fn break_if_bad_condition() { + check_validation! { + " + fn test_break_if_bad_condition() { + loop { + continuing { + break if 1; + } + } + } + ": + Err( + naga::valid::ValidationError::Function { + error: naga::valid::FunctionError::InvalidIfType(_), + .. + }, + ) + } +}