[naga wgsl-in] Allow override expressions as local var initializers.

Allow `LocalVariable::init` to be an override expression.

Note that this is unrelated to WGSL compliance. The WGSL front end
already accepts any sort of expression as an initializer for
`LocalVariable`s, but initialization by an override expression was
handled in the same way as initialization by a runtime expression, via
an explicit `Store` statement.

This commit merely lets us skip the `Store` when the initializer is an
override expression, producing slightly cleaner output in some cases.
This commit is contained in:
Jim Blandy 2024-03-26 15:25:57 -07:00 committed by Teodor Tanasoaia
parent 7df0aa6364
commit 2ad95b2774
11 changed files with 76 additions and 96 deletions

View File

@ -304,6 +304,13 @@ fn process_function(
filter_emits_in_block(&mut function.body, &function.expressions);
// Update local expression initializers.
for (_, local) in function.local_variables.iter_mut() {
if let &mut Some(ref mut init) = &mut local.init {
*init = adjusted_local_expressions[init.index()];
}
}
// We've changed the keys of `function.named_expression`, so we have to
// rebuild it from scratch.
let named_expressions = mem::take(&mut function.named_expressions);

View File

@ -1317,7 +1317,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
// expression, so its value depends on the
// state at the point of initialization.
if is_inside_loop
|| !ctx.local_expression_kind_tracker.is_const(init)
|| !ctx.local_expression_kind_tracker.is_const_or_override(init)
{
(None, Some(init))
} else {

View File

@ -998,7 +998,7 @@ pub struct LocalVariable {
///
/// This handle refers to this `LocalVariable`'s function's
/// [`expressions`] arena, but it is required to be an evaluated
/// constant expression.
/// override expression.
///
/// [`expressions`]: Function::expressions
pub init: Option<Handle<Expression>>,

View File

@ -54,8 +54,8 @@ pub enum LocalVariableError {
InvalidType(Handle<crate::Type>),
#[error("Initializer doesn't match the variable type")]
InitializerType,
#[error("Initializer is not const")]
NonConstInitializer,
#[error("Initializer is not a const or override expression")]
NonConstOrOverrideInitializer,
}
#[derive(Clone, Debug, thiserror::Error)]
@ -945,8 +945,8 @@ impl super::Validator {
return Err(LocalVariableError::InitializerType);
}
if !local_expr_kind.is_const(init) {
return Err(LocalVariableError::NonConstInitializer);
if !local_expr_kind.is_const_or_override(init) {
return Err(LocalVariableError::NonConstOrOverrideInitializer);
}
}

View File

@ -51,18 +51,6 @@
width: 4,
))),
),
(
uniformity: (
non_uniform_result: Some(4),
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Value(Pointer(
base: 2,
space: Function,
)),
),
(
uniformity: (
non_uniform_result: None,
@ -83,7 +71,7 @@
),
(
uniformity: (
non_uniform_result: Some(7),
non_uniform_result: Some(6),
requirements: (""),
),
ref_count: 1,
@ -95,7 +83,7 @@
),
(
uniformity: (
non_uniform_result: Some(8),
non_uniform_result: Some(7),
requirements: (""),
),
ref_count: 1,
@ -107,7 +95,7 @@
),
(
uniformity: (
non_uniform_result: Some(8),
non_uniform_result: Some(7),
requirements: (""),
),
ref_count: 1,
@ -128,7 +116,7 @@
),
(
uniformity: (
non_uniform_result: Some(8),
non_uniform_result: Some(7),
requirements: (""),
),
ref_count: 1,
@ -140,7 +128,7 @@
),
(
uniformity: (
non_uniform_result: Some(12),
non_uniform_result: Some(11),
requirements: (""),
),
ref_count: 1,

View File

@ -17,13 +17,12 @@ float gain_x_10_ = 11.0;
void main() {
float t = 0.0;
float t = 23.0;
bool x = false;
float gain_x_100_ = 0.0;
t = 23.0;
x = true;
float _e10 = gain_x_10_;
gain_x_100_ = (_e10 * 10.0);
float _e9 = gain_x_10_;
gain_x_100_ = (_e9 * 10.0);
return;
}

View File

@ -11,13 +11,12 @@ static float gain_x_10_ = 11.0;
[numthreads(1, 1, 1)]
void main()
{
float t = (float)0;
float t = 23.0;
bool x = (bool)0;
float gain_x_100_ = (float)0;
t = 23.0;
x = true;
float _expr10 = gain_x_10_;
gain_x_100_ = (_expr10 * 10.0);
float _expr9 = gain_x_10_;
gain_x_100_ = (_expr9 * 10.0);
return;
}

View File

@ -109,7 +109,7 @@
(
name: Some("t"),
ty: 2,
init: None,
init: Some(3),
),
(
name: Some("x"),
@ -130,56 +130,51 @@
left: 1,
right: 2,
),
LocalVariable(1),
Override(1),
Unary(
op: LogicalNot,
expr: 5,
expr: 4,
),
LocalVariable(2),
GlobalVariable(1),
Load(
pointer: 8,
pointer: 7,
),
Literal(F32(10.0)),
Binary(
op: Multiply,
left: 9,
right: 10,
left: 8,
right: 9,
),
LocalVariable(3),
],
named_expressions: {
6: "a",
5: "a",
},
body: [
Emit((
start: 2,
end: 3,
)),
Emit((
start: 4,
end: 5,
)),
Store(
pointer: 4,
value: 3,
pointer: 6,
value: 5,
),
Emit((
start: 5,
end: 6,
)),
Store(
pointer: 7,
value: 6,
),
Emit((
start: 8,
end: 9,
start: 7,
end: 8,
)),
Emit((
start: 10,
end: 11,
start: 9,
end: 10,
)),
Store(
pointer: 12,
value: 11,
pointer: 11,
value: 10,
),
Return(
value: None,

View File

@ -109,7 +109,7 @@
(
name: Some("t"),
ty: 2,
init: None,
init: Some(3),
),
(
name: Some("x"),
@ -130,56 +130,51 @@
left: 1,
right: 2,
),
LocalVariable(1),
Override(1),
Unary(
op: LogicalNot,
expr: 5,
expr: 4,
),
LocalVariable(2),
GlobalVariable(1),
Load(
pointer: 8,
pointer: 7,
),
Literal(F32(10.0)),
Binary(
op: Multiply,
left: 9,
right: 10,
left: 8,
right: 9,
),
LocalVariable(3),
],
named_expressions: {
6: "a",
5: "a",
},
body: [
Emit((
start: 2,
end: 3,
)),
Emit((
start: 4,
end: 5,
)),
Store(
pointer: 4,
value: 3,
pointer: 6,
value: 5,
),
Emit((
start: 5,
end: 6,
)),
Store(
pointer: 7,
value: 6,
),
Emit((
start: 8,
end: 9,
start: 7,
end: 8,
)),
Emit((
start: 10,
end: 11,
start: 9,
end: 10,
)),
Store(
pointer: 12,
value: 11,
pointer: 11,
value: 10,
),
Return(
value: None,

View File

@ -15,12 +15,11 @@ constant float inferred_f32_ = 2.718;
kernel void main_(
) {
float gain_x_10_ = 11.0;
float t = {};
float t = 23.0;
bool x = {};
float gain_x_100_ = {};
t = 23.0;
x = true;
float _e10 = gain_x_10_;
gain_x_100_ = _e10 * 10.0;
float _e9 = gain_x_10_;
gain_x_100_ = _e9 * 10.0;
return;
}

View File

@ -1,7 +1,7 @@
; SPIR-V
; Version: 1.0
; Generator: rspirv
; Bound: 32
; Bound: 31
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
@ -25,21 +25,19 @@ OpExecutionMode %18 LocalSize 1 1 1
%19 = OpTypeFunction %2
%20 = OpConstant %4 23.0
%22 = OpTypePointer Function %4
%23 = OpConstantNull %4
%25 = OpTypePointer Function %3
%26 = OpConstantNull %3
%28 = OpConstantNull %4
%24 = OpTypePointer Function %3
%25 = OpConstantNull %3
%27 = OpConstantNull %4
%18 = OpFunction %2 None %19
%17 = OpLabel
%21 = OpVariable %22 Function %23
%24 = OpVariable %25 Function %26
%27 = OpVariable %22 Function %28
OpBranch %29
%29 = OpLabel
OpStore %21 %20
OpStore %24 %5
%30 = OpLoad %4 %15
%31 = OpFMul %4 %30 %13
OpStore %27 %31
%21 = OpVariable %22 Function %20
%23 = OpVariable %24 Function %25
%26 = OpVariable %22 Function %27
OpBranch %28
%28 = OpLabel
OpStore %23 %5
%29 = OpLoad %4 %15
%30 = OpFMul %4 %29 %13
OpStore %26 %30
OpReturn
OpFunctionEnd