Take annotations into account when merging types

This commit is contained in:
khyperia 2020-10-01 11:08:12 +02:00
parent 2c1b73ab9a
commit 28c0885b40
2 changed files with 104 additions and 42 deletions

View File

@ -1,4 +1,5 @@
use crate::operand_idref_mut;
use crate::{operand_idref, operand_idref_mut};
use rspirv::binary::Assemble;
use rspirv::spirv;
use std::collections::{hash_map, HashMap, HashSet};
@ -35,12 +36,14 @@ pub fn remove_duplicate_ext_inst_imports(module: &mut rspirv::dr::Module) {
hash_map::Entry::Occupied(entry) => {
let old_value = rewrite_rules.insert(inst.result_id.unwrap(), *entry.get());
assert!(old_value.is_none());
// We're iterating through the vec, so removing items is hard - nop it out.
*inst = rspirv::dr::Instruction::new(spirv::Op::Nop, None, None, vec![]);
}
}
}
}
// Delete the nops we inserted
module
.ext_inst_imports
.retain(|op| op.class.opcode != spirv::Op::Nop);
@ -55,8 +58,80 @@ pub fn remove_duplicate_ext_inst_imports(module: &mut rspirv::dr::Module) {
}
}
// TODO: Don't merge zombie types with non-zombie types
pub fn remove_duplicate_types(module: &mut rspirv::dr::Module) {
fn make_annotation_key(inst: &rspirv::dr::Instruction) -> Vec<u32> {
let mut data = vec![];
data.push(inst.class.opcode as u32);
// skip over the target ID
for op in inst.operands.iter().skip(1) {
op.assemble_into(&mut data);
}
data
}
fn gather_annotations(annotations: &[rspirv::dr::Instruction]) -> HashMap<spirv::Word, Vec<u32>> {
let mut map = HashMap::new();
for inst in annotations {
if inst.class.opcode == spirv::Op::Decorate
|| inst.class.opcode == spirv::Op::MemberDecorate
{
match map.entry(operand_idref(&inst.operands[0]).unwrap()) {
hash_map::Entry::Vacant(entry) => {
entry.insert(vec![make_annotation_key(inst)]);
}
hash_map::Entry::Occupied(mut entry) => {
entry.get_mut().push(make_annotation_key(inst));
}
}
}
}
map.into_iter()
.map(|(key, mut value)| {
(key, {
value.sort();
value.concat()
})
})
.collect()
}
fn make_type_key(
inst: &rspirv::dr::Instruction,
unresolved_forward_pointers: &HashSet<spirv::Word>,
annotations: &HashMap<spirv::Word, Vec<u32>>,
) -> Vec<u32> {
let mut data = vec![];
data.push(inst.class.opcode as u32);
if let Some(id) = inst.result_type {
// We're not only deduplicating types here, but constants as well. Those contain result_types, and so we
// need to include those here. For example, OpConstant can have the same arg, but different result_type,
// and it should not be deduplicated (e.g. the constants 1u8 and 1u16).
data.push(id);
}
for op in &inst.operands {
if let rspirv::dr::Operand::IdRef(id) = op {
if unresolved_forward_pointers.contains(id) {
// TODO: This is implementing forward pointers incorrectly. All unresolved forward pointers will
// compare equal.
rspirv::dr::Operand::IdRef(0).assemble_into(&mut data);
} else {
op.assemble_into(&mut data);
}
} else {
op.assemble_into(&mut data);
}
}
if let Some(id) = inst.result_id {
if let Some(annos) = annotations.get(&id) {
data.extend_from_slice(annos)
}
}
data
}
fn rewrite_inst_with_rules(inst: &mut rspirv::dr::Instruction, rules: &HashMap<u32, u32>) {
if let Some(ref mut id) = inst.result_type {
// If the rewrite rules contain this ID, replace with the mapped value, otherwise don't touch it.
@ -69,18 +144,22 @@ pub fn remove_duplicate_types(module: &mut rspirv::dr::Module) {
}
}
// TODO: Don't merge zombie types with non-zombie types
pub fn remove_duplicate_types(module: &mut rspirv::dr::Module) {
// Keep in mind, this algorithm requires forward type references to not exist - i.e. it's a valid spir-v module.
use rspirv::binary::Assemble;
// When a duplicate type is encountered, then this is a map from the deleted ID, to the new, deduplicated ID.
let mut rewrite_rules = HashMap::new();
// Instructions are encoded into "keys": their opcode, followed by arguments. Importantly, result_id is left out.
// This means that any instruction that declares the same type, but with different result_id, will result in the
// same key.
// Instructions are encoded into "keys": their opcode, followed by arguments, then annotations.
// Importantly, result_id is left out. This means that any instruction that declares the same
// type, but with different result_id, will result in the same key.
let mut key_to_result_id = HashMap::new();
// TODO: This is implementing forward pointers incorrectly.
let mut unresolved_forward_pointers = HashSet::new();
// Collect a map from type ID to an annotation "key blob" (to append to the type key)
let annotations = gather_annotations(&module.annotations);
for inst in &mut module.types_global_values {
if inst.class.opcode == spirv::Op::TypeForwardPointer {
if let rspirv::dr::Operand::IdRef(id) = inst.operands[0] {
@ -103,32 +182,7 @@ pub fn remove_duplicate_types(module: &mut rspirv::dr::Module) {
// all_inst_iter_mut pass below. However, the code is a lil bit cleaner this way I guess.
rewrite_inst_with_rules(inst, &rewrite_rules);
let key = {
let mut data = vec![];
data.push(inst.class.opcode as u32);
if let Some(id) = inst.result_type {
// We're not only deduplicating types here, but constants as well. Those contain result_types, and so we
// need to include those here. For example, OpConstant can have the same arg, but different result_type,
// and it should not be deduplicated (e.g. the constants 1u8 and 1u16).
data.push(id);
}
for op in &inst.operands {
if let rspirv::dr::Operand::IdRef(id) = op {
if unresolved_forward_pointers.contains(id) {
// TODO: This is implementing forward pointers incorrectly. All unresolved forward pointers will
// compare equal.
rspirv::dr::Operand::IdRef(0).assemble_into(&mut data);
} else {
op.assemble_into(&mut data);
}
} else {
op.assemble_into(&mut data);
}
}
data
};
let key = make_type_key(inst, &unresolved_forward_pointers, &annotations);
match key_to_result_id.entry(key) {
hash_map::Entry::Vacant(entry) => {
@ -157,4 +211,12 @@ pub fn remove_duplicate_types(module: &mut rspirv::dr::Module) {
for inst in module.all_inst_iter_mut() {
rewrite_inst_with_rules(inst, &rewrite_rules);
}
// The same decorations for duplicated types will cause those different types to merge
// together. So, we need to deduplicate the annotations as well. (Note we *do* care about the
// ID of the type being applied to here, unlike `gather_annotations`)
let mut anno_set = HashSet::new();
module
.annotations
.retain(|inst| anno_set.insert(inst.assemble()));
}

View File

@ -99,7 +99,7 @@ pub fn link<T>(
}
let mut output = loader.module();
let mut header = rspirv::dr::ModuleHeader::new(bound);
let mut header = rspirv::dr::ModuleHeader::new(bound + 1);
header.set_version(version.0, version.1);
output.header = Some(header);