mirror of
https://github.com/gfx-rs/wgpu.git
synced 2024-11-21 22:33:49 +00:00
evaluate override-expressions in functions
This commit is contained in:
parent
fd5c4db606
commit
3abdfde0ba
@ -1,10 +1,11 @@
|
||||
use super::PipelineConstants;
|
||||
use crate::{
|
||||
proc::{ConstantEvaluator, ConstantEvaluatorError},
|
||||
proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter},
|
||||
valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator},
|
||||
Constant, Expression, Handle, Literal, Module, Override, Scalar, Span, TypeInner, WithSpan,
|
||||
Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar,
|
||||
Span, Statement, SwitchCase, TypeInner, WithSpan,
|
||||
};
|
||||
use std::{borrow::Cow, collections::HashSet};
|
||||
use std::{borrow::Cow, collections::HashSet, mem};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug, Clone)]
|
||||
@ -175,6 +176,18 @@ pub(super) fn process_overrides<'a>(
|
||||
}
|
||||
}
|
||||
|
||||
let mut functions = mem::take(&mut module.functions);
|
||||
for (_, function) in functions.iter_mut() {
|
||||
process_function(&mut module, &override_map, function)?;
|
||||
}
|
||||
let _ = mem::replace(&mut module.functions, functions);
|
||||
|
||||
let mut entry_points = mem::take(&mut module.entry_points);
|
||||
for ep in entry_points.iter_mut() {
|
||||
process_function(&mut module, &override_map, &mut ep.function)?;
|
||||
}
|
||||
let _ = mem::replace(&mut module.entry_points, entry_points);
|
||||
|
||||
// Now that the global expression arena has changed, we need to
|
||||
// recompute those expressions' types. For the time being, do a
|
||||
// full re-validation.
|
||||
@ -237,6 +250,64 @@ fn process_override(
|
||||
Ok(h)
|
||||
}
|
||||
|
||||
/// Replaces all `Expression::Override`s in this function's expression arena
|
||||
/// with `Expression::Constant` and evaluates all expressions in its arena.
|
||||
fn process_function(
|
||||
module: &mut Module,
|
||||
override_map: &[Handle<Constant>],
|
||||
function: &mut Function,
|
||||
) -> Result<(), ConstantEvaluatorError> {
|
||||
// A map from original local expression handles to
|
||||
// handles in the new, local expression arena.
|
||||
let mut adjusted_local_expressions = Vec::with_capacity(function.expressions.len());
|
||||
|
||||
let mut local_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
|
||||
|
||||
let mut expressions = mem::take(&mut function.expressions);
|
||||
|
||||
// Dummy `emitter` and `block` for the constant evaluator.
|
||||
// We can ignore the concept of emitting expressions here since
|
||||
// expressions have already been covered by a `Statement::Emit`
|
||||
// in the frontend.
|
||||
// The only thing we might have to do is remove some expressions
|
||||
// that have been covered by a `Statement::Emit`. See the docs of
|
||||
// `filter_emits_in_block` for the reasoning.
|
||||
let mut emitter = Emitter::default();
|
||||
let mut block = Block::new();
|
||||
|
||||
for (old_h, expr, span) in expressions.drain() {
|
||||
let mut expr = match expr {
|
||||
Expression::Override(h) => Expression::Constant(override_map[h.index()]),
|
||||
expr => expr,
|
||||
};
|
||||
let mut evaluator = ConstantEvaluator::for_wgsl_function(
|
||||
module,
|
||||
&mut function.expressions,
|
||||
&mut local_expression_kind_tracker,
|
||||
&mut emitter,
|
||||
&mut block,
|
||||
);
|
||||
adjust_expr(&adjusted_local_expressions, &mut expr);
|
||||
let h = evaluator.try_eval_and_append(expr, span)?;
|
||||
debug_assert_eq!(old_h.index(), adjusted_local_expressions.len());
|
||||
adjusted_local_expressions.push(h);
|
||||
}
|
||||
|
||||
adjust_block(&adjusted_local_expressions, &mut function.body);
|
||||
|
||||
let new_body = filter_emits_in_block(&function.body, &function.expressions);
|
||||
let _ = mem::replace(&mut function.body, new_body);
|
||||
|
||||
let named_expressions = mem::take(&mut function.named_expressions);
|
||||
for (expr_h, name) in named_expressions {
|
||||
function
|
||||
.named_expressions
|
||||
.insert(adjusted_local_expressions[expr_h.index()], name);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Replace every expression handle in `expr` with its counterpart
|
||||
/// given by `new_pos`.
|
||||
fn adjust_expr(new_pos: &[Handle<Expression>], expr: &mut Expression) {
|
||||
@ -409,6 +480,207 @@ fn adjust_expr(new_pos: &[Handle<Expression>], expr: &mut Expression) {
|
||||
}
|
||||
}
|
||||
|
||||
/// Replace every expression handle in `block` with its counterpart
|
||||
/// given by `new_pos`.
|
||||
fn adjust_block(new_pos: &[Handle<Expression>], block: &mut Block) {
|
||||
for stmt in block.iter_mut() {
|
||||
adjust_stmt(new_pos, stmt);
|
||||
}
|
||||
}
|
||||
|
||||
/// Replace every expression handle in `stmt` with its counterpart
|
||||
/// given by `new_pos`.
|
||||
fn adjust_stmt(new_pos: &[Handle<Expression>], stmt: &mut Statement) {
|
||||
let adjust = |expr: &mut Handle<Expression>| {
|
||||
*expr = new_pos[expr.index()];
|
||||
};
|
||||
match *stmt {
|
||||
Statement::Emit(ref mut range) => {
|
||||
if let Some((mut first, mut last)) = range.first_and_last() {
|
||||
adjust(&mut first);
|
||||
adjust(&mut last);
|
||||
*range = Range::new_from_bounds(first, last);
|
||||
}
|
||||
}
|
||||
Statement::Block(ref mut block) => {
|
||||
adjust_block(new_pos, block);
|
||||
}
|
||||
Statement::If {
|
||||
ref mut condition,
|
||||
ref mut accept,
|
||||
ref mut reject,
|
||||
} => {
|
||||
adjust(condition);
|
||||
adjust_block(new_pos, accept);
|
||||
adjust_block(new_pos, reject);
|
||||
}
|
||||
Statement::Switch {
|
||||
ref mut selector,
|
||||
ref mut cases,
|
||||
} => {
|
||||
adjust(selector);
|
||||
for case in cases.iter_mut() {
|
||||
adjust_block(new_pos, &mut case.body);
|
||||
}
|
||||
}
|
||||
Statement::Loop {
|
||||
ref mut body,
|
||||
ref mut continuing,
|
||||
ref mut break_if,
|
||||
} => {
|
||||
adjust_block(new_pos, body);
|
||||
adjust_block(new_pos, continuing);
|
||||
if let Some(e) = break_if.as_mut() {
|
||||
adjust(e);
|
||||
}
|
||||
}
|
||||
Statement::Return { ref mut value } => {
|
||||
if let Some(e) = value.as_mut() {
|
||||
adjust(e);
|
||||
}
|
||||
}
|
||||
Statement::Store {
|
||||
ref mut pointer,
|
||||
ref mut value,
|
||||
} => {
|
||||
adjust(pointer);
|
||||
adjust(value);
|
||||
}
|
||||
Statement::ImageStore {
|
||||
ref mut image,
|
||||
ref mut coordinate,
|
||||
ref mut array_index,
|
||||
ref mut value,
|
||||
} => {
|
||||
adjust(image);
|
||||
adjust(coordinate);
|
||||
if let Some(e) = array_index.as_mut() {
|
||||
adjust(e);
|
||||
}
|
||||
adjust(value);
|
||||
}
|
||||
crate::Statement::Atomic {
|
||||
ref mut pointer,
|
||||
ref mut value,
|
||||
ref mut result,
|
||||
..
|
||||
} => {
|
||||
adjust(pointer);
|
||||
adjust(value);
|
||||
adjust(result);
|
||||
}
|
||||
Statement::WorkGroupUniformLoad {
|
||||
ref mut pointer,
|
||||
ref mut result,
|
||||
} => {
|
||||
adjust(pointer);
|
||||
adjust(result);
|
||||
}
|
||||
Statement::Call {
|
||||
ref mut arguments,
|
||||
ref mut result,
|
||||
..
|
||||
} => {
|
||||
for argument in arguments.iter_mut() {
|
||||
adjust(argument);
|
||||
}
|
||||
if let Some(e) = result.as_mut() {
|
||||
adjust(e);
|
||||
}
|
||||
}
|
||||
Statement::RayQuery { ref mut query, .. } => {
|
||||
adjust(query);
|
||||
}
|
||||
Statement::Break | Statement::Continue | Statement::Kill | Statement::Barrier(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
/// Filters out expressions that `needs_pre_emit`. This step is necessary after
|
||||
/// const evaluation since unevaluated expressions could have been included in
|
||||
/// `Statement::Emit`; but since they have been evaluated we need to filter those
|
||||
/// out.
|
||||
fn filter_emits_in_block(block: &Block, expressions: &Arena<Expression>) -> Block {
|
||||
let mut out = Block::with_capacity(block.len());
|
||||
for (stmt, span) in block.span_iter() {
|
||||
match stmt {
|
||||
&Statement::Emit(ref range) => {
|
||||
let mut current = None;
|
||||
for expr_h in range.clone() {
|
||||
if expressions[expr_h].needs_pre_emit() {
|
||||
if let Some((first, last)) = current {
|
||||
out.push(Statement::Emit(Range::new_from_bounds(first, last)), *span);
|
||||
}
|
||||
|
||||
current = None;
|
||||
} else if let Some((_, ref mut last)) = current {
|
||||
*last = expr_h;
|
||||
} else {
|
||||
current = Some((expr_h, expr_h));
|
||||
}
|
||||
}
|
||||
if let Some((first, last)) = current {
|
||||
out.push(Statement::Emit(Range::new_from_bounds(first, last)), *span);
|
||||
}
|
||||
}
|
||||
&Statement::Block(ref block) => {
|
||||
let block = filter_emits_in_block(block, expressions);
|
||||
out.push(Statement::Block(block), *span);
|
||||
}
|
||||
&Statement::If {
|
||||
condition,
|
||||
ref accept,
|
||||
ref reject,
|
||||
} => {
|
||||
let accept = filter_emits_in_block(accept, expressions);
|
||||
let reject = filter_emits_in_block(reject, expressions);
|
||||
out.push(
|
||||
Statement::If {
|
||||
condition,
|
||||
accept,
|
||||
reject,
|
||||
},
|
||||
*span,
|
||||
);
|
||||
}
|
||||
&Statement::Switch {
|
||||
selector,
|
||||
ref cases,
|
||||
} => {
|
||||
let cases = cases
|
||||
.iter()
|
||||
.map(|case| {
|
||||
let body = filter_emits_in_block(&case.body, expressions);
|
||||
SwitchCase {
|
||||
value: case.value,
|
||||
body,
|
||||
fall_through: case.fall_through,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
out.push(Statement::Switch { selector, cases }, *span);
|
||||
}
|
||||
&Statement::Loop {
|
||||
ref body,
|
||||
ref continuing,
|
||||
break_if,
|
||||
} => {
|
||||
let body = filter_emits_in_block(body, expressions);
|
||||
let continuing = filter_emits_in_block(continuing, expressions);
|
||||
out.push(
|
||||
Statement::Loop {
|
||||
body,
|
||||
continuing,
|
||||
break_if,
|
||||
},
|
||||
*span,
|
||||
);
|
||||
}
|
||||
stmt => out.push(stmt.clone(), *span),
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn map_value_to_literal(value: f64, scalar: Scalar) -> Result<Literal, PipelineConstantError> {
|
||||
// note that in rust 0.0 == -0.0
|
||||
match scalar {
|
||||
|
@ -14,4 +14,8 @@
|
||||
override inferred_f32 = 2.718;
|
||||
|
||||
@compute @workgroup_size(1)
|
||||
fn main() {}
|
||||
fn main() {
|
||||
var t = height * 5;
|
||||
let a = !has_point_light;
|
||||
var x = a;
|
||||
}
|
@ -15,7 +15,83 @@
|
||||
may_kill: false,
|
||||
sampling_set: [],
|
||||
global_uses: [],
|
||||
expressions: [],
|
||||
expressions: [
|
||||
(
|
||||
uniformity: (
|
||||
non_uniform_result: None,
|
||||
requirements: (""),
|
||||
),
|
||||
ref_count: 1,
|
||||
assignable_global: None,
|
||||
ty: Handle(2),
|
||||
),
|
||||
(
|
||||
uniformity: (
|
||||
non_uniform_result: None,
|
||||
requirements: (""),
|
||||
),
|
||||
ref_count: 1,
|
||||
assignable_global: None,
|
||||
ty: Value(Scalar((
|
||||
kind: Float,
|
||||
width: 4,
|
||||
))),
|
||||
),
|
||||
(
|
||||
uniformity: (
|
||||
non_uniform_result: None,
|
||||
requirements: (""),
|
||||
),
|
||||
ref_count: 1,
|
||||
assignable_global: None,
|
||||
ty: Value(Scalar((
|
||||
kind: Float,
|
||||
width: 4,
|
||||
))),
|
||||
),
|
||||
(
|
||||
uniformity: (
|
||||
non_uniform_result: Some(4),
|
||||
requirements: (""),
|
||||
),
|
||||
ref_count: 1,
|
||||
assignable_global: None,
|
||||
ty: Value(Pointer(
|
||||
base: 2,
|
||||
space: Function,
|
||||
)),
|
||||
),
|
||||
(
|
||||
uniformity: (
|
||||
non_uniform_result: None,
|
||||
requirements: (""),
|
||||
),
|
||||
ref_count: 1,
|
||||
assignable_global: None,
|
||||
ty: Handle(1),
|
||||
),
|
||||
(
|
||||
uniformity: (
|
||||
non_uniform_result: None,
|
||||
requirements: (""),
|
||||
),
|
||||
ref_count: 1,
|
||||
assignable_global: None,
|
||||
ty: Handle(1),
|
||||
),
|
||||
(
|
||||
uniformity: (
|
||||
non_uniform_result: Some(7),
|
||||
requirements: (""),
|
||||
),
|
||||
ref_count: 1,
|
||||
assignable_global: None,
|
||||
ty: Value(Pointer(
|
||||
base: 1,
|
||||
space: Function,
|
||||
)),
|
||||
),
|
||||
],
|
||||
sampling: [],
|
||||
dual_source_blending: false,
|
||||
),
|
||||
|
@ -9,5 +9,10 @@ static const float inferred_f32_ = 2.718;
|
||||
[numthreads(1, 1, 1)]
|
||||
void main()
|
||||
{
|
||||
float t = (float)0;
|
||||
bool x = (bool)0;
|
||||
|
||||
t = 23.0;
|
||||
x = true;
|
||||
return;
|
||||
}
|
||||
|
@ -90,10 +90,54 @@
|
||||
name: Some("main"),
|
||||
arguments: [],
|
||||
result: None,
|
||||
local_variables: [],
|
||||
expressions: [],
|
||||
named_expressions: {},
|
||||
local_variables: [
|
||||
(
|
||||
name: Some("t"),
|
||||
ty: 2,
|
||||
init: None,
|
||||
),
|
||||
(
|
||||
name: Some("x"),
|
||||
ty: 1,
|
||||
init: None,
|
||||
),
|
||||
],
|
||||
expressions: [
|
||||
Override(6),
|
||||
Literal(F32(5.0)),
|
||||
Binary(
|
||||
op: Multiply,
|
||||
left: 1,
|
||||
right: 2,
|
||||
),
|
||||
LocalVariable(1),
|
||||
Override(1),
|
||||
Unary(
|
||||
op: LogicalNot,
|
||||
expr: 5,
|
||||
),
|
||||
LocalVariable(2),
|
||||
],
|
||||
named_expressions: {
|
||||
6: "a",
|
||||
},
|
||||
body: [
|
||||
Emit((
|
||||
start: 2,
|
||||
end: 3,
|
||||
)),
|
||||
Store(
|
||||
pointer: 4,
|
||||
value: 3,
|
||||
),
|
||||
Emit((
|
||||
start: 5,
|
||||
end: 6,
|
||||
)),
|
||||
Store(
|
||||
pointer: 7,
|
||||
value: 6,
|
||||
),
|
||||
Return(
|
||||
value: None,
|
||||
),
|
||||
|
@ -90,10 +90,54 @@
|
||||
name: Some("main"),
|
||||
arguments: [],
|
||||
result: None,
|
||||
local_variables: [],
|
||||
expressions: [],
|
||||
named_expressions: {},
|
||||
local_variables: [
|
||||
(
|
||||
name: Some("t"),
|
||||
ty: 2,
|
||||
init: None,
|
||||
),
|
||||
(
|
||||
name: Some("x"),
|
||||
ty: 1,
|
||||
init: None,
|
||||
),
|
||||
],
|
||||
expressions: [
|
||||
Override(6),
|
||||
Literal(F32(5.0)),
|
||||
Binary(
|
||||
op: Multiply,
|
||||
left: 1,
|
||||
right: 2,
|
||||
),
|
||||
LocalVariable(1),
|
||||
Override(1),
|
||||
Unary(
|
||||
op: LogicalNot,
|
||||
expr: 5,
|
||||
),
|
||||
LocalVariable(2),
|
||||
],
|
||||
named_expressions: {
|
||||
6: "a",
|
||||
},
|
||||
body: [
|
||||
Emit((
|
||||
start: 2,
|
||||
end: 3,
|
||||
)),
|
||||
Store(
|
||||
pointer: 4,
|
||||
value: 3,
|
||||
),
|
||||
Emit((
|
||||
start: 5,
|
||||
end: 6,
|
||||
)),
|
||||
Store(
|
||||
pointer: 7,
|
||||
value: 6,
|
||||
),
|
||||
Return(
|
||||
value: None,
|
||||
),
|
||||
|
@ -14,5 +14,9 @@ constant float inferred_f32_ = 2.718;
|
||||
|
||||
kernel void main_(
|
||||
) {
|
||||
float t = {};
|
||||
bool x = {};
|
||||
t = 23.0;
|
||||
x = true;
|
||||
return;
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
; SPIR-V
|
||||
; Version: 1.0
|
||||
; Generator: rspirv
|
||||
; Bound: 17
|
||||
; Bound: 24
|
||||
OpCapability Shader
|
||||
%1 = OpExtInstImport "GLSL.std.450"
|
||||
OpMemoryModel Logical GLSL450
|
||||
@ -19,9 +19,18 @@ OpExecutionMode %14 LocalSize 1 1 1
|
||||
%11 = OpConstant %4 4.6
|
||||
%12 = OpConstant %4 2.718
|
||||
%15 = OpTypeFunction %2
|
||||
%16 = OpConstant %4 23.0
|
||||
%18 = OpTypePointer Function %4
|
||||
%19 = OpConstantNull %4
|
||||
%21 = OpTypePointer Function %3
|
||||
%22 = OpConstantNull %3
|
||||
%14 = OpFunction %2 None %15
|
||||
%13 = OpLabel
|
||||
OpBranch %16
|
||||
%16 = OpLabel
|
||||
%17 = OpVariable %18 Function %19
|
||||
%20 = OpVariable %21 Function %22
|
||||
OpBranch %23
|
||||
%23 = OpLabel
|
||||
OpStore %17 %16
|
||||
OpStore %20 %5
|
||||
OpReturn
|
||||
OpFunctionEnd
|
Loading…
Reference in New Issue
Block a user