Simplify remove_duplicate_types

This commit is contained in:
khyperia 2020-09-22 17:26:17 +02:00
parent 31e8d95898
commit a3f859afa5
2 changed files with 214 additions and 372 deletions

View File

@ -163,119 +163,53 @@ fn kill_annotations_and_debug(module: &mut rspirv::dr::Module, id: u32) {
kill_with_id(&mut module.debugs, id);
}
fn remove_duplicate_types(module: rspirv::dr::Module) -> rspirv::dr::Module {
fn remove_duplicate_types(module: &mut rspirv::dr::Module) {
use rspirv::binary::Assemble;
// jb-todo: spirv-tools's linker has special case handling for SpvOpTypeForwardPointer,
// not sure if we need that; see https://github.com/KhronosGroup/SPIRV-Tools/blob/e7866de4b1dc2a7e8672867caeb0bdca49f458d3/source/opt/remove_duplicates_pass.cpp for reference
let mut rewrite_rules = HashMap::new();
let mut key_to_result_id = HashMap::new();
let mut instructions = module
.all_inst_iter()
.cloned()
.collect::<Vec<_>>()
.into_boxed_slice(); // force boxed slice so we don't accidentally grow or shrink it later
let mut def_use_analyzer = DefUseAnalyzer::new(&mut instructions);
let mut kill_annotations = vec![];
let mut continue_from_idx = 0;
// need to do this process iteratively because types can reference each other
loop {
let mut dedup = std::collections::HashMap::new();
let mut duplicate = None;
for (iterator_idx, module_inst) in module
.types_global_values
.iter()
.enumerate()
.skip(continue_from_idx)
{
let (inst_idx, inst) = def_use_analyzer.def(module_inst.result_id.unwrap());
if inst.class.opcode == spirv::Op::Nop {
continue;
}
// partially assemble only the opcode and operands to be used as a key
// maybe this should also include the result_type
let data = {
for inst in &mut module.types_global_values {
let key = {
let mut data = vec![];
data.push(inst.class.opcode as u32);
for op in &inst.operands {
for op in &mut inst.operands {
if let rspirv::dr::Operand::IdRef(ref mut id) = op {
*id = rewrite_rules.get(id).copied().unwrap_or(*id);
}
op.assemble_into(&mut data);
}
data
};
// dedup contains a tuple of three indices;
// the first two point into our `def_use_analyzer.instructions` map
// the last one points into the `module.types_global_values` iterator so we can resume iteration
dedup
.entry(data)
.and_modify(|(identical_idx, backtrack_idx)| {
duplicate = Some((inst_idx, *identical_idx, *backtrack_idx));
})
.or_insert((inst_idx, iterator_idx)); // store the index that we encountered an instruction
// for the first time so we can backtrack later
if let Some((_, _, backtrack_idx)) = duplicate {
continue_from_idx = backtrack_idx;
break;
match key_to_result_id.entry(key) {
hash_map::Entry::Vacant(entry) => {
entry.insert(inst.result_id.unwrap());
}
hash_map::Entry::Occupied(entry) => {
assert!(!rewrite_rules.contains_key(&inst.result_id.unwrap()));
rewrite_rules.insert(inst.result_id.unwrap(), *entry.get());
*inst = rspirv::dr::Instruction::new(spirv::Op::Nop, None, None, vec![]);
}
}
if let Some((before_idx, after_idx, _)) = duplicate {
let before_id = def_use_analyzer.instructions[before_idx].result_id.unwrap();
let after_id = def_use_analyzer.instructions[after_idx].result_id.unwrap();
// remove annotations later
kill_annotations.push(before_id);
def_use_analyzer.for_each_use(before_id, |inst| {
if inst.result_type == Some(before_id) {
inst.result_type = Some(after_id);
}
for op in inst.operands.iter_mut() {
match op {
rspirv::dr::Operand::IdMemorySemantics(w)
| rspirv::dr::Operand::IdScope(w)
| rspirv::dr::Operand::IdRef(w) => {
if *w == before_id {
*w = after_id
}
}
_ => {}
}
}
});
// this loop / system works on the assumption that all indices remain valid,
// so instead of removing the instruction we just nop it out - `consume_instruction` will then
// skip all OpNops and they won't appear in the newly constructed module
def_use_analyzer.instructions[before_idx] =
rspirv::dr::Instruction::new(spirv::Op::Nop, None, None, vec![]);
} else {
break;
}
}
let mut loader = rspirv::dr::Loader::new();
for inst in def_use_analyzer.instructions.iter() {
loader.consume_instruction(inst.clone());
}
let mut module = loader.module();
for remove in kill_annotations {
kill_annotations_and_debug(&mut module, remove);
}
module
.types_global_values
.retain(|op| op.class.opcode != spirv::Op::Nop);
for inst in module.all_inst_iter_mut() {
if let Some(ref mut id) = inst.result_type {
*id = rewrite_rules.get(id).copied().unwrap_or(*id);
}
for op in &mut inst.operands {
if let rspirv::dr::Operand::IdRef(ref mut id) = op {
*id = rewrite_rules.get(id).copied().unwrap_or(*id);
}
}
}
}
#[derive(Clone, Debug)]
@ -469,91 +403,6 @@ impl DefAnalyzer {
}
}
struct DefUseAnalyzer<'a> {
def_ids: HashMap<u32, usize>,
use_ids: HashMap<u32, Vec<usize>>,
use_result_type_ids: HashMap<u32, Vec<usize>>,
instructions: &'a mut [rspirv::dr::Instruction],
}
impl<'a> DefUseAnalyzer<'a> {
fn new(instructions: &'a mut [rspirv::dr::Instruction]) -> Self {
let mut def_ids = HashMap::new();
let mut use_ids: HashMap<u32, Vec<usize>> = HashMap::new();
let mut use_result_type_ids: HashMap<u32, Vec<usize>> = HashMap::new();
instructions
.iter()
.enumerate()
.for_each(|(inst_idx, inst)| {
if let Some(def_id) = inst.result_id {
def_ids
.entry(def_id)
.and_modify(|stored_inst| {
*stored_inst = inst_idx;
})
.or_insert(inst_idx);
}
if let Some(result_type) = inst.result_type {
use_result_type_ids
.entry(result_type)
.and_modify(|v| v.push(inst_idx))
.or_insert_with(|| vec![inst_idx]);
}
for op in inst.operands.iter() {
match op {
rspirv::dr::Operand::IdMemorySemantics(w)
| rspirv::dr::Operand::IdScope(w)
| rspirv::dr::Operand::IdRef(w) => {
use_ids
.entry(*w)
.and_modify(|v| v.push(inst_idx))
.or_insert_with(|| vec![inst_idx]);
}
_ => {}
}
}
});
Self {
def_ids,
use_ids,
use_result_type_ids,
instructions,
}
}
fn def_idx(&self, id: u32) -> usize {
self.def_ids[&id]
}
fn def(&self, id: u32) -> (usize, &rspirv::dr::Instruction) {
let idx = self.def_idx(id);
(idx, &self.instructions[idx])
}
fn for_each_use<F>(&mut self, id: u32, mut f: F)
where
F: FnMut(&mut rspirv::dr::Instruction),
{
// find by `result_type`
if let Some(use_result_type_id) = self.use_result_type_ids.get(&id) {
for inst_idx in use_result_type_id {
f(&mut self.instructions[*inst_idx])
}
}
// find by operand
if let Some(use_id) = self.use_ids.get(&id) {
for inst_idx in use_id {
f(&mut self.instructions[*inst_idx]);
}
}
}
}
fn import_kill_annotations_and_debug(module: &mut rspirv::dr::Module, info: &LinkInfo) {
for import in &info.imports {
kill_annotations_and_debug(module, import.id);
@ -1210,7 +1059,7 @@ pub fn link(inputs: &mut [&mut rspirv::dr::Module], opts: &Options) -> Result<rs
// remove duplicates (https://github.com/KhronosGroup/SPIRV-Tools/blob/e7866de4b1dc2a7e8672867caeb0bdca49f458d3/source/opt/remove_duplicates_pass.cpp)
remove_duplicate_capablities(&mut output);
remove_duplicate_ext_inst_imports(&mut output);
let mut output = remove_duplicate_types(output);
remove_duplicate_types(&mut output);
// jb-todo: strip identical OpDecoration / OpDecorationGroups
// remove names and decorations of import variables / functions https://github.com/KhronosGroup/SPIRV-Tools/blob/8a0ebd40f86d1f18ad42ea96c6ac53915076c3c7/source/opt/ir_context.cpp#L404

View File

@ -1,4 +1,7 @@
use crate::link;
use crate::LinkerError;
use crate::Options;
use crate::Result;
// https://github.com/colin-kiegel/rust-pretty-assertions/issues/24
#[derive(PartialEq, Eq)]
@ -21,7 +24,7 @@ fn assemble_spirv(spirv: &str) -> Vec<u8> {
std::fs::write(&input, spirv).unwrap();
let process = Command::new("spirv-as.exe")
let process = Command::new("spirv-as")
.arg(input.to_str().unwrap())
.arg("-o")
.arg(output.to_str().unwrap())
@ -64,8 +67,7 @@ fn validate(spirv: &[u32]) {
fn load(bytes: &[u8]) -> rspirv::dr::Module {
let mut loader = rspirv::dr::Loader::new();
rspirv::binary::parse_bytes(&bytes, &mut loader).unwrap();
let module = loader.module();
module
loader.module()
}
fn assemble_and_link(
@ -73,7 +75,7 @@ fn assemble_and_link(
opts: &crate::Options,
) -> crate::Result<rspirv::dr::Module> {
let mut modules = binaries.iter().cloned().map(load).collect::<Vec<_>>();
let mut modules = modules.iter_mut().map(|m| m).collect::<Vec<_>>();
let mut modules = modules.iter_mut().collect::<Vec<_>>();
link(&mut modules, opts)
}
@ -88,13 +90,13 @@ fn without_header_eq(mut result: rspirv::dr::Module, expected: &str) {
let result = result.disassemble();
let expected = expected
.split("\n")
.split('\n')
.map(|l| l.trim())
.collect::<Vec<_>>()
.join("\n");
let result = result
.split("\n")
.split('\n')
.map(|l| l.trim().replace(" ", " ")) // rspirv outputs multiple spaces between operands
.collect::<Vec<_>>()
.join("\n");
@ -111,14 +113,6 @@ fn without_header_eq(mut result: rspirv::dr::Module, expected: &str) {
}
}
mod test {
use crate::test::assemble_and_link;
use crate::test::assemble_spirv;
use crate::test::without_header_eq;
use crate::LinkerError;
use crate::Options;
use crate::Result;
#[test]
fn standard() -> Result<()> {
let a = assemble_spirv(
@ -236,8 +230,8 @@ mod test {
result.err(),
Some(LinkerError::TypeMismatch {
name: "foo".to_string(),
import_type: "OpTypeFloat 32".to_string(),
export_type: "OpTypeInt 32 0".to_string(),
import_type: "f32".to_string(),
export_type: "u32".to_string(),
})
);
Ok(())
@ -513,4 +507,3 @@ mod test {
without_header_eq(result, expect);
Ok(())
}
}