diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index 298ccbc0d..bd9eec76e 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -1,10 +1,11 @@ use super::PipelineConstants; use crate::{ - proc::{ConstantEvaluator, ConstantEvaluatorError}, + proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter}, valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator}, - Constant, Expression, Handle, Literal, Module, Override, Scalar, Span, TypeInner, WithSpan, + Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar, + Span, Statement, SwitchCase, TypeInner, WithSpan, }; -use std::{borrow::Cow, collections::HashSet}; +use std::{borrow::Cow, collections::HashSet, mem}; use thiserror::Error; #[derive(Error, Debug, Clone)] @@ -175,6 +176,18 @@ pub(super) fn process_overrides<'a>( } } + let mut functions = mem::take(&mut module.functions); + for (_, function) in functions.iter_mut() { + process_function(&mut module, &override_map, function)?; + } + let _ = mem::replace(&mut module.functions, functions); + + let mut entry_points = mem::take(&mut module.entry_points); + for ep in entry_points.iter_mut() { + process_function(&mut module, &override_map, &mut ep.function)?; + } + let _ = mem::replace(&mut module.entry_points, entry_points); + // Now that the global expression arena has changed, we need to // recompute those expressions' types. For the time being, do a // full re-validation. @@ -237,6 +250,64 @@ fn process_override( Ok(h) } +/// Replaces all `Expression::Override`s in this function's expression arena +/// with `Expression::Constant` and evaluates all expressions in its arena. +fn process_function( + module: &mut Module, + override_map: &[Handle], + function: &mut Function, +) -> Result<(), ConstantEvaluatorError> { + // A map from original local expression handles to + // handles in the new, local expression arena. + let mut adjusted_local_expressions = Vec::with_capacity(function.expressions.len()); + + let mut local_expression_kind_tracker = crate::proc::ExpressionKindTracker::new(); + + let mut expressions = mem::take(&mut function.expressions); + + // Dummy `emitter` and `block` for the constant evaluator. + // We can ignore the concept of emitting expressions here since + // expressions have already been covered by a `Statement::Emit` + // in the frontend. + // The only thing we might have to do is remove some expressions + // that have been covered by a `Statement::Emit`. See the docs of + // `filter_emits_in_block` for the reasoning. + let mut emitter = Emitter::default(); + let mut block = Block::new(); + + for (old_h, expr, span) in expressions.drain() { + let mut expr = match expr { + Expression::Override(h) => Expression::Constant(override_map[h.index()]), + expr => expr, + }; + let mut evaluator = ConstantEvaluator::for_wgsl_function( + module, + &mut function.expressions, + &mut local_expression_kind_tracker, + &mut emitter, + &mut block, + ); + adjust_expr(&adjusted_local_expressions, &mut expr); + let h = evaluator.try_eval_and_append(expr, span)?; + debug_assert_eq!(old_h.index(), adjusted_local_expressions.len()); + adjusted_local_expressions.push(h); + } + + adjust_block(&adjusted_local_expressions, &mut function.body); + + let new_body = filter_emits_in_block(&function.body, &function.expressions); + let _ = mem::replace(&mut function.body, new_body); + + let named_expressions = mem::take(&mut function.named_expressions); + for (expr_h, name) in named_expressions { + function + .named_expressions + .insert(adjusted_local_expressions[expr_h.index()], name); + } + + Ok(()) +} + /// Replace every expression handle in `expr` with its counterpart /// given by `new_pos`. fn adjust_expr(new_pos: &[Handle], expr: &mut Expression) { @@ -409,6 +480,207 @@ fn adjust_expr(new_pos: &[Handle], expr: &mut Expression) { } } +/// Replace every expression handle in `block` with its counterpart +/// given by `new_pos`. +fn adjust_block(new_pos: &[Handle], block: &mut Block) { + for stmt in block.iter_mut() { + adjust_stmt(new_pos, stmt); + } +} + +/// Replace every expression handle in `stmt` with its counterpart +/// given by `new_pos`. +fn adjust_stmt(new_pos: &[Handle], stmt: &mut Statement) { + let adjust = |expr: &mut Handle| { + *expr = new_pos[expr.index()]; + }; + match *stmt { + Statement::Emit(ref mut range) => { + if let Some((mut first, mut last)) = range.first_and_last() { + adjust(&mut first); + adjust(&mut last); + *range = Range::new_from_bounds(first, last); + } + } + Statement::Block(ref mut block) => { + adjust_block(new_pos, block); + } + Statement::If { + ref mut condition, + ref mut accept, + ref mut reject, + } => { + adjust(condition); + adjust_block(new_pos, accept); + adjust_block(new_pos, reject); + } + Statement::Switch { + ref mut selector, + ref mut cases, + } => { + adjust(selector); + for case in cases.iter_mut() { + adjust_block(new_pos, &mut case.body); + } + } + Statement::Loop { + ref mut body, + ref mut continuing, + ref mut break_if, + } => { + adjust_block(new_pos, body); + adjust_block(new_pos, continuing); + if let Some(e) = break_if.as_mut() { + adjust(e); + } + } + Statement::Return { ref mut value } => { + if let Some(e) = value.as_mut() { + adjust(e); + } + } + Statement::Store { + ref mut pointer, + ref mut value, + } => { + adjust(pointer); + adjust(value); + } + Statement::ImageStore { + ref mut image, + ref mut coordinate, + ref mut array_index, + ref mut value, + } => { + adjust(image); + adjust(coordinate); + if let Some(e) = array_index.as_mut() { + adjust(e); + } + adjust(value); + } + crate::Statement::Atomic { + ref mut pointer, + ref mut value, + ref mut result, + .. + } => { + adjust(pointer); + adjust(value); + adjust(result); + } + Statement::WorkGroupUniformLoad { + ref mut pointer, + ref mut result, + } => { + adjust(pointer); + adjust(result); + } + Statement::Call { + ref mut arguments, + ref mut result, + .. + } => { + for argument in arguments.iter_mut() { + adjust(argument); + } + if let Some(e) = result.as_mut() { + adjust(e); + } + } + Statement::RayQuery { ref mut query, .. } => { + adjust(query); + } + Statement::Break | Statement::Continue | Statement::Kill | Statement::Barrier(_) => {} + } +} + +/// Filters out expressions that `needs_pre_emit`. This step is necessary after +/// const evaluation since unevaluated expressions could have been included in +/// `Statement::Emit`; but since they have been evaluated we need to filter those +/// out. +fn filter_emits_in_block(block: &Block, expressions: &Arena) -> Block { + let mut out = Block::with_capacity(block.len()); + for (stmt, span) in block.span_iter() { + match stmt { + &Statement::Emit(ref range) => { + let mut current = None; + for expr_h in range.clone() { + if expressions[expr_h].needs_pre_emit() { + if let Some((first, last)) = current { + out.push(Statement::Emit(Range::new_from_bounds(first, last)), *span); + } + + current = None; + } else if let Some((_, ref mut last)) = current { + *last = expr_h; + } else { + current = Some((expr_h, expr_h)); + } + } + if let Some((first, last)) = current { + out.push(Statement::Emit(Range::new_from_bounds(first, last)), *span); + } + } + &Statement::Block(ref block) => { + let block = filter_emits_in_block(block, expressions); + out.push(Statement::Block(block), *span); + } + &Statement::If { + condition, + ref accept, + ref reject, + } => { + let accept = filter_emits_in_block(accept, expressions); + let reject = filter_emits_in_block(reject, expressions); + out.push( + Statement::If { + condition, + accept, + reject, + }, + *span, + ); + } + &Statement::Switch { + selector, + ref cases, + } => { + let cases = cases + .iter() + .map(|case| { + let body = filter_emits_in_block(&case.body, expressions); + SwitchCase { + value: case.value, + body, + fall_through: case.fall_through, + } + }) + .collect(); + out.push(Statement::Switch { selector, cases }, *span); + } + &Statement::Loop { + ref body, + ref continuing, + break_if, + } => { + let body = filter_emits_in_block(body, expressions); + let continuing = filter_emits_in_block(continuing, expressions); + out.push( + Statement::Loop { + body, + continuing, + break_if, + }, + *span, + ); + } + stmt => out.push(stmt.clone(), *span), + } + } + out +} + fn map_value_to_literal(value: f64, scalar: Scalar) -> Result { // note that in rust 0.0 == -0.0 match scalar { diff --git a/naga/tests/in/overrides.wgsl b/naga/tests/in/overrides.wgsl index 41e99f942..b06edecdb 100644 --- a/naga/tests/in/overrides.wgsl +++ b/naga/tests/in/overrides.wgsl @@ -14,4 +14,8 @@ override inferred_f32 = 2.718; @compute @workgroup_size(1) -fn main() {} \ No newline at end of file +fn main() { + var t = height * 5; + let a = !has_point_light; + var x = a; +} \ No newline at end of file diff --git a/naga/tests/out/analysis/overrides.info.ron b/naga/tests/out/analysis/overrides.info.ron index 7a2447f3c..389e7fba7 100644 --- a/naga/tests/out/analysis/overrides.info.ron +++ b/naga/tests/out/analysis/overrides.info.ron @@ -15,7 +15,83 @@ may_kill: false, sampling_set: [], global_uses: [], - expressions: [], + expressions: [ + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(2), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: Some(4), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Pointer( + base: 2, + space: Function, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: Some(7), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Pointer( + base: 1, + space: Function, + )), + ), + ], sampling: [], dual_source_blending: false, ), diff --git a/naga/tests/out/hlsl/overrides.hlsl b/naga/tests/out/hlsl/overrides.hlsl index 0a849fd4d..1541ae728 100644 --- a/naga/tests/out/hlsl/overrides.hlsl +++ b/naga/tests/out/hlsl/overrides.hlsl @@ -9,5 +9,10 @@ static const float inferred_f32_ = 2.718; [numthreads(1, 1, 1)] void main() { + float t = (float)0; + bool x = (bool)0; + + t = 23.0; + x = true; return; } diff --git a/naga/tests/out/ir/overrides.compact.ron b/naga/tests/out/ir/overrides.compact.ron index 7a60f1423..b0a230a71 100644 --- a/naga/tests/out/ir/overrides.compact.ron +++ b/naga/tests/out/ir/overrides.compact.ron @@ -90,10 +90,54 @@ name: Some("main"), arguments: [], result: None, - local_variables: [], - expressions: [], - named_expressions: {}, + local_variables: [ + ( + name: Some("t"), + ty: 2, + init: None, + ), + ( + name: Some("x"), + ty: 1, + init: None, + ), + ], + expressions: [ + Override(6), + Literal(F32(5.0)), + Binary( + op: Multiply, + left: 1, + right: 2, + ), + LocalVariable(1), + Override(1), + Unary( + op: LogicalNot, + expr: 5, + ), + LocalVariable(2), + ], + named_expressions: { + 6: "a", + }, body: [ + Emit(( + start: 2, + end: 3, + )), + Store( + pointer: 4, + value: 3, + ), + Emit(( + start: 5, + end: 6, + )), + Store( + pointer: 7, + value: 6, + ), Return( value: None, ), diff --git a/naga/tests/out/ir/overrides.ron b/naga/tests/out/ir/overrides.ron index 7a60f1423..b0a230a71 100644 --- a/naga/tests/out/ir/overrides.ron +++ b/naga/tests/out/ir/overrides.ron @@ -90,10 +90,54 @@ name: Some("main"), arguments: [], result: None, - local_variables: [], - expressions: [], - named_expressions: {}, + local_variables: [ + ( + name: Some("t"), + ty: 2, + init: None, + ), + ( + name: Some("x"), + ty: 1, + init: None, + ), + ], + expressions: [ + Override(6), + Literal(F32(5.0)), + Binary( + op: Multiply, + left: 1, + right: 2, + ), + LocalVariable(1), + Override(1), + Unary( + op: LogicalNot, + expr: 5, + ), + LocalVariable(2), + ], + named_expressions: { + 6: "a", + }, body: [ + Emit(( + start: 2, + end: 3, + )), + Store( + pointer: 4, + value: 3, + ), + Emit(( + start: 5, + end: 6, + )), + Store( + pointer: 7, + value: 6, + ), Return( value: None, ), diff --git a/naga/tests/out/msl/overrides.msl b/naga/tests/out/msl/overrides.msl index 13a3b623a..0bc9e6b12 100644 --- a/naga/tests/out/msl/overrides.msl +++ b/naga/tests/out/msl/overrides.msl @@ -14,5 +14,9 @@ constant float inferred_f32_ = 2.718; kernel void main_( ) { + float t = {}; + bool x = {}; + t = 23.0; + x = true; return; } diff --git a/naga/tests/out/spv/overrides.main.spvasm b/naga/tests/out/spv/overrides.main.spvasm index 7731edfb9..d421606ca 100644 --- a/naga/tests/out/spv/overrides.main.spvasm +++ b/naga/tests/out/spv/overrides.main.spvasm @@ -1,7 +1,7 @@ ; SPIR-V ; Version: 1.0 ; Generator: rspirv -; Bound: 17 +; Bound: 24 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -19,9 +19,18 @@ OpExecutionMode %14 LocalSize 1 1 1 %11 = OpConstant %4 4.6 %12 = OpConstant %4 2.718 %15 = OpTypeFunction %2 +%16 = OpConstant %4 23.0 +%18 = OpTypePointer Function %4 +%19 = OpConstantNull %4 +%21 = OpTypePointer Function %3 +%22 = OpConstantNull %3 %14 = OpFunction %2 None %15 %13 = OpLabel -OpBranch %16 -%16 = OpLabel +%17 = OpVariable %18 Function %19 +%20 = OpVariable %21 Function %22 +OpBranch %23 +%23 = OpLabel +OpStore %17 %16 +OpStore %20 %5 OpReturn OpFunctionEnd \ No newline at end of file