Refactor remove_duplicate_types to be a lil cleaner

This commit is contained in:
khyperia 2020-09-22 18:19:33 +02:00
parent abdfa35cce
commit 206af9d2a1

View File

@ -164,6 +164,18 @@ fn kill_annotations_and_debug(module: &mut rspirv::dr::Module, id: u32) {
}
fn remove_duplicate_types(module: &mut rspirv::dr::Module) {
fn rewrite_inst_with_rules(inst: &mut rspirv::dr::Instruction, rules: &HashMap<u32, u32>) {
if let Some(ref mut id) = inst.result_type {
// If the rewrite rules contain this ID, replace with the mapped value, otherwise don't touch it.
*id = rules.get(id).copied().unwrap_or(*id);
}
for op in &mut inst.operands {
if let rspirv::dr::Operand::IdRef(ref mut id) = op {
*id = rules.get(id).copied().unwrap_or(*id);
}
}
}
// Keep in mind, this algorithm requires forward type references to not exist - i.e. it's a valid spir-v module.
use rspirv::binary::Assemble;
@ -175,23 +187,27 @@ fn remove_duplicate_types(module: &mut rspirv::dr::Module) {
let mut key_to_result_id = HashMap::new();
for inst in &mut module.types_global_values {
let key = {
let mut data = vec![];
data.push(inst.class.opcode as u32);
// TODO: Should this also include the result_type?
for op in &mut inst.operands {
// This is an important spot: Say that we come upon a duplicated aggregate type (one that references
// other types). Its arguments may be duplicated themselves, and so building the key directly will fail
// to match up with the first type. However, **because forward references are not allowed**, we're
// guaranteed to have already found and deduplicated the argument types! So that means the deduplication
// translation is already in rewrite_rules, and we merely need to apply the mapping before generating
// the key.
if let rspirv::dr::Operand::IdRef(ref mut id) = op {
// Nit: Overwriting the instruction isn't technically necessary, as it will get handled by the final
// all_inst_iter_mut pass below. However, the code is a lil bit cleaner this way I guess.
*id = rewrite_rules.get(id).copied().unwrap_or(*id);
rewrite_inst_with_rules(inst, &rewrite_rules);
let key = {
let mut data = vec![];
data.push(inst.class.opcode as u32);
if let Some(id) = inst.result_type {
// We're not only deduplicating types here, but constants as well. Those contain result_types, and so we
// need to include those here. For example, OpConstant can have the same arg, but different result_type,
// and it should not be deduplicated (e.g. the constants 1u8 and 1u16).
data.push(id);
}
for op in &mut inst.operands {
op.assemble_into(&mut data);
}
@ -223,15 +239,7 @@ fn remove_duplicate_types(module: &mut rspirv::dr::Module) {
// Apply the rewrite rules to the whole module
for inst in module.all_inst_iter_mut() {
if let Some(ref mut id) = inst.result_type {
// If the rewrite rules contain this ID, replace with the mapped value, otherwise don't touch it.
*id = rewrite_rules.get(id).copied().unwrap_or(*id);
}
for op in &mut inst.operands {
if let rspirv::dr::Operand::IdRef(ref mut id) = op {
*id = rewrite_rules.get(id).copied().unwrap_or(*id);
}
}
rewrite_inst_with_rules(inst, &rewrite_rules);
}
}