From 28c0885b4010264e853738b353ca3872fd13130e Mon Sep 17 00:00:00 2001 From: khyperia Date: Thu, 1 Oct 2020 11:08:12 +0200 Subject: [PATCH] Take annotations into account when merging types --- rspirv-linker/src/duplicates.rs | 144 +++++++++++++++++++++++--------- rspirv-linker/src/lib.rs | 2 +- 2 files changed, 104 insertions(+), 42 deletions(-) diff --git a/rspirv-linker/src/duplicates.rs b/rspirv-linker/src/duplicates.rs index 0b1216afd4..6510934630 100644 --- a/rspirv-linker/src/duplicates.rs +++ b/rspirv-linker/src/duplicates.rs @@ -1,4 +1,5 @@ -use crate::operand_idref_mut; +use crate::{operand_idref, operand_idref_mut}; +use rspirv::binary::Assemble; use rspirv::spirv; use std::collections::{hash_map, HashMap, HashSet}; @@ -35,12 +36,14 @@ pub fn remove_duplicate_ext_inst_imports(module: &mut rspirv::dr::Module) { hash_map::Entry::Occupied(entry) => { let old_value = rewrite_rules.insert(inst.result_id.unwrap(), *entry.get()); assert!(old_value.is_none()); + // We're iterating through the vec, so removing items is hard - nop it out. *inst = rspirv::dr::Instruction::new(spirv::Op::Nop, None, None, vec![]); } } } } + // Delete the nops we inserted module .ext_inst_imports .retain(|op| op.class.opcode != spirv::Op::Nop); @@ -55,32 +58,108 @@ pub fn remove_duplicate_ext_inst_imports(module: &mut rspirv::dr::Module) { } } -// TODO: Don't merge zombie types with non-zombie types -pub fn remove_duplicate_types(module: &mut rspirv::dr::Module) { - fn rewrite_inst_with_rules(inst: &mut rspirv::dr::Instruction, rules: &HashMap) { - if let Some(ref mut id) = inst.result_type { - // If the rewrite rules contain this ID, replace with the mapped value, otherwise don't touch it. - *id = rules.get(id).copied().unwrap_or(*id); - } - for op in &mut inst.operands { - if let Some(id) = operand_idref_mut(op) { - *id = rules.get(id).copied().unwrap_or(*id); +fn make_annotation_key(inst: &rspirv::dr::Instruction) -> Vec { + let mut data = vec![]; + + data.push(inst.class.opcode as u32); + // skip over the target ID + for op in inst.operands.iter().skip(1) { + op.assemble_into(&mut data); + } + + data +} + +fn gather_annotations(annotations: &[rspirv::dr::Instruction]) -> HashMap> { + let mut map = HashMap::new(); + for inst in annotations { + if inst.class.opcode == spirv::Op::Decorate + || inst.class.opcode == spirv::Op::MemberDecorate + { + match map.entry(operand_idref(&inst.operands[0]).unwrap()) { + hash_map::Entry::Vacant(entry) => { + entry.insert(vec![make_annotation_key(inst)]); + } + hash_map::Entry::Occupied(mut entry) => { + entry.get_mut().push(make_annotation_key(inst)); + } } } } + map.into_iter() + .map(|(key, mut value)| { + (key, { + value.sort(); + value.concat() + }) + }) + .collect() +} +fn make_type_key( + inst: &rspirv::dr::Instruction, + unresolved_forward_pointers: &HashSet, + annotations: &HashMap>, +) -> Vec { + let mut data = vec![]; + + data.push(inst.class.opcode as u32); + if let Some(id) = inst.result_type { + // We're not only deduplicating types here, but constants as well. Those contain result_types, and so we + // need to include those here. For example, OpConstant can have the same arg, but different result_type, + // and it should not be deduplicated (e.g. the constants 1u8 and 1u16). + data.push(id); + } + for op in &inst.operands { + if let rspirv::dr::Operand::IdRef(id) = op { + if unresolved_forward_pointers.contains(id) { + // TODO: This is implementing forward pointers incorrectly. All unresolved forward pointers will + // compare equal. + rspirv::dr::Operand::IdRef(0).assemble_into(&mut data); + } else { + op.assemble_into(&mut data); + } + } else { + op.assemble_into(&mut data); + } + } + if let Some(id) = inst.result_id { + if let Some(annos) = annotations.get(&id) { + data.extend_from_slice(annos) + } + } + + data +} + +fn rewrite_inst_with_rules(inst: &mut rspirv::dr::Instruction, rules: &HashMap) { + if let Some(ref mut id) = inst.result_type { + // If the rewrite rules contain this ID, replace with the mapped value, otherwise don't touch it. + *id = rules.get(id).copied().unwrap_or(*id); + } + for op in &mut inst.operands { + if let Some(id) = operand_idref_mut(op) { + *id = rules.get(id).copied().unwrap_or(*id); + } + } +} + +// TODO: Don't merge zombie types with non-zombie types +pub fn remove_duplicate_types(module: &mut rspirv::dr::Module) { // Keep in mind, this algorithm requires forward type references to not exist - i.e. it's a valid spir-v module. - use rspirv::binary::Assemble; // When a duplicate type is encountered, then this is a map from the deleted ID, to the new, deduplicated ID. let mut rewrite_rules = HashMap::new(); - // Instructions are encoded into "keys": their opcode, followed by arguments. Importantly, result_id is left out. - // This means that any instruction that declares the same type, but with different result_id, will result in the - // same key. + // Instructions are encoded into "keys": their opcode, followed by arguments, then annotations. + // Importantly, result_id is left out. This means that any instruction that declares the same + // type, but with different result_id, will result in the same key. let mut key_to_result_id = HashMap::new(); // TODO: This is implementing forward pointers incorrectly. let mut unresolved_forward_pointers = HashSet::new(); + // Collect a map from type ID to an annotation "key blob" (to append to the type key) + let annotations = gather_annotations(&module.annotations); + for inst in &mut module.types_global_values { if inst.class.opcode == spirv::Op::TypeForwardPointer { if let rspirv::dr::Operand::IdRef(id) = inst.operands[0] { @@ -103,32 +182,7 @@ pub fn remove_duplicate_types(module: &mut rspirv::dr::Module) { // all_inst_iter_mut pass below. However, the code is a lil bit cleaner this way I guess. rewrite_inst_with_rules(inst, &rewrite_rules); - let key = { - let mut data = vec![]; - - data.push(inst.class.opcode as u32); - if let Some(id) = inst.result_type { - // We're not only deduplicating types here, but constants as well. Those contain result_types, and so we - // need to include those here. For example, OpConstant can have the same arg, but different result_type, - // and it should not be deduplicated (e.g. the constants 1u8 and 1u16). - data.push(id); - } - for op in &inst.operands { - if let rspirv::dr::Operand::IdRef(id) = op { - if unresolved_forward_pointers.contains(id) { - // TODO: This is implementing forward pointers incorrectly. All unresolved forward pointers will - // compare equal. - rspirv::dr::Operand::IdRef(0).assemble_into(&mut data); - } else { - op.assemble_into(&mut data); - } - } else { - op.assemble_into(&mut data); - } - } - - data - }; + let key = make_type_key(inst, &unresolved_forward_pointers, &annotations); match key_to_result_id.entry(key) { hash_map::Entry::Vacant(entry) => { @@ -157,4 +211,12 @@ pub fn remove_duplicate_types(module: &mut rspirv::dr::Module) { for inst in module.all_inst_iter_mut() { rewrite_inst_with_rules(inst, &rewrite_rules); } + + // The same decorations for duplicated types will cause those different types to merge + // together. So, we need to deduplicate the annotations as well. (Note we *do* care about the + // ID of the type being applied to here, unlike `gather_annotations`) + let mut anno_set = HashSet::new(); + module + .annotations + .retain(|inst| anno_set.insert(inst.assemble())); } diff --git a/rspirv-linker/src/lib.rs b/rspirv-linker/src/lib.rs index 49e0eeee35..90f571a0a6 100644 --- a/rspirv-linker/src/lib.rs +++ b/rspirv-linker/src/lib.rs @@ -99,7 +99,7 @@ pub fn link( } let mut output = loader.module(); - let mut header = rspirv::dr::ModuleHeader::new(bound); + let mut header = rspirv::dr::ModuleHeader::new(bound + 1); header.set_version(version.0, version.1); output.header = Some(header);