evaluate override-expressions in functions

This commit is contained in:
teoxoy 2024-03-06 12:22:33 +01:00 committed by Teodor Tanasoaia
parent fd5c4db606
commit 3abdfde0ba
8 changed files with 472 additions and 14 deletions

View File

@ -1,10 +1,11 @@
use super::PipelineConstants;
use crate::{
proc::{ConstantEvaluator, ConstantEvaluatorError},
proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter},
valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator},
Constant, Expression, Handle, Literal, Module, Override, Scalar, Span, TypeInner, WithSpan,
Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar,
Span, Statement, SwitchCase, TypeInner, WithSpan,
};
use std::{borrow::Cow, collections::HashSet};
use std::{borrow::Cow, collections::HashSet, mem};
use thiserror::Error;
#[derive(Error, Debug, Clone)]
@ -175,6 +176,18 @@ pub(super) fn process_overrides<'a>(
}
}
let mut functions = mem::take(&mut module.functions);
for (_, function) in functions.iter_mut() {
process_function(&mut module, &override_map, function)?;
}
let _ = mem::replace(&mut module.functions, functions);
let mut entry_points = mem::take(&mut module.entry_points);
for ep in entry_points.iter_mut() {
process_function(&mut module, &override_map, &mut ep.function)?;
}
let _ = mem::replace(&mut module.entry_points, entry_points);
// Now that the global expression arena has changed, we need to
// recompute those expressions' types. For the time being, do a
// full re-validation.
@ -237,6 +250,64 @@ fn process_override(
Ok(h)
}
/// Replaces all `Expression::Override`s in this function's expression arena
/// with `Expression::Constant` and evaluates all expressions in its arena.
fn process_function(
module: &mut Module,
override_map: &[Handle<Constant>],
function: &mut Function,
) -> Result<(), ConstantEvaluatorError> {
// A map from original local expression handles to
// handles in the new, local expression arena.
let mut adjusted_local_expressions = Vec::with_capacity(function.expressions.len());
let mut local_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
let mut expressions = mem::take(&mut function.expressions);
// Dummy `emitter` and `block` for the constant evaluator.
// We can ignore the concept of emitting expressions here since
// expressions have already been covered by a `Statement::Emit`
// in the frontend.
// The only thing we might have to do is remove some expressions
// that have been covered by a `Statement::Emit`. See the docs of
// `filter_emits_in_block` for the reasoning.
let mut emitter = Emitter::default();
let mut block = Block::new();
for (old_h, expr, span) in expressions.drain() {
let mut expr = match expr {
Expression::Override(h) => Expression::Constant(override_map[h.index()]),
expr => expr,
};
let mut evaluator = ConstantEvaluator::for_wgsl_function(
module,
&mut function.expressions,
&mut local_expression_kind_tracker,
&mut emitter,
&mut block,
);
adjust_expr(&adjusted_local_expressions, &mut expr);
let h = evaluator.try_eval_and_append(expr, span)?;
debug_assert_eq!(old_h.index(), adjusted_local_expressions.len());
adjusted_local_expressions.push(h);
}
adjust_block(&adjusted_local_expressions, &mut function.body);
let new_body = filter_emits_in_block(&function.body, &function.expressions);
let _ = mem::replace(&mut function.body, new_body);
let named_expressions = mem::take(&mut function.named_expressions);
for (expr_h, name) in named_expressions {
function
.named_expressions
.insert(adjusted_local_expressions[expr_h.index()], name);
}
Ok(())
}
/// Replace every expression handle in `expr` with its counterpart
/// given by `new_pos`.
fn adjust_expr(new_pos: &[Handle<Expression>], expr: &mut Expression) {
@ -409,6 +480,207 @@ fn adjust_expr(new_pos: &[Handle<Expression>], expr: &mut Expression) {
}
}
/// Replace every expression handle in `block` with its counterpart
/// given by `new_pos`.
fn adjust_block(new_pos: &[Handle<Expression>], block: &mut Block) {
for stmt in block.iter_mut() {
adjust_stmt(new_pos, stmt);
}
}
/// Replace every expression handle in `stmt` with its counterpart
/// given by `new_pos`.
fn adjust_stmt(new_pos: &[Handle<Expression>], stmt: &mut Statement) {
let adjust = |expr: &mut Handle<Expression>| {
*expr = new_pos[expr.index()];
};
match *stmt {
Statement::Emit(ref mut range) => {
if let Some((mut first, mut last)) = range.first_and_last() {
adjust(&mut first);
adjust(&mut last);
*range = Range::new_from_bounds(first, last);
}
}
Statement::Block(ref mut block) => {
adjust_block(new_pos, block);
}
Statement::If {
ref mut condition,
ref mut accept,
ref mut reject,
} => {
adjust(condition);
adjust_block(new_pos, accept);
adjust_block(new_pos, reject);
}
Statement::Switch {
ref mut selector,
ref mut cases,
} => {
adjust(selector);
for case in cases.iter_mut() {
adjust_block(new_pos, &mut case.body);
}
}
Statement::Loop {
ref mut body,
ref mut continuing,
ref mut break_if,
} => {
adjust_block(new_pos, body);
adjust_block(new_pos, continuing);
if let Some(e) = break_if.as_mut() {
adjust(e);
}
}
Statement::Return { ref mut value } => {
if let Some(e) = value.as_mut() {
adjust(e);
}
}
Statement::Store {
ref mut pointer,
ref mut value,
} => {
adjust(pointer);
adjust(value);
}
Statement::ImageStore {
ref mut image,
ref mut coordinate,
ref mut array_index,
ref mut value,
} => {
adjust(image);
adjust(coordinate);
if let Some(e) = array_index.as_mut() {
adjust(e);
}
adjust(value);
}
crate::Statement::Atomic {
ref mut pointer,
ref mut value,
ref mut result,
..
} => {
adjust(pointer);
adjust(value);
adjust(result);
}
Statement::WorkGroupUniformLoad {
ref mut pointer,
ref mut result,
} => {
adjust(pointer);
adjust(result);
}
Statement::Call {
ref mut arguments,
ref mut result,
..
} => {
for argument in arguments.iter_mut() {
adjust(argument);
}
if let Some(e) = result.as_mut() {
adjust(e);
}
}
Statement::RayQuery { ref mut query, .. } => {
adjust(query);
}
Statement::Break | Statement::Continue | Statement::Kill | Statement::Barrier(_) => {}
}
}
/// Filters out expressions that `needs_pre_emit`. This step is necessary after
/// const evaluation since unevaluated expressions could have been included in
/// `Statement::Emit`; but since they have been evaluated we need to filter those
/// out.
fn filter_emits_in_block(block: &Block, expressions: &Arena<Expression>) -> Block {
let mut out = Block::with_capacity(block.len());
for (stmt, span) in block.span_iter() {
match stmt {
&Statement::Emit(ref range) => {
let mut current = None;
for expr_h in range.clone() {
if expressions[expr_h].needs_pre_emit() {
if let Some((first, last)) = current {
out.push(Statement::Emit(Range::new_from_bounds(first, last)), *span);
}
current = None;
} else if let Some((_, ref mut last)) = current {
*last = expr_h;
} else {
current = Some((expr_h, expr_h));
}
}
if let Some((first, last)) = current {
out.push(Statement::Emit(Range::new_from_bounds(first, last)), *span);
}
}
&Statement::Block(ref block) => {
let block = filter_emits_in_block(block, expressions);
out.push(Statement::Block(block), *span);
}
&Statement::If {
condition,
ref accept,
ref reject,
} => {
let accept = filter_emits_in_block(accept, expressions);
let reject = filter_emits_in_block(reject, expressions);
out.push(
Statement::If {
condition,
accept,
reject,
},
*span,
);
}
&Statement::Switch {
selector,
ref cases,
} => {
let cases = cases
.iter()
.map(|case| {
let body = filter_emits_in_block(&case.body, expressions);
SwitchCase {
value: case.value,
body,
fall_through: case.fall_through,
}
})
.collect();
out.push(Statement::Switch { selector, cases }, *span);
}
&Statement::Loop {
ref body,
ref continuing,
break_if,
} => {
let body = filter_emits_in_block(body, expressions);
let continuing = filter_emits_in_block(continuing, expressions);
out.push(
Statement::Loop {
body,
continuing,
break_if,
},
*span,
);
}
stmt => out.push(stmt.clone(), *span),
}
}
out
}
fn map_value_to_literal(value: f64, scalar: Scalar) -> Result<Literal, PipelineConstantError> {
// note that in rust 0.0 == -0.0
match scalar {

View File

@ -14,4 +14,8 @@
override inferred_f32 = 2.718;
@compute @workgroup_size(1)
fn main() {}
fn main() {
var t = height * 5;
let a = !has_point_light;
var x = a;
}

View File

@ -15,7 +15,83 @@
may_kill: false,
sampling_set: [],
global_uses: [],
expressions: [],
expressions: [
(
uniformity: (
non_uniform_result: None,
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Handle(2),
),
(
uniformity: (
non_uniform_result: None,
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Value(Scalar((
kind: Float,
width: 4,
))),
),
(
uniformity: (
non_uniform_result: None,
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Value(Scalar((
kind: Float,
width: 4,
))),
),
(
uniformity: (
non_uniform_result: Some(4),
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Value(Pointer(
base: 2,
space: Function,
)),
),
(
uniformity: (
non_uniform_result: None,
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Handle(1),
),
(
uniformity: (
non_uniform_result: None,
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Handle(1),
),
(
uniformity: (
non_uniform_result: Some(7),
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Value(Pointer(
base: 1,
space: Function,
)),
),
],
sampling: [],
dual_source_blending: false,
),

View File

@ -9,5 +9,10 @@ static const float inferred_f32_ = 2.718;
[numthreads(1, 1, 1)]
void main()
{
float t = (float)0;
bool x = (bool)0;
t = 23.0;
x = true;
return;
}

View File

@ -90,10 +90,54 @@
name: Some("main"),
arguments: [],
result: None,
local_variables: [],
expressions: [],
named_expressions: {},
local_variables: [
(
name: Some("t"),
ty: 2,
init: None,
),
(
name: Some("x"),
ty: 1,
init: None,
),
],
expressions: [
Override(6),
Literal(F32(5.0)),
Binary(
op: Multiply,
left: 1,
right: 2,
),
LocalVariable(1),
Override(1),
Unary(
op: LogicalNot,
expr: 5,
),
LocalVariable(2),
],
named_expressions: {
6: "a",
},
body: [
Emit((
start: 2,
end: 3,
)),
Store(
pointer: 4,
value: 3,
),
Emit((
start: 5,
end: 6,
)),
Store(
pointer: 7,
value: 6,
),
Return(
value: None,
),

View File

@ -90,10 +90,54 @@
name: Some("main"),
arguments: [],
result: None,
local_variables: [],
expressions: [],
named_expressions: {},
local_variables: [
(
name: Some("t"),
ty: 2,
init: None,
),
(
name: Some("x"),
ty: 1,
init: None,
),
],
expressions: [
Override(6),
Literal(F32(5.0)),
Binary(
op: Multiply,
left: 1,
right: 2,
),
LocalVariable(1),
Override(1),
Unary(
op: LogicalNot,
expr: 5,
),
LocalVariable(2),
],
named_expressions: {
6: "a",
},
body: [
Emit((
start: 2,
end: 3,
)),
Store(
pointer: 4,
value: 3,
),
Emit((
start: 5,
end: 6,
)),
Store(
pointer: 7,
value: 6,
),
Return(
value: None,
),

View File

@ -14,5 +14,9 @@ constant float inferred_f32_ = 2.718;
kernel void main_(
) {
float t = {};
bool x = {};
t = 23.0;
x = true;
return;
}

View File

@ -1,7 +1,7 @@
; SPIR-V
; Version: 1.0
; Generator: rspirv
; Bound: 17
; Bound: 24
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
@ -19,9 +19,18 @@ OpExecutionMode %14 LocalSize 1 1 1
%11 = OpConstant %4 4.6
%12 = OpConstant %4 2.718
%15 = OpTypeFunction %2
%16 = OpConstant %4 23.0
%18 = OpTypePointer Function %4
%19 = OpConstantNull %4
%21 = OpTypePointer Function %3
%22 = OpConstantNull %3
%14 = OpFunction %2 None %15
%13 = OpLabel
OpBranch %16
%16 = OpLabel
%17 = OpVariable %18 Function %19
%20 = OpVariable %21 Function %22
OpBranch %23
%23 = OpLabel
OpStore %17 %16
OpStore %20 %5
OpReturn
OpFunctionEnd