Make default a switch case (#1477)

* Make default a switch case

Previously the default case of a switch statement was encoded as a block
in the statement but the wgsl spec defines it in such a way that the
default case ordering matters.

* [spv-out] Support for the new switch IR

* [dot-out] Use different labels for default cases
This commit is contained in:
João Capucho 2021-10-26 18:31:54 +01:00 committed by GitHub
parent 5cf11ab734
commit 63dbd38edc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 229 additions and 260 deletions

View File

@ -63,15 +63,16 @@ impl StatementGraph {
S::Switch {
selector,
ref cases,
ref default,
} => {
self.dependencies.push((id, selector, "selector"));
for case in cases {
let case_id = self.add(&case.body);
self.flow.push((id, case_id, "case"));
let label = match case.value {
crate::SwitchValue::Integer(_) => "case",
crate::SwitchValue::Default => "default",
};
self.flow.push((id, case_id, label));
}
let default_id = self.add(default);
self.flow.push((id, default_id, "default"));
"Switch"
}
S::Loop {

View File

@ -1468,7 +1468,6 @@ impl<'a, W: Write> Writer<'a, W> {
Statement::Switch {
selector,
ref cases,
ref default,
} => {
// Start the switch
write!(self.out, "{}", level)?;
@ -1486,7 +1485,12 @@ impl<'a, W: Write> Writer<'a, W> {
// Write all cases
let l2 = level.next();
for case in cases {
writeln!(self.out, "{}case {}{}:", l2, case.value, type_postfix)?;
match case.value {
crate::SwitchValue::Integer(value) => {
writeln!(self.out, "{}case {}{}:", l2, value, type_postfix)?
}
crate::SwitchValue::Default => writeln!(self.out, "{}default:", l2)?,
}
for sta in case.body.iter() {
self.write_stmt(sta, ctx, l2.next())?;
@ -1497,27 +1501,11 @@ impl<'a, W: Write> Writer<'a, W> {
// broken out of at the end of its body.
if case.fall_through {
writeln!(self.out, "{}/* fallthrough */", l2.next())?;
} else if !matches!(
case.body.last(),
Some(&Statement::Break)
| Some(&Statement::Continue)
| Some(&Statement::Return { .. })
| Some(&Statement::Kill)
) {
} else if case.body.last().map_or(true, |s| !s.is_terminator()) {
writeln!(self.out, "{}break;", l2.next())?;
}
}
// Only write the default block if the block isn't empty
// Writing default without a block is valid but it's more readable this way
if !default.is_empty() {
writeln!(self.out, "{}default:", level.next())?;
for sta in default {
self.write_stmt(sta, ctx, l2.next())?;
}
}
writeln!(self.out, "{}}}", level)?
}
// Loops in naga IR are based on wgsl loops, glsl can emulate the behaviour by using a

View File

@ -1358,7 +1358,6 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Statement::Switch {
selector,
ref cases,
ref default,
} => {
// Start the switch
write!(self.out, "{}", level)?;
@ -1378,11 +1377,16 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
let indent_level_2 = indent_level_1.next();
for case in cases {
writeln!(
self.out,
"{}case {}{}: {{",
indent_level_1, case.value, type_postfix
)?;
match case.value {
crate::SwitchValue::Integer(value) => writeln!(
self.out,
"{}case {}{}: {{",
indent_level_1, value, type_postfix
)?,
crate::SwitchValue::Default => {
writeln!(self.out, "{}default: {{", indent_level_1)?
}
}
if case.fall_through {
// Generate each fallthrough case statement in a new block. This is done to
@ -1401,25 +1405,13 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
if case.fall_through {
writeln!(self.out, "{}}}", indent_level_2)?;
} else {
} else if case.body.last().map_or(true, |s| !s.is_terminator()) {
writeln!(self.out, "{}break;", indent_level_2)?;
}
writeln!(self.out, "{}}}", indent_level_1)?;
}
// Only write the default block if the block isn't empty
// Writing default without a block is valid but it's more readable this way
if !default.is_empty() {
writeln!(self.out, "{}default: {{", indent_level_1)?;
for sta in default {
self.write_stmt(module, sta, func_ctx, indent_level_2)?;
}
writeln!(self.out, "{}}}", indent_level_1)?;
}
writeln!(self.out, "{}}}", level)?
}
}

View File

@ -311,3 +311,18 @@ impl crate::TypeInner {
}
}
}
impl crate::Statement {
/// Returns true if the statement directly terminates the current block
///
/// Used to decided wether case blocks require a explicit `break`
pub fn is_terminator(&self) -> bool {
match *self {
crate::Statement::Break
| crate::Statement::Continue
| crate::Statement::Return { .. }
| crate::Statement::Kill => true,
_ => false,
}
}
}

View File

@ -1476,7 +1476,6 @@ impl<W: Write> Writer<W> {
crate::Statement::Switch {
selector,
ref cases,
ref default,
} => {
write!(self.out, "{}switch(", level)?;
self.put_expression(selector, &context.expression, true)?;
@ -1490,16 +1489,22 @@ impl<W: Write> Writer<W> {
writeln!(self.out, ") {{")?;
let lcase = level.next();
for case in cases.iter() {
writeln!(self.out, "{}case {}{}: {{", lcase, case.value, type_postfix)?;
match case.value {
crate::SwitchValue::Integer(value) => {
writeln!(self.out, "{}case {}{}: {{", lcase, value, type_postfix)?;
}
crate::SwitchValue::Default => {
writeln!(self.out, "{}default: {{", lcase)?;
}
}
self.put_block(lcase.next(), &case.body, context)?;
if !case.fall_through {
if !case.fall_through
&& case.body.last().map_or(true, |s| !s.is_terminator())
{
writeln!(self.out, "{}break;", lcase.next())?;
}
writeln!(self.out, "{}}}", lcase)?;
}
writeln!(self.out, "{}default: {{", lcase)?;
self.put_block(lcase.next(), default, context)?;
writeln!(self.out, "{}}}", lcase)?;
writeln!(self.out, "{}}}", level)?;
}
crate::Statement::Loop {

View File

@ -1151,7 +1151,6 @@ impl<'w> BlockContext<'w> {
crate::Statement::Switch {
selector,
ref cases,
ref default,
} => {
let selector_id = self.cached[selector];
@ -1162,13 +1161,30 @@ impl<'w> BlockContext<'w> {
));
let default_id = self.gen_id();
let raw_cases = cases
.iter()
.map(|c| super::instructions::Case {
value: c.value as Word,
label_id: self.gen_id(),
})
.collect::<Vec<_>>();
let mut reached_default = false;
let mut raw_cases = Vec::with_capacity(cases.len());
let mut case_ids = Vec::with_capacity(cases.len());
for case in cases.iter() {
match case.value {
crate::SwitchValue::Integer(value) => {
let label_id = self.gen_id();
// No cases should be added after the default case is encountered
// since the default case catches all
if !reached_default {
raw_cases.push(super::instructions::Case {
value: value as Word,
label_id,
});
}
case_ids.push(label_id);
}
crate::SwitchValue::Default => {
case_ids.push(default_id);
reached_default = true;
}
}
}
self.function.consume(
block,
@ -1180,24 +1196,25 @@ impl<'w> BlockContext<'w> {
..loop_context
};
for (i, (case, raw_case)) in cases.iter().zip(raw_cases.iter()).enumerate() {
for (i, (case, label_id)) in cases.iter().zip(case_ids.iter()).enumerate() {
let case_finish_id = if case.fall_through {
match raw_cases.get(i + 1) {
Some(rc) => rc.label_id,
None => default_id,
}
case_ids[i + 1]
} else {
merge_id
};
self.write_block(
raw_case.label_id,
*label_id,
&case.body,
Some(case_finish_id),
inner_context,
)?;
}
self.write_block(default_id, default, Some(merge_id), inner_context)?;
// If no default was encountered write a empty block to satisfy the presence of
// a block the default label
if !reached_default {
self.write_block(default_id, &[], Some(merge_id), inner_context)?;
}
block = Block::new(merge_id);
}

View File

@ -877,7 +877,6 @@ impl<W: Write> Writer<W> {
Statement::Switch {
selector,
ref cases,
ref default,
} => {
// Start the switch
write!(self.out, "{}", level)?;
@ -885,11 +884,6 @@ impl<W: Write> Writer<W> {
self.write_expr(module, selector, func_ctx)?;
writeln!(self.out, ") {{")?;
// Write all cases
let mut write_case = true;
let all_fall_through = cases
.iter()
.all(|case| case.fall_through && case.body.is_empty());
let type_postfix = match *func_ctx.info[selector].ty.inner_with(&module.types) {
crate::TypeInner::Scalar {
kind: crate::ScalarKind::Uint,
@ -901,16 +895,13 @@ impl<W: Write> Writer<W> {
let l2 = level.next();
if !cases.is_empty() {
for case in cases {
if write_case {
write!(self.out, "{}case ", l2)?;
}
if !all_fall_through && case.fall_through && case.body.is_empty() {
write_case = false;
write!(self.out, "{}{}, ", case.value, type_postfix)?;
continue;
} else {
write_case = true;
writeln!(self.out, "{}{}: {{", case.value, type_postfix)?;
match case.value {
crate::SwitchValue::Integer(value) => {
writeln!(self.out, "{}case {}{}: {{", l2, value, type_postfix)?;
}
crate::SwitchValue::Default => {
writeln!(self.out, "{}default: {{", l2)?;
}
}
for sta in case.body.iter() {
@ -925,16 +916,6 @@ impl<W: Write> Writer<W> {
}
}
if !default.is_empty() {
writeln!(self.out, "{}default: {{", l2)?;
for sta in default {
self.write_stmt(module, sta, func_ctx, l2.next())?;
}
writeln!(self.out, "{}}}", l2)?
}
writeln!(self.out, "{}}}", level)?
}
Statement::Loop {

View File

@ -169,11 +169,10 @@ impl<'source> ParsingContext<'source> {
ctx.emit_start();
let mut cases = Vec::new();
let mut default = Block::new();
self.expect(parser, TokenValue::LeftBrace)?;
loop {
match self.expect_peek(parser)?.value {
let value = match self.expect_peek(parser)?.value {
TokenValue::Case => {
self.bump(parser)?;
let value = {
@ -204,73 +203,11 @@ impl<'source> ParsingContext<'source> {
}
}
};
self.expect(parser, TokenValue::Colon)?;
let mut body = Block::new();
let mut case_terminator = None;
loop {
match self.expect_peek(parser)?.value {
TokenValue::Case
| TokenValue::Default
| TokenValue::RightBrace => break,
_ => {
self.parse_statement(
parser,
ctx,
&mut body,
&mut case_terminator,
)?;
}
}
}
let mut fall_through = true;
if let Some(mut idx) = case_terminator {
if let Statement::Break = body[idx - 1] {
fall_through = false;
idx -= 1;
}
body.cull(idx..)
}
cases.push(SwitchCase {
value,
body,
fall_through,
})
crate::SwitchValue::Integer(value)
}
TokenValue::Default => {
let Token { meta, .. } = self.bump(parser)?;
self.expect(parser, TokenValue::Colon)?;
if !default.is_empty() {
parser.errors.push(Error {
kind: ErrorKind::SemanticError(
"Can only have one default case per switch statement"
.into(),
),
meta,
});
}
let mut default_terminator = None;
loop {
match self.expect_peek(parser)?.value {
TokenValue::Case | TokenValue::RightBrace => break,
_ => {
self.parse_statement(
parser,
ctx,
&mut default,
&mut default_terminator,
)?;
}
}
}
self.bump(parser)?;
crate::SwitchValue::Default
}
TokenValue::RightBrace => {
end_meta = self.bump(parser)?.meta;
@ -290,19 +227,45 @@ impl<'source> ParsingContext<'source> {
meta,
});
}
};
self.expect(parser, TokenValue::Colon)?;
let mut body = Block::new();
let mut case_terminator = None;
loop {
match self.expect_peek(parser)?.value {
TokenValue::Case | TokenValue::Default | TokenValue::RightBrace => {
break
}
_ => {
self.parse_statement(parser, ctx, &mut body, &mut case_terminator)?;
}
}
}
let mut fall_through = true;
if let Some(mut idx) = case_terminator {
if let Statement::Break = body[idx - 1] {
fall_through = false;
idx -= 1;
}
body.cull(idx..)
}
cases.push(SwitchCase {
value,
body,
fall_through,
})
}
meta.subsume(end_meta);
body.push(
Statement::Switch {
selector,
cases,
default,
},
meta,
);
body.push(Statement::Switch { selector, cases }, meta);
meta
}

View File

@ -566,33 +566,31 @@ impl<'function> BlockContext<'function> {
ref cases,
default,
} => {
let default = lower_impl(blocks, bodies, default);
let mut ir_cases: Vec<_> = cases
.iter()
.map(|&(value, body_idx)| {
let body = lower_impl(blocks, bodies, body_idx);
// Handle simple cases that would make a fallthrough statement unreachable code
let fall_through = body.last().map_or(true, |s| !s.is_terminator());
crate::SwitchCase {
value: crate::SwitchValue::Integer(value),
body,
fall_through,
}
})
.collect();
ir_cases.push(crate::SwitchCase {
value: crate::SwitchValue::Default,
body: lower_impl(blocks, bodies, default),
fall_through: false,
});
block.push(
crate::Statement::Switch {
selector,
cases: cases
.iter()
.map(|&(value, body_idx)| {
let body = lower_impl(blocks, bodies, body_idx);
// Handle simple cases that would make a fallthrough statement unreachable code
let fall_through = match body.last() {
Some(&crate::Statement::Break)
| Some(&crate::Statement::Continue)
| Some(&crate::Statement::Kill)
| Some(&crate::Statement::Return { .. }) => false,
_ => true,
};
crate::SwitchCase {
value,
body,
fall_through,
}
})
.collect(),
default,
cases: ir_cases,
},
crate::Span::default(),
)

View File

@ -3225,12 +3225,10 @@ impl<I: Iterator<Item = u32>> Parser<I> {
S::Switch {
selector: _,
ref mut cases,
ref mut default,
} => {
for case in cases.iter_mut() {
self.patch_statements(&mut case.body, expressions, fun_parameter_sampling)?;
}
self.patch_statements(default, expressions, fun_parameter_sampling)?;
}
S::Loop {
ref mut body,

View File

@ -3200,6 +3200,29 @@ impl Parser {
Ok(())
}
fn parse_switch_case_body<'a, 'out>(
&mut self,
lexer: &mut Lexer<'a>,
mut context: StatementContext<'a, '_, 'out>,
) -> Result<(bool, crate::Block), Error<'a>> {
let mut body = crate::Block::new();
lexer.expect(Token::Paren('{'))?;
let fall_through = loop {
// default statements
if lexer.skip(Token::Word("fallthrough")) {
lexer.expect(Token::Separator(';'))?;
lexer.expect(Token::Paren('}'))?;
break true;
}
if lexer.skip(Token::Paren('}')) {
break false;
}
self.parse_statement(lexer, context.reborrow(), &mut body, false)?;
};
Ok((fall_through, body))
}
fn parse_statement<'a, 'out>(
&mut self,
lexer: &mut Lexer<'a>,
@ -3495,7 +3518,6 @@ impl Parser {
block.extend(emitter.finish(context.expressions));
lexer.expect(Token::Paren('{'))?;
let mut cases = Vec::new();
let mut default = crate::Block::new();
loop {
// cases + default
@ -3503,7 +3525,6 @@ impl Parser {
(Token::Word("case"), _) => {
// parse a list of values
let value = loop {
// TODO: Switch statements also allow for floats, bools and unsigned integers. See https://www.w3.org/TR/WGSL/#switch-statement
let value = Self::parse_switch_value(lexer, uint)?;
if lexer.skip(Token::Separator(',')) {
if lexer.skip(Token::Separator(':')) {
@ -3514,41 +3535,30 @@ impl Parser {
break value;
}
cases.push(crate::SwitchCase {
value,
value: crate::SwitchValue::Integer(value),
body: crate::Block::new(),
fall_through: true,
});
};
let mut body = crate::Block::new();
lexer.expect(Token::Paren('{'))?;
let fall_through = loop {
// default statements
if lexer.skip(Token::Word("fallthrough")) {
lexer.expect(Token::Separator(';'))?;
lexer.expect(Token::Paren('}'))?;
break true;
}
if lexer.skip(Token::Paren('}')) {
break false;
}
self.parse_statement(
lexer,
context.reborrow(),
&mut body,
false,
)?;
};
let (fall_through, body) =
self.parse_switch_case_body(lexer, context.reborrow())?;
cases.push(crate::SwitchCase {
value,
value: crate::SwitchValue::Integer(value),
body,
fall_through,
});
}
(Token::Word("default"), _) => {
lexer.expect(Token::Separator(':'))?;
default = self.parse_block(lexer, context.reborrow(), false)?;
let (fall_through, body) =
self.parse_switch_case_body(lexer, context.reborrow())?;
cases.push(crate::SwitchCase {
value: crate::SwitchValue::Default,
body,
fall_through,
});
}
(Token::Paren('}'), _) => break,
other => {
@ -3557,11 +3567,7 @@ impl Parser {
}
}
Some(crate::Statement::Switch {
selector,
cases,
default,
})
Some(crate::Statement::Switch { selector, cases })
}
"loop" => {
let _ = lexer.next();

View File

@ -1301,6 +1301,16 @@ pub enum Expression {
pub use block::Block;
/// The value of the switch case
// Clone is used only for error reporting and is not intended for end users
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum SwitchValue {
Integer(i32),
Default,
}
/// A case for a switch statement.
// Clone is used only for error reporting and is not intended for end users
#[derive(Clone, Debug)]
@ -1308,7 +1318,7 @@ pub use block::Block;
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub struct SwitchCase {
/// Value, upon which the case is considered true.
pub value: i32,
pub value: SwitchValue,
/// Body of the case.
pub body: Block,
/// If true, the control flow continues to the next case in the list,
@ -1341,7 +1351,6 @@ pub enum Statement {
Switch {
selector: Handle<Expression>, //int
cases: Vec<SwitchCase>,
default: Block,
},
/// Executes a block repeatedly.

View File

@ -21,14 +21,12 @@ pub fn ensure_block_returns(block: &mut crate::Block) {
Some(&mut S::Switch {
selector: _,
ref mut cases,
ref mut default,
}) => {
for case in cases.iter_mut() {
if !case.fall_through {
ensure_block_returns(&mut case.body);
}
}
ensure_block_returns(default);
}
Some(&mut S::Emit(_))
| Some(&mut S::Break)

View File

@ -694,7 +694,6 @@ impl FunctionInfo {
S::Switch {
selector,
ref cases,
ref default,
} => {
let selector_nur = self.add_ref(selector);
let branch_disruptor =
@ -715,14 +714,7 @@ impl FunctionInfo {
};
uniformity = uniformity | case_uniformity;
}
// using the disruptor inherited from the last fall-through chain
let default_exit = self.process_block(
default,
other_functions,
case_disruptor,
expression_arena,
)?;
uniformity | default_exit
uniformity
}
S::Loop {
ref body,

View File

@ -106,6 +106,8 @@ pub enum FunctionError {
InvalidSwitchType(Handle<crate::Expression>),
#[error("Multiple `switch` cases for {0:?} are present")]
ConflictingSwitchCase(i32),
#[error("Multiple `default` cases are present")]
MultipleDefaultCases,
#[error("The pointer {0:?} doesn't relate to a valid destination for a store")]
InvalidStorePointer(Handle<crate::Expression>),
#[error("The value {0:?} can not be stored")]
@ -421,7 +423,6 @@ impl super::Validator {
S::Switch {
selector,
ref cases,
ref default,
} => {
match *context.resolve_type(selector, &self.valid_expression_set)? {
Ti::Scalar {
@ -438,16 +439,34 @@ impl super::Validator {
}
}
self.select_cases.clear();
let mut default = false;
for case in cases {
if !self.select_cases.insert(case.value) {
return Err(FunctionError::ConflictingSwitchCase(case.value)
.with_span_static(
case.body
.span_iter()
.next()
.map_or(Default::default(), |(_, s)| *s),
"conflicting switch arm here",
));
match case.value {
crate::SwitchValue::Integer(value) => {
if !self.select_cases.insert(value) {
return Err(FunctionError::ConflictingSwitchCase(value)
.with_span_static(
case.body
.span_iter()
.next()
.map_or(Default::default(), |(_, s)| *s),
"conflicting switch arm here",
));
}
}
crate::SwitchValue::Default => {
if default {
return Err(FunctionError::MultipleDefaultCases
.with_span_static(
case.body
.span_iter()
.next()
.map_or(Default::default(), |(_, s)| *s),
"duplicated switch arm here",
));
}
default = true
}
}
}
let pass_through_abilities = context.abilities
@ -457,7 +476,6 @@ impl super::Validator {
for case in cases {
stages &= self.validate_block(&case.body, &sub_context)?.stages;
}
stages &= self.validate_block(default, &sub_context)?.stages;
}
S::Loop {
ref body,

View File

@ -39,6 +39,7 @@ void main() {
switch(1) {
default:
pos = 1;
break;
}
int _e4 = pos;
switch(_e4) {
@ -55,6 +56,7 @@ void main() {
break;
default:
pos = 3;
break;
}
switch(0u) {
case 0u:

View File

@ -13,7 +13,6 @@ void switch_case_break()
switch(0) {
case 0: {
break;
break;
}
}
return;
@ -25,7 +24,6 @@ void loop_switch_continue(int x)
switch(x) {
case 1: {
continue;
break;
}
}
}
@ -42,6 +40,7 @@ void main(uint3 global_id : SV_DispatchThreadID)
switch(1) {
default: {
pos = 1;
break;
}
}
int _expr4 = pos;
@ -49,7 +48,6 @@ void main(uint3 global_id : SV_DispatchThreadID)
case 1: {
pos = 0;
break;
break;
}
case 2: {
pos = 1;
@ -66,6 +64,7 @@ void main(uint3 global_id : SV_DispatchThreadID)
}
default: {
pos = 3;
break;
}
}
switch(0u) {
@ -78,12 +77,10 @@ void main(uint3 global_id : SV_DispatchThreadID)
case 1: {
pos = 0;
break;
break;
}
case 2: {
pos = 1;
return;
break;
}
case 3: {
/* fallthrough */
@ -93,7 +90,6 @@ void main(uint3 global_id : SV_DispatchThreadID)
}
case 4: {
return;
break;
}
default: {
pos = 3;

View File

@ -18,9 +18,6 @@ void switch_case_break(
switch(0) {
case 0: {
break;
break;
}
default: {
}
}
return;
@ -33,9 +30,6 @@ void loop_switch_continue(
switch(x) {
case 1: {
continue;
break;
}
default: {
}
}
}
@ -53,6 +47,7 @@ kernel void main1(
switch(1) {
default: {
pos = 1;
break;
}
}
int _e4 = pos;
@ -60,7 +55,6 @@ kernel void main1(
case 1: {
pos = 0;
break;
break;
}
case 2: {
pos = 1;
@ -74,33 +68,29 @@ kernel void main1(
}
default: {
pos = 3;
break;
}
}
switch(0u) {
case 0u: {
break;
}
default: {
}
}
int _e10 = pos;
switch(_e10) {
case 1: {
pos = 0;
break;
break;
}
case 2: {
pos = 1;
return;
break;
}
case 3: {
pos = 2;
}
case 4: {
return;
break;
}
default: {
pos = 3;