mirror of
https://github.com/EmbarkStudios/rust-gpu.git
synced 2024-11-22 06:45:13 +00:00
spirt-passes: add reduce
pass for replacing ops with their inputs/constants.
This commit is contained in:
parent
6ed51e87b2
commit
27c698b302
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -2014,6 +2014,7 @@ dependencies = [
|
||||
"either",
|
||||
"hashbrown",
|
||||
"indexmap",
|
||||
"lazy_static",
|
||||
"libc",
|
||||
"num-traits",
|
||||
"once_cell",
|
||||
|
@ -59,6 +59,7 @@ smallvec = { version = "1.6.1", features = ["union"] }
|
||||
spirv-tools = { version = "0.9", default-features = false }
|
||||
rustc_codegen_spirv-types.workspace = true
|
||||
spirt = "0.1.0"
|
||||
lazy_static = "1.4.0"
|
||||
|
||||
[dev-dependencies]
|
||||
pipe = "0.4"
|
||||
|
@ -1,8 +1,92 @@
|
||||
//! SPIR-T pass infrastructure and supporting utilities.
|
||||
|
||||
use rustc_data_structures::fx::FxIndexSet;
|
||||
mod reduce;
|
||||
|
||||
use lazy_static::lazy_static;
|
||||
use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexSet};
|
||||
use spirt::func_at::FuncAt;
|
||||
use spirt::transform::InnerInPlaceTransform;
|
||||
use spirt::visit::{InnerVisit, Visitor};
|
||||
use spirt::{AttrSet, Const, Context, DeclDef, Func, GlobalVar, Module, Type};
|
||||
use spirt::{
|
||||
spv, AttrSet, Const, Context, ControlNode, ControlNodeKind, ControlRegion, DataInstDef,
|
||||
DataInstKind, DeclDef, EntityOrientedDenseMap, Func, FuncDefBody, GlobalVar, Module, Type,
|
||||
Value,
|
||||
};
|
||||
use std::collections::VecDeque;
|
||||
use std::hash::Hash;
|
||||
use std::iter;
|
||||
|
||||
// HACK(eddyb) `spv::spec::Spec` with extra `WellKnown`s (that should be upstreamed).
|
||||
macro_rules! def_spv_spec_with_extra_well_known {
|
||||
($($group:ident: $ty:ty = [$($entry:ident),+ $(,)?]),+ $(,)?) => {
|
||||
struct SpvSpecWithExtras {
|
||||
__base_spec: &'static spv::spec::Spec,
|
||||
|
||||
well_known: SpvWellKnownWithExtras,
|
||||
}
|
||||
|
||||
#[allow(non_snake_case)]
|
||||
pub struct SpvWellKnownWithExtras {
|
||||
__base_well_known: &'static spv::spec::WellKnown,
|
||||
|
||||
$($(pub $entry: $ty,)+)+
|
||||
}
|
||||
|
||||
impl std::ops::Deref for SpvSpecWithExtras {
|
||||
type Target = spv::spec::Spec;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.__base_spec
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for SpvWellKnownWithExtras {
|
||||
type Target = spv::spec::WellKnown;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.__base_well_known
|
||||
}
|
||||
}
|
||||
|
||||
impl SpvSpecWithExtras {
|
||||
#[inline(always)]
|
||||
#[must_use]
|
||||
pub fn get() -> &'static SpvSpecWithExtras {
|
||||
lazy_static! {
|
||||
static ref SPEC: SpvSpecWithExtras = {
|
||||
#[allow(non_camel_case_types)]
|
||||
struct PerWellKnownGroup<$($group),+> {
|
||||
$($group: $group),+
|
||||
}
|
||||
|
||||
let spv_spec = spv::spec::Spec::get();
|
||||
let lookup_fns = PerWellKnownGroup {
|
||||
opcode: |name| spv_spec.instructions.lookup(name).unwrap(),
|
||||
};
|
||||
|
||||
SpvSpecWithExtras {
|
||||
__base_spec: spv_spec,
|
||||
|
||||
well_known: SpvWellKnownWithExtras {
|
||||
__base_well_known: &spv_spec.well_known,
|
||||
|
||||
$($($entry: (lookup_fns.$group)(stringify!($entry)),)+)+
|
||||
},
|
||||
}
|
||||
};
|
||||
}
|
||||
&SPEC
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
def_spv_spec_with_extra_well_known! {
|
||||
opcode: spv::spec::Opcode = [
|
||||
OpConstantComposite,
|
||||
|
||||
OpBitcast,
|
||||
OpCompositeInsert,
|
||||
OpCompositeExtract,
|
||||
],
|
||||
}
|
||||
|
||||
/// Run intra-function passes on all `Func` definitions in the `Module`.
|
||||
//
|
||||
@ -36,6 +120,7 @@ pub(super) fn run_func_passes<P>(
|
||||
for name in passes {
|
||||
let name = name.as_ref();
|
||||
let (full_name, pass_fn): (_, fn(_, &mut _)) = match name {
|
||||
"reduce" => ("spirt_passes::reduce", reduce::reduce_in_func),
|
||||
_ => panic!("unknown `--spirt-passes={}`", name),
|
||||
};
|
||||
|
||||
@ -43,6 +128,9 @@ pub(super) fn run_func_passes<P>(
|
||||
for &func in &all_funcs {
|
||||
if let DeclDef::Present(func_def_body) = &mut module.funcs[func].def {
|
||||
pass_fn(cx, func_def_body);
|
||||
|
||||
// FIXME(eddyb) avoid doing this except where changes occurred.
|
||||
remove_unused_values_in_func(func_def_body);
|
||||
}
|
||||
}
|
||||
after_pass(full_name, module, profiler);
|
||||
@ -86,3 +174,409 @@ impl Visitor<'_> for ReachableUseCollector<'_> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME(eddyb) maybe this should be provided by `spirt::visit`.
|
||||
struct VisitAllControlRegionsAndNodes<S, VCR, VCN> {
|
||||
state: S,
|
||||
visit_control_region: VCR,
|
||||
visit_control_node: VCN,
|
||||
}
|
||||
const _: () = {
|
||||
use spirt::{func_at::*, visit::*, *};
|
||||
|
||||
impl<
|
||||
'a,
|
||||
S,
|
||||
VCR: FnMut(&mut S, FuncAt<'a, ControlRegion>),
|
||||
VCN: FnMut(&mut S, FuncAt<'a, ControlNode>),
|
||||
> Visitor<'a> for VisitAllControlRegionsAndNodes<S, VCR, VCN>
|
||||
{
|
||||
// FIXME(eddyb) this is excessive, maybe different kinds of
|
||||
// visitors should exist for module-level and func-level?
|
||||
fn visit_attr_set_use(&mut self, _: AttrSet) {}
|
||||
fn visit_type_use(&mut self, _: Type) {}
|
||||
fn visit_const_use(&mut self, _: Const) {}
|
||||
fn visit_global_var_use(&mut self, _: GlobalVar) {}
|
||||
fn visit_func_use(&mut self, _: Func) {}
|
||||
|
||||
fn visit_control_region_def(&mut self, func_at_control_region: FuncAt<'a, ControlRegion>) {
|
||||
(self.visit_control_region)(&mut self.state, func_at_control_region);
|
||||
func_at_control_region.inner_visit_with(self);
|
||||
}
|
||||
fn visit_control_node_def(&mut self, func_at_control_node: FuncAt<'a, ControlNode>) {
|
||||
(self.visit_control_node)(&mut self.state, func_at_control_node);
|
||||
// HACK(eddyb) accidentally private `inner_visit_with` method.
|
||||
fn control_node_inner_visit_with<'a>(
|
||||
self_: FuncAt<'a, ControlNode>,
|
||||
visitor: &mut impl Visitor<'a>,
|
||||
) {
|
||||
let ControlNodeDef { kind, outputs } = self_.def();
|
||||
|
||||
match kind {
|
||||
ControlNodeKind::Block { insts } => {
|
||||
for func_at_inst in self_.at(*insts) {
|
||||
visitor.visit_data_inst_def(func_at_inst.def());
|
||||
}
|
||||
}
|
||||
ControlNodeKind::Select {
|
||||
kind: SelectionKind::BoolCond | SelectionKind::SpvInst(_),
|
||||
scrutinee,
|
||||
cases,
|
||||
} => {
|
||||
visitor.visit_value_use(scrutinee);
|
||||
for &case in cases {
|
||||
visitor.visit_control_region_def(self_.at(case));
|
||||
}
|
||||
}
|
||||
ControlNodeKind::Loop {
|
||||
initial_inputs,
|
||||
body,
|
||||
repeat_condition,
|
||||
} => {
|
||||
for v in initial_inputs {
|
||||
visitor.visit_value_use(v);
|
||||
}
|
||||
visitor.visit_control_region_def(self_.at(*body));
|
||||
visitor.visit_value_use(repeat_condition);
|
||||
}
|
||||
}
|
||||
for output in outputs {
|
||||
output.inner_visit_with(visitor);
|
||||
}
|
||||
}
|
||||
control_node_inner_visit_with(func_at_control_node, self);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// HACK(eddyb) this works around the accidental lack of `spirt::Value: Hash`.
|
||||
#[derive(Copy, Clone, PartialEq, Eq)]
|
||||
struct HashableValue(Value);
|
||||
#[allow(clippy::derive_hash_xor_eq)]
|
||||
impl Hash for HashableValue {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
use spirt::*;
|
||||
#[derive(Hash)]
|
||||
enum ValueH {
|
||||
Const(Const),
|
||||
ControlRegionInput {
|
||||
region: ControlRegion,
|
||||
input_idx: u32,
|
||||
},
|
||||
ControlNodeOutput {
|
||||
control_node: ControlNode,
|
||||
output_idx: u32,
|
||||
},
|
||||
DataInstOutput(DataInst),
|
||||
}
|
||||
match self.0 {
|
||||
Value::Const(ct) => ValueH::Const(ct),
|
||||
Value::ControlRegionInput { region, input_idx } => {
|
||||
ValueH::ControlRegionInput { region, input_idx }
|
||||
}
|
||||
Value::ControlNodeOutput {
|
||||
control_node,
|
||||
output_idx,
|
||||
} => ValueH::ControlNodeOutput {
|
||||
control_node,
|
||||
output_idx,
|
||||
},
|
||||
Value::DataInstOutput(inst) => ValueH::DataInstOutput(inst),
|
||||
}
|
||||
.hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME(eddyb) maybe this should be provided by `spirt::transform`.
|
||||
struct ReplaceValueWith<F>(F);
|
||||
const _: () = {
|
||||
use spirt::{transform::*, *};
|
||||
|
||||
impl<F: FnMut(Value) -> Option<Value>> Transformer for ReplaceValueWith<F> {
|
||||
fn transform_value_use(&mut self, v: &Value) -> Transformed<Value> {
|
||||
self.0(*v).map_or(Transformed::Unchanged, Transformed::Changed)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Clean up after a pass by removing unused (pure) `Value` definitions from
|
||||
/// a function body (both `DataInst`s and `ControlRegion` inputs/outputs).
|
||||
//
|
||||
// FIXME(eddyb) should this be a dedicated pass?
|
||||
fn remove_unused_values_in_func(func_def_body: &mut FuncDefBody) {
|
||||
// Avoid having to support unstructured control-flow.
|
||||
if func_def_body.unstructured_cfg.is_some() {
|
||||
return;
|
||||
}
|
||||
|
||||
let wk = &SpvSpecWithExtras::get().well_known;
|
||||
|
||||
struct Propagator {
|
||||
func_body_region: ControlRegion,
|
||||
|
||||
// FIXME(eddyb) maybe this kind of "parent map" should be provided by SPIR-T?
|
||||
loop_body_to_loop: EntityOrientedDenseMap<ControlRegion, ControlNode>,
|
||||
|
||||
// FIXME(eddyb) entity-keyed dense sets might be better for performance,
|
||||
// but would require separate sets/maps for separate `Value` cases.
|
||||
used: FxHashSet<HashableValue>,
|
||||
|
||||
queue: VecDeque<Value>,
|
||||
}
|
||||
impl Propagator {
|
||||
fn mark_used(&mut self, v: Value) {
|
||||
if let Value::Const(_) = v {
|
||||
return;
|
||||
}
|
||||
if let Value::ControlRegionInput {
|
||||
region,
|
||||
input_idx: _,
|
||||
} = v
|
||||
{
|
||||
if region == self.func_body_region {
|
||||
return;
|
||||
}
|
||||
}
|
||||
if self.used.insert(HashableValue(v)) {
|
||||
self.queue.push_back(v);
|
||||
}
|
||||
}
|
||||
fn propagate_used(&mut self, func: FuncAt<'_, ()>) {
|
||||
while let Some(v) = self.queue.pop_front() {
|
||||
match v {
|
||||
Value::Const(_) => unreachable!(),
|
||||
Value::ControlRegionInput { region, input_idx } => {
|
||||
let loop_node = self.loop_body_to_loop[region];
|
||||
let initial_inputs = match &func.at(loop_node).def().kind {
|
||||
ControlNodeKind::Loop { initial_inputs, .. } => initial_inputs,
|
||||
// NOTE(eddyb) only `Loop`s' bodies can have inputs right now.
|
||||
_ => unreachable!(),
|
||||
};
|
||||
self.mark_used(initial_inputs[input_idx as usize]);
|
||||
self.mark_used(func.at(region).def().outputs[input_idx as usize]);
|
||||
}
|
||||
Value::ControlNodeOutput {
|
||||
control_node,
|
||||
output_idx,
|
||||
} => {
|
||||
let cases = match &func.at(control_node).def().kind {
|
||||
ControlNodeKind::Select { cases, .. } => cases,
|
||||
// NOTE(eddyb) only `Select`s can have outputs right now.
|
||||
_ => unreachable!(),
|
||||
};
|
||||
for &case in cases {
|
||||
self.mark_used(func.at(case).def().outputs[output_idx as usize]);
|
||||
}
|
||||
}
|
||||
Value::DataInstOutput(inst) => {
|
||||
for &input in &func.at(inst).def().inputs {
|
||||
self.mark_used(input);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HACK(eddyb) it's simpler to first ensure `loop_body_to_loop` is computed,
|
||||
// just to allow the later unordered propagation to always work.
|
||||
let propagator = {
|
||||
let mut visitor = VisitAllControlRegionsAndNodes {
|
||||
state: Propagator {
|
||||
func_body_region: func_def_body.body,
|
||||
loop_body_to_loop: Default::default(),
|
||||
used: Default::default(),
|
||||
queue: Default::default(),
|
||||
},
|
||||
visit_control_region: |_: &mut _, _| {},
|
||||
visit_control_node:
|
||||
|propagator: &mut Propagator, func_at_control_node: FuncAt<'_, ControlNode>| {
|
||||
if let ControlNodeKind::Loop { body, .. } = func_at_control_node.def().kind {
|
||||
propagator
|
||||
.loop_body_to_loop
|
||||
.insert(body, func_at_control_node.position);
|
||||
}
|
||||
},
|
||||
};
|
||||
func_def_body.inner_visit_with(&mut visitor);
|
||||
visitor.state
|
||||
};
|
||||
|
||||
// HACK(eddyb) this kind of random-access is easier than using `spirt::transform`.
|
||||
let mut all_control_nodes = vec![];
|
||||
|
||||
let used_values = {
|
||||
let mut visitor = VisitAllControlRegionsAndNodes {
|
||||
state: propagator,
|
||||
visit_control_region: |_: &mut _, _| {},
|
||||
visit_control_node:
|
||||
|propagator: &mut Propagator, func_at_control_node: FuncAt<'_, ControlNode>| {
|
||||
all_control_nodes.push(func_at_control_node.position);
|
||||
|
||||
let mut mark_used_and_propagate = |v| {
|
||||
propagator.mark_used(v);
|
||||
propagator.propagate_used(func_at_control_node.at(()));
|
||||
};
|
||||
match func_at_control_node.def().kind {
|
||||
ControlNodeKind::Block { insts } => {
|
||||
for func_at_inst in func_at_control_node.at(insts) {
|
||||
// Ignore pure instructions (i.e. they're only used
|
||||
// if their output value is used, from somewhere else).
|
||||
if let DataInstKind::SpvInst(spv_inst) = &func_at_inst.def().kind {
|
||||
// HACK(eddyb) small selection relevant for now,
|
||||
// but should be extended using e.g. a bitset.
|
||||
if [wk.OpNop, wk.OpCompositeInsert].contains(&spv_inst.opcode) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
mark_used_and_propagate(Value::DataInstOutput(
|
||||
func_at_inst.position,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
ControlNodeKind::Select { scrutinee: v, .. }
|
||||
| ControlNodeKind::Loop {
|
||||
repeat_condition: v,
|
||||
..
|
||||
} => mark_used_and_propagate(v),
|
||||
}
|
||||
},
|
||||
};
|
||||
func_def_body.inner_visit_with(&mut visitor);
|
||||
|
||||
let mut propagator = visitor.state;
|
||||
for &v in &func_def_body.at_body().def().outputs {
|
||||
propagator.mark_used(v);
|
||||
propagator.propagate_used(func_def_body.at(()));
|
||||
}
|
||||
|
||||
assert!(propagator.queue.is_empty());
|
||||
propagator.used
|
||||
};
|
||||
|
||||
// FIXME(eddyb) entity-keyed dense maps might be better for performance,
|
||||
// but would require separate maps for separate `Value` cases.
|
||||
let mut value_replacements = FxHashMap::default();
|
||||
|
||||
// Remove anything that didn't end up marked as used (directly or indirectly).
|
||||
for control_node in all_control_nodes {
|
||||
let control_node_def = func_def_body.at(control_node).def();
|
||||
match &control_node_def.kind {
|
||||
&ControlNodeKind::Block { insts } => {
|
||||
let mut all_nops = true;
|
||||
let mut func_at_inst_iter = func_def_body.at_mut(insts).into_iter();
|
||||
while let Some(mut func_at_inst) = func_at_inst_iter.next() {
|
||||
if let DataInstKind::SpvInst(spv_inst) = &func_at_inst.reborrow().def().kind {
|
||||
if spv_inst.opcode == wk.OpNop {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if !used_values
|
||||
.contains(&HashableValue(Value::DataInstOutput(func_at_inst.position)))
|
||||
{
|
||||
// Replace the removed `DataInstDef` itself with `OpNop`,
|
||||
// removing the ability to use its "name" as a value.
|
||||
*func_at_inst.def() = DataInstDef {
|
||||
attrs: Default::default(),
|
||||
kind: DataInstKind::SpvInst(wk.OpNop.into()),
|
||||
output_type: None,
|
||||
inputs: iter::empty().collect(),
|
||||
};
|
||||
continue;
|
||||
}
|
||||
all_nops = false;
|
||||
}
|
||||
// HACK(eddyb) because we can't remove list elements yet, we
|
||||
// instead replace blocks of `OpNop`s with empty ones.
|
||||
if all_nops {
|
||||
func_def_body.at_mut(control_node).def().kind = ControlNodeKind::Block {
|
||||
insts: Default::default(),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
ControlNodeKind::Select { cases, .. } => {
|
||||
// FIXME(eddyb) remove this cloning.
|
||||
let cases = cases.clone();
|
||||
|
||||
let mut new_idx = 0;
|
||||
for original_idx in 0..control_node_def.outputs.len() {
|
||||
let original_output = Value::ControlNodeOutput {
|
||||
control_node,
|
||||
output_idx: original_idx as u32,
|
||||
};
|
||||
|
||||
if !used_values.contains(&HashableValue(original_output)) {
|
||||
// Remove the output definition and corresponding value from all cases.
|
||||
func_def_body
|
||||
.at_mut(control_node)
|
||||
.def()
|
||||
.outputs
|
||||
.remove(new_idx);
|
||||
for &case in &cases {
|
||||
func_def_body.at_mut(case).def().outputs.remove(new_idx);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Record remappings for any still-used outputs that got "shifted over".
|
||||
if original_idx != new_idx {
|
||||
let new_output = Value::ControlNodeOutput {
|
||||
control_node,
|
||||
output_idx: new_idx as u32,
|
||||
};
|
||||
value_replacements.insert(HashableValue(original_output), new_output);
|
||||
}
|
||||
new_idx += 1;
|
||||
}
|
||||
}
|
||||
ControlNodeKind::Loop {
|
||||
body,
|
||||
initial_inputs,
|
||||
..
|
||||
} => {
|
||||
let body = *body;
|
||||
|
||||
let mut new_idx = 0;
|
||||
for original_idx in 0..initial_inputs.len() {
|
||||
let original_input = Value::ControlRegionInput {
|
||||
region: body,
|
||||
input_idx: original_idx as u32,
|
||||
};
|
||||
|
||||
if !used_values.contains(&HashableValue(original_input)) {
|
||||
// Remove the input definition and corresponding values.
|
||||
match &mut func_def_body.at_mut(control_node).def().kind {
|
||||
ControlNodeKind::Loop { initial_inputs, .. } => {
|
||||
initial_inputs.remove(new_idx);
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
let body_def = func_def_body.at_mut(body).def();
|
||||
body_def.inputs.remove(new_idx);
|
||||
body_def.outputs.remove(new_idx);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Record remappings for any still-used inputs that got "shifted over".
|
||||
if original_idx != new_idx {
|
||||
let new_input = Value::ControlRegionInput {
|
||||
region: body,
|
||||
input_idx: new_idx as u32,
|
||||
};
|
||||
value_replacements.insert(HashableValue(original_input), new_input);
|
||||
}
|
||||
new_idx += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !value_replacements.is_empty() {
|
||||
func_def_body.inner_in_place_transform_with(&mut ReplaceValueWith(|v| match v {
|
||||
Value::Const(_) => None,
|
||||
_ => value_replacements.get(&HashableValue(v)).copied(),
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
879
crates/rustc_codegen_spirv/src/linker/spirt_passes/reduce.rs
Normal file
879
crates/rustc_codegen_spirv/src/linker/spirt_passes/reduce.rs
Normal file
@ -0,0 +1,879 @@
|
||||
use rustc_data_structures::fx::FxHashMap;
|
||||
use smallvec::SmallVec;
|
||||
use spirt::func_at::{FuncAt, FuncAtMut};
|
||||
use spirt::transform::InnerInPlaceTransform;
|
||||
use spirt::visit::InnerVisit;
|
||||
use spirt::{
|
||||
spv, Const, ConstCtor, ConstDef, Context, ControlNode, ControlNodeDef, ControlNodeKind,
|
||||
ControlNodeOutputDecl, ControlRegion, ControlRegionInputDecl, DataInst, DataInstDef,
|
||||
DataInstKind, EntityOrientedDenseMap, FuncDefBody, SelectionKind, Type, TypeCtor, TypeDef,
|
||||
Value,
|
||||
};
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::convert::{TryFrom, TryInto};
|
||||
use std::hash::Hash;
|
||||
use std::{iter, slice};
|
||||
|
||||
use super::{HashableValue, ReplaceValueWith, VisitAllControlRegionsAndNodes};
|
||||
|
||||
/// Apply "reduction rules" to `func_def_body`, replacing (pure) computations
|
||||
/// with one of their inputs or a constant (e.g. `x + 0 => x` or `1 + 2 => 3`),
|
||||
/// and at most only adding more `Select` outputs/`Loop` state (where necessary)
|
||||
/// but never any new instructions (unlike e.g. LLVM's instcombine).
|
||||
pub(crate) fn reduce_in_func(cx: &Context, func_def_body: &mut FuncDefBody) {
|
||||
let wk = &super::SpvSpecWithExtras::get().well_known;
|
||||
|
||||
let parent_map = ParentMap::new(func_def_body);
|
||||
|
||||
// FIXME(eddyb) entity-keyed dense maps might be better for performance,
|
||||
// but would require separate maps for separate `Value` cases.
|
||||
let mut value_replacements = FxHashMap::default();
|
||||
|
||||
let mut reduction_cache = FxHashMap::default();
|
||||
|
||||
// HACK(eddyb) this is an annoying workaround for iterator invalidation
|
||||
// (SPIR-T iterators don't cope well with the underlying data changing).
|
||||
//
|
||||
// FIXME(eddyb) replace SPIR-T `FuncAtMut<EntityListIter<T>>` with some
|
||||
// kind of "list cursor", maybe even allowing removal during traversal.
|
||||
let mut reduction_queue = vec![];
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
enum ReductionTarget {
|
||||
/// Replace uses of a `DataInst` with a reduced `Value`.
|
||||
DataInst(DataInst),
|
||||
|
||||
/// Replace an `OpSwitch` `ControlNode` with an `if`-`else` one.
|
||||
//
|
||||
// HACK(eddyb) see comment in `handle_control_node` for more details.
|
||||
SwitchToIfElse(ControlNode),
|
||||
}
|
||||
|
||||
loop {
|
||||
let old_value_replacements_len = value_replacements.len();
|
||||
|
||||
// HACK(eddyb) we want to transform `DataInstDef`s, while having the ability
|
||||
// to (mutably) traverse the function, but `in_place_transform_data_inst_def`
|
||||
// only gives us a `&mut DataInstDef` (without the `FuncAtMut` around it).
|
||||
//
|
||||
// HACK(eddyb) ignore the above, for now it's pretty bad due to iterator
|
||||
// invalidation (see comment on `let reduction_queue` too).
|
||||
let mut handle_control_node =
|
||||
|func_at_control_node: FuncAt<'_, ControlNode>| match func_at_control_node.def() {
|
||||
&ControlNodeDef {
|
||||
kind: ControlNodeKind::Block { insts },
|
||||
..
|
||||
} => {
|
||||
for func_at_inst in func_at_control_node.at(insts) {
|
||||
if let Ok(redu) = Reducible::try_from(func_at_inst.def()) {
|
||||
let redu_target = ReductionTarget::DataInst(func_at_inst.position);
|
||||
reduction_queue.push((redu_target, redu));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ControlNodeDef {
|
||||
kind:
|
||||
ControlNodeKind::Select {
|
||||
kind,
|
||||
scrutinee,
|
||||
cases,
|
||||
},
|
||||
outputs,
|
||||
} => {
|
||||
// FIXME(eddyb) this should probably be ran in the queue loop
|
||||
// below, to more quickly benefit from previous reductions.
|
||||
for i in 0..u32::try_from(outputs.len()).unwrap() {
|
||||
let output = Value::ControlNodeOutput {
|
||||
control_node: func_at_control_node.position,
|
||||
output_idx: i,
|
||||
};
|
||||
if let Entry::Vacant(entry) =
|
||||
value_replacements.entry(HashableValue(output))
|
||||
{
|
||||
let per_case_value = cases.iter().map(|&case| {
|
||||
func_at_control_node.at(case).def().outputs[i as usize]
|
||||
});
|
||||
if let Some(reduced) = try_reduce_select(
|
||||
cx,
|
||||
&parent_map,
|
||||
func_at_control_node.position,
|
||||
kind,
|
||||
*scrutinee,
|
||||
per_case_value,
|
||||
) {
|
||||
entry.insert(reduced);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HACK(eddyb) turn `switch x { case 0: A; case 1: B; default: ... }`
|
||||
// into `if ... {B} else {A}`, when `x` ends up limited in `0..=1`,
|
||||
// (such `switch`es come from e.g. `match`-ing enums w/ 2 variants)
|
||||
// allowing us to bypass SPIR-T current (and temporary) lossiness
|
||||
// wrt `default: OpUnreachable` (i.e. we prove the `default:` can't
|
||||
// be entered based on `x` not having values other than `0` or `1`)
|
||||
if let SelectionKind::SpvInst(spv_inst) = kind {
|
||||
if spv_inst.opcode == wk.OpSwitch && cases.len() == 3 {
|
||||
// FIXME(eddyb) this kind of `OpSwitch` decoding logic should
|
||||
// be done by SPIR-T ahead of time, not here.
|
||||
let num_logical_imms = cases.len() - 1;
|
||||
assert_eq!(spv_inst.imms.len() % num_logical_imms, 0);
|
||||
let logical_imm_size = spv_inst.imms.len() / num_logical_imms;
|
||||
// FIXME(eddyb) collect to array instead.
|
||||
let logical_imms_as_u32s: SmallVec<[_; 2]> = spv_inst
|
||||
.imms
|
||||
.chunks(logical_imm_size)
|
||||
.map(spv_imm_checked_trunc32)
|
||||
.collect();
|
||||
|
||||
// FIMXE(eddyb) support more values than just `0..=1`.
|
||||
if logical_imms_as_u32s[..] == [Some(0), Some(1)] {
|
||||
let redu = Reducible {
|
||||
op: PureOp::IntToBool,
|
||||
output_type: cx.intern(TypeDef {
|
||||
attrs: Default::default(),
|
||||
ctor: TypeCtor::SpvInst(wk.OpTypeBool.into()),
|
||||
ctor_args: iter::empty().collect(),
|
||||
}),
|
||||
input: *scrutinee,
|
||||
};
|
||||
let redu_target =
|
||||
ReductionTarget::SwitchToIfElse(func_at_control_node.position);
|
||||
reduction_queue.push((redu_target, redu));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ControlNodeDef {
|
||||
kind:
|
||||
ControlNodeKind::Loop {
|
||||
body,
|
||||
initial_inputs,
|
||||
..
|
||||
},
|
||||
..
|
||||
} => {
|
||||
// FIXME(eddyb) this should probably be ran in the queue loop
|
||||
// below, to more quickly benefit from previous reductions.
|
||||
let body_outputs = &func_at_control_node.at(*body).def().outputs;
|
||||
for (i, (&initial_input, &body_output)) in
|
||||
initial_inputs.iter().zip(body_outputs).enumerate()
|
||||
{
|
||||
let body_input = Value::ControlRegionInput {
|
||||
region: *body,
|
||||
input_idx: i as u32,
|
||||
};
|
||||
if body_output == body_input {
|
||||
value_replacements
|
||||
.entry(HashableValue(body_input))
|
||||
.or_insert(initial_input);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
func_def_body.inner_visit_with(&mut VisitAllControlRegionsAndNodes {
|
||||
state: (),
|
||||
visit_control_region: |_: &mut (), _| {},
|
||||
visit_control_node: |_: &mut (), func_at_control_node| {
|
||||
handle_control_node(func_at_control_node);
|
||||
},
|
||||
});
|
||||
|
||||
// FIXME(eddyb) should this loop become the only loop, by having loop
|
||||
// reductions push the new instruction to `reduction_queue`? the problem
|
||||
// then is that it's not trivial to figure out what else might benefit
|
||||
// from another full scan, so perhaps the only solution is "demand-driven"
|
||||
// (recursing into use->def, instead of processing defs).
|
||||
let mut any_changes = false;
|
||||
for (redu_target, redu) in reduction_queue.drain(..) {
|
||||
if let Some(v) = redu.try_reduce(
|
||||
cx,
|
||||
func_def_body.at_mut(()),
|
||||
&value_replacements,
|
||||
&parent_map,
|
||||
&mut reduction_cache,
|
||||
) {
|
||||
any_changes = true;
|
||||
match redu_target {
|
||||
ReductionTarget::DataInst(inst) => {
|
||||
value_replacements.insert(HashableValue(Value::DataInstOutput(inst)), v);
|
||||
|
||||
// Replace the reduced `DataInstDef` itself with `OpNop`,
|
||||
// removing the ability to use its "name" as a value.
|
||||
*func_def_body.at_mut(inst).def() = DataInstDef {
|
||||
attrs: Default::default(),
|
||||
kind: DataInstKind::SpvInst(wk.OpNop.into()),
|
||||
output_type: None,
|
||||
inputs: iter::empty().collect(),
|
||||
};
|
||||
}
|
||||
|
||||
// HACK(eddyb) see comment in `handle_control_node` for more details.
|
||||
ReductionTarget::SwitchToIfElse(control_node) => {
|
||||
let control_node_def = func_def_body.at_mut(control_node).def();
|
||||
match &control_node_def.kind {
|
||||
ControlNodeKind::Select { cases, .. } => match cases[..] {
|
||||
[_default, case_0, case_1] => {
|
||||
control_node_def.kind = ControlNodeKind::Select {
|
||||
kind: SelectionKind::BoolCond,
|
||||
scrutinee: v,
|
||||
cases: [case_1, case_0].iter().copied().collect(),
|
||||
};
|
||||
}
|
||||
_ => unreachable!(),
|
||||
},
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !any_changes && old_value_replacements_len == value_replacements.len() {
|
||||
break;
|
||||
}
|
||||
|
||||
func_def_body.inner_in_place_transform_with(&mut ReplaceValueWith(|v| match v {
|
||||
Value::Const(_) => None,
|
||||
_ => value_replacements
|
||||
.get(&HashableValue(v))
|
||||
.copied()
|
||||
.map(|new| {
|
||||
any_changes = true;
|
||||
new
|
||||
}),
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME(eddyb) maybe this kind of "parent map" should be provided by SPIR-T?
|
||||
#[derive(Default)]
|
||||
struct ParentMap {
|
||||
data_inst_parent: EntityOrientedDenseMap<DataInst, ControlNode>,
|
||||
control_node_parent: EntityOrientedDenseMap<ControlNode, ControlRegion>,
|
||||
control_region_parent: EntityOrientedDenseMap<ControlRegion, ControlNode>,
|
||||
}
|
||||
|
||||
impl ParentMap {
|
||||
fn new(func_def_body: &FuncDefBody) -> Self {
|
||||
let mut visitor = VisitAllControlRegionsAndNodes {
|
||||
state: Self::default(),
|
||||
visit_control_region:
|
||||
|this: &mut Self, func_at_control_region: FuncAt<'_, ControlRegion>| {
|
||||
for func_at_child_control_node in func_at_control_region.at_children() {
|
||||
this.control_node_parent.insert(
|
||||
func_at_child_control_node.position,
|
||||
func_at_control_region.position,
|
||||
);
|
||||
}
|
||||
},
|
||||
visit_control_node: |this: &mut Self, func_at_control_node: FuncAt<'_, ControlNode>| {
|
||||
let child_regions = match &func_at_control_node.def().kind {
|
||||
&ControlNodeKind::Block { insts } => {
|
||||
for func_at_inst in func_at_control_node.at(insts) {
|
||||
this.data_inst_parent
|
||||
.insert(func_at_inst.position, func_at_control_node.position);
|
||||
}
|
||||
&[][..]
|
||||
}
|
||||
|
||||
ControlNodeKind::Select { cases, .. } => cases,
|
||||
ControlNodeKind::Loop { body, .. } => slice::from_ref(body),
|
||||
};
|
||||
for &child_region in child_regions {
|
||||
this.control_region_parent
|
||||
.insert(child_region, func_at_control_node.position);
|
||||
}
|
||||
},
|
||||
};
|
||||
func_def_body.inner_visit_with(&mut visitor);
|
||||
visitor.state
|
||||
}
|
||||
}
|
||||
|
||||
/// If possible, find a single `Value` from `cases` (or even `scrutinee`),
|
||||
/// which would always be a valid result for `Select(kind, scrutinee, cases)`,
|
||||
/// regardless of which case gets (dynamically) taken.
|
||||
fn try_reduce_select(
|
||||
cx: &Context,
|
||||
parent_map: &ParentMap,
|
||||
select_control_node: ControlNode,
|
||||
// FIXME(eddyb) are these redundant with the `ControlNode` above?
|
||||
kind: &SelectionKind,
|
||||
scrutinee: Value,
|
||||
cases: impl Iterator<Item = Value>,
|
||||
) -> Option<Value> {
|
||||
let wk = &super::SpvSpecWithExtras::get().well_known;
|
||||
|
||||
let as_spv_const = |v: Value| match v {
|
||||
Value::Const(ct) => match &cx[ct].ctor {
|
||||
ConstCtor::SpvInst(spv_inst) => Some(spv_inst.opcode),
|
||||
_ => None,
|
||||
},
|
||||
_ => None,
|
||||
};
|
||||
|
||||
// Ignore `OpUndef`s, as they can be legally substituted with any other value.
|
||||
let mut first_undef = None;
|
||||
let mut non_undef_cases = cases.filter(|&case| {
|
||||
let is_undef = as_spv_const(case) == Some(wk.OpUndef);
|
||||
if is_undef && first_undef.is_none() {
|
||||
first_undef = Some(case);
|
||||
}
|
||||
!is_undef
|
||||
});
|
||||
match (non_undef_cases.next(), non_undef_cases.next()) {
|
||||
(None, _) => first_undef,
|
||||
|
||||
// `Select(c: bool, true, false)` can be replaced with just `c`.
|
||||
(Some(x), Some(y))
|
||||
if matches!(kind, SelectionKind::BoolCond)
|
||||
&& as_spv_const(x) == Some(wk.OpConstantTrue)
|
||||
&& as_spv_const(y) == Some(wk.OpConstantFalse) =>
|
||||
{
|
||||
assert!(non_undef_cases.next().is_none() && first_undef.is_none());
|
||||
|
||||
Some(scrutinee)
|
||||
}
|
||||
|
||||
(Some(x), y) => {
|
||||
if y.into_iter().chain(non_undef_cases).all(|z| z == x) {
|
||||
// HACK(eddyb) closure here serves as `try` block.
|
||||
let is_x_valid_outside_select = || {
|
||||
// Constants are always valid.
|
||||
if let Value::Const(_) = x {
|
||||
return Some(());
|
||||
}
|
||||
|
||||
// HACK(eddyb) if the same value appears in two different
|
||||
// cases, it's definitely dominating the whole `Select`.
|
||||
if y.is_some() {
|
||||
return Some(());
|
||||
}
|
||||
|
||||
// In general, `x` dominating the `Select` is what would
|
||||
// allow lifting an use of it outside the `Select`.
|
||||
let region_defining_x = match x {
|
||||
Value::Const(_) => unreachable!(),
|
||||
Value::ControlRegionInput { region, .. } => region,
|
||||
Value::ControlNodeOutput { control_node, .. } => {
|
||||
*parent_map.control_node_parent.get(control_node)?
|
||||
}
|
||||
Value::DataInstOutput(inst) => *parent_map
|
||||
.control_node_parent
|
||||
.get(*parent_map.data_inst_parent.get(inst)?)?,
|
||||
};
|
||||
|
||||
// Fast-reject: if `x` is defined immediately inside one of
|
||||
// `select_control_node`'s cases, it's not a dominator.
|
||||
if parent_map.control_region_parent.get(region_defining_x)
|
||||
== Some(&select_control_node)
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
// Since we know `x` is used inside the `Select`, this only
|
||||
// needs to check that `x` is defined in a region that the
|
||||
// `Select` is nested in, as the only other possibility is
|
||||
// that the `x` is defined inside the `Select` - that is,
|
||||
// one of `x` and `Select` always dominates the other.
|
||||
//
|
||||
// FIXME(eddyb) this could be more efficient with some kind
|
||||
// of "region depth" precomputation but a potentially-slower
|
||||
// check doubles as a sanity check, for now.
|
||||
let mut region_containing_select =
|
||||
*parent_map.control_node_parent.get(select_control_node)?;
|
||||
loop {
|
||||
if region_containing_select == region_defining_x {
|
||||
return Some(());
|
||||
}
|
||||
region_containing_select = *parent_map.control_node_parent.get(
|
||||
*parent_map
|
||||
.control_region_parent
|
||||
.get(region_containing_select)?,
|
||||
)?;
|
||||
}
|
||||
};
|
||||
if is_x_valid_outside_select().is_some() {
|
||||
return Some(x);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Pure operation that transforms one `Value` into another `Value`.
|
||||
//
|
||||
// FIXME(eddyb) move this elsewhere? also, how should binops etc. be supported?
|
||||
// (one approach could be having a "focus input" that can be dynamic, with the
|
||||
// other inputs being `Const`s, i.e. partially applying all but one input)
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
|
||||
enum PureOp {
|
||||
BitCast,
|
||||
CompositeExtract {
|
||||
elem_idx: spv::Imm,
|
||||
},
|
||||
|
||||
/// Maps `0` to `false`, and `1` to `true`, but any other input values won't
|
||||
/// allow reduction, which is used to signal `0..=1` isn't being guaranteed.
|
||||
//
|
||||
// HACK(eddyb) not a real operation, but a helper used to extract a `bool`
|
||||
// equivalent for an `OpSwitch`'s scrutinee.
|
||||
// FIXME(eddyb) proper SPIR-T range analysis should be implemented and such
|
||||
// a reduction not attempted at all if the range is larger than `0..=1`
|
||||
// (also, the actual operation can be replaced with `x == 1` or `x != 0`)
|
||||
IntToBool,
|
||||
}
|
||||
|
||||
impl TryFrom<&spv::Inst> for PureOp {
|
||||
type Error = ();
|
||||
fn try_from(spv_inst: &spv::Inst) -> Result<Self, ()> {
|
||||
let wk = &super::SpvSpecWithExtras::get().well_known;
|
||||
|
||||
let op = spv_inst.opcode;
|
||||
Ok(match spv_inst.imms[..] {
|
||||
[] if op == wk.OpBitcast => Self::BitCast,
|
||||
|
||||
// FIXME(eddyb) support more than one index at a time, somehow.
|
||||
[elem_idx] if op == wk.OpCompositeExtract => Self::CompositeExtract { elem_idx },
|
||||
|
||||
_ => return Err(()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<PureOp> for spv::Inst {
|
||||
type Error = ();
|
||||
fn try_from(op: PureOp) -> Result<Self, ()> {
|
||||
let wk = &super::SpvSpecWithExtras::get().well_known;
|
||||
|
||||
let (opcode, imms) = match op {
|
||||
PureOp::BitCast => (wk.OpBitcast, iter::empty().collect()),
|
||||
PureOp::CompositeExtract { elem_idx } => {
|
||||
(wk.OpCompositeExtract, iter::once(elem_idx).collect())
|
||||
}
|
||||
|
||||
// HACK(eddyb) this is the only reason this is `TryFrom` not `From`.
|
||||
PureOp::IntToBool => return Err(()),
|
||||
};
|
||||
Ok(Self { opcode, imms })
|
||||
}
|
||||
}
|
||||
|
||||
/// Potentially-reducible application of a `PureOp` (`op`) to `input`.
|
||||
#[derive(Copy, Clone, PartialEq, Eq)]
|
||||
struct Reducible<V = Value> {
|
||||
op: PureOp,
|
||||
output_type: Type,
|
||||
input: V,
|
||||
}
|
||||
|
||||
// HACK(eddyb) this works around the accidental lack of `spirt::Value: Hash`.
|
||||
impl Hash for Reducible<Value> {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
self.op.hash(state);
|
||||
self.output_type.hash(state);
|
||||
HashableValue(self.input).hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
impl<V> Reducible<V> {
|
||||
fn with_input<V2>(self, new_input: V2) -> Reducible<V2> {
|
||||
Reducible {
|
||||
op: self.op,
|
||||
output_type: self.output_type,
|
||||
input: new_input,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&DataInstDef> for Reducible {
|
||||
type Error = ();
|
||||
fn try_from(inst_def: &DataInstDef) -> Result<Self, ()> {
|
||||
if let DataInstKind::SpvInst(spv_inst) = &inst_def.kind {
|
||||
let op = PureOp::try_from(spv_inst)?;
|
||||
let output_type = inst_def.output_type.unwrap();
|
||||
if let [input] = inst_def.inputs[..] {
|
||||
return Ok(Self {
|
||||
op,
|
||||
output_type,
|
||||
input,
|
||||
});
|
||||
}
|
||||
}
|
||||
Err(())
|
||||
}
|
||||
}
|
||||
|
||||
// HACK(eddyb) `IntToBool` is the only reason this is `TryFrom` not `From`.
|
||||
impl TryFrom<Reducible> for DataInstDef {
|
||||
type Error = ();
|
||||
fn try_from(redu: Reducible) -> Result<Self, ()> {
|
||||
Ok(Self {
|
||||
attrs: Default::default(),
|
||||
kind: DataInstKind::SpvInst(redu.op.try_into()?),
|
||||
output_type: Some(redu.output_type),
|
||||
inputs: iter::once(redu.input).collect(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `Some(lowest32)` iff `imms` contains one *logical* SPIR-V immediate
|
||||
/// representing a (little-endian) integer which truncates (if wider than 32 bits)
|
||||
/// to `lowest32`, losslessly (i.e. the rest of the bits are all zeros).
|
||||
//
|
||||
// FIXME(eddyb) move this into some kind of utility/common helpers place.
|
||||
fn spv_imm_checked_trunc32(imms: &[spv::Imm]) -> Option<u32> {
|
||||
match imms {
|
||||
&[spv::Imm::Short(_, lowest32)] | &[spv::Imm::LongStart(_, lowest32), ..]
|
||||
if imms[1..]
|
||||
.iter()
|
||||
.all(|imm| matches!(imm, spv::Imm::LongCont(_, 0))) =>
|
||||
{
|
||||
Some(lowest32)
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
impl Reducible<Const> {
|
||||
// FIXME(eddyb) in theory this should always return `Some`.
|
||||
fn try_reduce_const(&self, cx: &Context) -> Option<Const> {
|
||||
let wk = &super::SpvSpecWithExtras::get().well_known;
|
||||
|
||||
let ct_def = &cx[self.input];
|
||||
match (self.op, &ct_def.ctor) {
|
||||
(_, ConstCtor::SpvInst(spv_inst)) if spv_inst.opcode == wk.OpUndef => {
|
||||
Some(cx.intern(ConstDef {
|
||||
attrs: ct_def.attrs,
|
||||
ty: self.output_type,
|
||||
ctor: ct_def.ctor.clone(),
|
||||
ctor_args: iter::empty().collect(),
|
||||
}))
|
||||
}
|
||||
|
||||
(PureOp::BitCast, ConstCtor::SpvInst(spv_inst)) if spv_inst.opcode == wk.OpConstant => {
|
||||
// `OpTypeInt`/`OpTypeFloat` bit width.
|
||||
let scalar_width = |ty: Type| match &cx[ty].ctor {
|
||||
TypeCtor::SpvInst(spv_inst)
|
||||
if [wk.OpTypeInt, wk.OpTypeFloat].contains(&spv_inst.opcode) =>
|
||||
{
|
||||
Some(spv_inst.imms[0])
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
match (scalar_width(ct_def.ty), scalar_width(self.output_type)) {
|
||||
(Some(from), Some(to)) if from == to => Some(cx.intern(ConstDef {
|
||||
attrs: ct_def.attrs,
|
||||
ty: self.output_type,
|
||||
ctor: ct_def.ctor.clone(),
|
||||
ctor_args: ct_def.ctor_args.clone(),
|
||||
})),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
(
|
||||
PureOp::CompositeExtract {
|
||||
elem_idx: spv::Imm::Short(_, elem_idx),
|
||||
},
|
||||
ConstCtor::SpvInst(spv_inst),
|
||||
) if spv_inst.opcode == wk.OpConstantComposite => {
|
||||
Some(ct_def.ctor_args[elem_idx as usize])
|
||||
}
|
||||
|
||||
(PureOp::IntToBool, ConstCtor::SpvInst(spv_inst))
|
||||
if spv_inst.opcode == wk.OpConstant =>
|
||||
{
|
||||
let bool_const_op = match spv_imm_checked_trunc32(&spv_inst.imms[..]) {
|
||||
Some(0) => wk.OpConstantFalse,
|
||||
Some(1) => wk.OpConstantTrue,
|
||||
_ => return None,
|
||||
};
|
||||
Some(cx.intern(ConstDef {
|
||||
attrs: Default::default(),
|
||||
ty: self.output_type,
|
||||
ctor: ConstCtor::SpvInst(bool_const_op.into()),
|
||||
ctor_args: iter::empty().collect(),
|
||||
}))
|
||||
}
|
||||
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Outcome of a single step of a reduction (which may require more steps).
|
||||
enum ReductionStep {
|
||||
Complete(Value),
|
||||
Partial(Reducible),
|
||||
}
|
||||
|
||||
impl Reducible<&DataInstDef> {
|
||||
// FIXME(eddyb) force the input to actually be itself some kind of pure op.
|
||||
fn try_reduce_output_of_data_inst(&self) -> Option<ReductionStep> {
|
||||
let wk = &super::SpvSpecWithExtras::get().well_known;
|
||||
|
||||
let input_inst_def = self.input;
|
||||
if let DataInstKind::SpvInst(input_spv_inst) = &input_inst_def.kind {
|
||||
// NOTE(eddyb) do not destroy information left in e.g. comments.
|
||||
#[allow(clippy::match_same_arms)]
|
||||
match self.op {
|
||||
PureOp::BitCast => {
|
||||
// FIXME(eddyb) reduce chains of bitcasts.
|
||||
}
|
||||
|
||||
PureOp::CompositeExtract { elem_idx } => {
|
||||
if input_spv_inst.opcode == wk.OpCompositeInsert
|
||||
&& input_spv_inst.imms.len() == 1
|
||||
{
|
||||
let new_elem = input_inst_def.inputs[0];
|
||||
let prev_composite = input_inst_def.inputs[1];
|
||||
return Some(if input_spv_inst.imms[0] == elem_idx {
|
||||
ReductionStep::Complete(new_elem)
|
||||
} else {
|
||||
ReductionStep::Partial(self.with_input(prev_composite))
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
PureOp::IntToBool => {
|
||||
// FIXME(eddyb) look into what instructions might end up
|
||||
// being used to transform booleans into integers.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl Reducible {
|
||||
// FIXME(eddyb) make this into some kind of local `ReduceCx` method.
|
||||
fn try_reduce(
|
||||
mut self,
|
||||
cx: &Context,
|
||||
// FIXME(eddyb) come up with a better convention for this!
|
||||
func: FuncAtMut<'_, ()>,
|
||||
|
||||
value_replacements: &FxHashMap<HashableValue, Value>,
|
||||
|
||||
parent_map: &ParentMap,
|
||||
|
||||
cache: &mut FxHashMap<Self, Option<Value>>,
|
||||
) -> Option<Value> {
|
||||
// FIXME(eddyb) should we care about the cache *before* this loop below?
|
||||
|
||||
// HACK(eddyb) eagerly apply `value_replacements`.
|
||||
// FIXME(eddyb) this could do the union-find trick of shortening chains
|
||||
// the first time they're encountered, but also, if this process was more
|
||||
// "demand-driven" (recursing into use->def, instead of processing defs),
|
||||
// it might not require any of this complication.
|
||||
while let Some(&replacement) = value_replacements.get(&HashableValue(self.input)) {
|
||||
self.input = replacement;
|
||||
}
|
||||
|
||||
if let Some(&cached) = cache.get(&self) {
|
||||
return cached;
|
||||
}
|
||||
|
||||
let result = self.try_reduce_uncached(cx, func, value_replacements, parent_map, cache);
|
||||
|
||||
cache.insert(self, result);
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
// FIXME(eddyb) make this into some kind of local `ReduceCx` method.
|
||||
fn try_reduce_uncached(
|
||||
self,
|
||||
cx: &Context,
|
||||
// FIXME(eddyb) come up with a better convention for this!
|
||||
mut func: FuncAtMut<'_, ()>,
|
||||
|
||||
value_replacements: &FxHashMap<HashableValue, Value>,
|
||||
|
||||
parent_map: &ParentMap,
|
||||
|
||||
cache: &mut FxHashMap<Self, Option<Value>>,
|
||||
) -> Option<Value> {
|
||||
match self.input {
|
||||
Value::Const(ct) => self.with_input(ct).try_reduce_const(cx).map(Value::Const),
|
||||
Value::ControlRegionInput {
|
||||
region,
|
||||
input_idx: state_idx,
|
||||
} => {
|
||||
let loop_node = *parent_map.control_region_parent.get(region)?;
|
||||
// HACK(eddyb) this can't be a closure due to lifetime elision.
|
||||
fn loop_initial_states(
|
||||
func_at_loop_node: FuncAtMut<'_, ControlNode>,
|
||||
) -> &mut SmallVec<[Value; 2]> {
|
||||
match &mut func_at_loop_node.def().kind {
|
||||
ControlNodeKind::Loop { initial_inputs, .. } => initial_inputs,
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
let input_from_initial_state =
|
||||
loop_initial_states(func.reborrow().at(loop_node))[state_idx as usize];
|
||||
let input_from_updated_state =
|
||||
func.reborrow().at(region).def().outputs[state_idx as usize];
|
||||
|
||||
let output_from_initial_state = self
|
||||
.with_input(input_from_initial_state)
|
||||
.try_reduce(cx, func.reborrow(), value_replacements, parent_map, cache)?;
|
||||
// HACK(eddyb) this is here because it can fail, see the comment
|
||||
// on `output_from_updated_state` for what's actually going on.
|
||||
let output_from_updated_state_inst =
|
||||
DataInstDef::try_from(self.with_input(input_from_updated_state)).ok()?;
|
||||
|
||||
// Now that the reduction succeeded for the initial state,
|
||||
// we can proceed with augmenting the loop with the extra state.
|
||||
loop_initial_states(func.reborrow().at(loop_node)).push(output_from_initial_state);
|
||||
|
||||
let loop_state_decls = &mut func.reborrow().at(region).def().inputs;
|
||||
let new_loop_state_idx = u32::try_from(loop_state_decls.len()).unwrap();
|
||||
loop_state_decls.push(ControlRegionInputDecl {
|
||||
attrs: Default::default(),
|
||||
ty: self.output_type,
|
||||
});
|
||||
|
||||
// HACK(eddyb) generating the instruction wholesale again is not
|
||||
// the most efficient way to go about this, but avoiding getting
|
||||
// stuck in a loop while processing a loop is also important.
|
||||
//
|
||||
// FIXME(eddyb) attempt to replace this with early-inserting in
|
||||
// `cache` *then* returning.
|
||||
let output_from_updated_state = func
|
||||
.data_insts
|
||||
.define(cx, output_from_updated_state_inst.into());
|
||||
func.reborrow()
|
||||
.at(region)
|
||||
.def()
|
||||
.outputs
|
||||
.push(Value::DataInstOutput(output_from_updated_state));
|
||||
|
||||
// FIXME(eddyb) move this into some kind of utility/common helpers.
|
||||
let loop_body_last_block = func
|
||||
.reborrow()
|
||||
.at(region)
|
||||
.def()
|
||||
.children
|
||||
.iter()
|
||||
.last
|
||||
.filter(|&node| {
|
||||
matches!(
|
||||
func.reborrow().at(node).def().kind,
|
||||
ControlNodeKind::Block { .. }
|
||||
)
|
||||
})
|
||||
.unwrap_or_else(|| {
|
||||
let new_block = func.control_nodes.define(
|
||||
cx,
|
||||
ControlNodeDef {
|
||||
kind: ControlNodeKind::Block {
|
||||
insts: Default::default(),
|
||||
},
|
||||
outputs: Default::default(),
|
||||
}
|
||||
.into(),
|
||||
);
|
||||
func.control_regions[region]
|
||||
.children
|
||||
.insert_last(new_block, func.control_nodes);
|
||||
new_block
|
||||
});
|
||||
match &mut func.control_nodes[loop_body_last_block].kind {
|
||||
ControlNodeKind::Block { insts } => {
|
||||
insts.insert_last(output_from_updated_state, func.data_insts);
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
|
||||
Some(Value::ControlRegionInput {
|
||||
region,
|
||||
input_idx: new_loop_state_idx,
|
||||
})
|
||||
}
|
||||
Value::ControlNodeOutput {
|
||||
control_node,
|
||||
output_idx,
|
||||
} => {
|
||||
let cases = match &func.reborrow().at(control_node).def().kind {
|
||||
ControlNodeKind::Select { cases, .. } => cases,
|
||||
// NOTE(eddyb) only `Select`s can have outputs right now.
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
// FIXME(eddyb) remove all the cloning and undo additions of new
|
||||
// outputs "upstream", if they end up unused (or let DCE do it?).
|
||||
let cases = cases.clone();
|
||||
let per_case_new_output: SmallVec<[_; 2]> = cases
|
||||
.iter()
|
||||
.map(|&case| {
|
||||
let per_case_input =
|
||||
func.reborrow().at(case).def().outputs[output_idx as usize];
|
||||
self.with_input(per_case_input).try_reduce(
|
||||
cx,
|
||||
func.reborrow(),
|
||||
value_replacements,
|
||||
parent_map,
|
||||
cache,
|
||||
)
|
||||
})
|
||||
.collect::<Option<_>>()?;
|
||||
|
||||
// Try to avoid introducing a new output, by reducing the merge
|
||||
// of the per-case output values to a single value, if possible.
|
||||
let (kind, scrutinee) = match &func.reborrow().at(control_node).def().kind {
|
||||
ControlNodeKind::Select {
|
||||
kind, scrutinee, ..
|
||||
} => (kind, *scrutinee),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
if let Some(v) = try_reduce_select(
|
||||
cx,
|
||||
parent_map,
|
||||
control_node,
|
||||
kind,
|
||||
scrutinee,
|
||||
per_case_new_output.iter().copied(),
|
||||
) {
|
||||
return Some(v);
|
||||
}
|
||||
|
||||
// Merge the per-case output values into a new output.
|
||||
let control_node_output_decls = &mut func.reborrow().at(control_node).def().outputs;
|
||||
let new_output_idx = u32::try_from(control_node_output_decls.len()).unwrap();
|
||||
control_node_output_decls.push(ControlNodeOutputDecl {
|
||||
attrs: Default::default(),
|
||||
ty: self.output_type,
|
||||
});
|
||||
for (&case, new_output) in cases.iter().zip(per_case_new_output) {
|
||||
let per_case_outputs = &mut func.reborrow().at(case).def().outputs;
|
||||
assert_eq!(per_case_outputs.len(), new_output_idx as usize);
|
||||
per_case_outputs.push(new_output);
|
||||
}
|
||||
Some(Value::ControlNodeOutput {
|
||||
control_node,
|
||||
output_idx: new_output_idx,
|
||||
})
|
||||
}
|
||||
Value::DataInstOutput(inst) => {
|
||||
let inst_def = &*func.reborrow().at(inst).def();
|
||||
match self.with_input(inst_def).try_reduce_output_of_data_inst()? {
|
||||
ReductionStep::Complete(v) => Some(v),
|
||||
// FIXME(eddyb) actually use a loop instead of recursing here.
|
||||
ReductionStep::Partial(redu) => {
|
||||
redu.try_reduce(cx, func, value_replacements, parent_map, cache)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user