diff --git a/rspirv-linker/src/lib.rs b/rspirv-linker/src/lib.rs index 9194204b27..d1e975b62b 100644 --- a/rspirv-linker/src/lib.rs +++ b/rspirv-linker/src/lib.rs @@ -164,6 +164,18 @@ fn kill_annotations_and_debug(module: &mut rspirv::dr::Module, id: u32) { } fn remove_duplicate_types(module: &mut rspirv::dr::Module) { + fn rewrite_inst_with_rules(inst: &mut rspirv::dr::Instruction, rules: &HashMap) { + if let Some(ref mut id) = inst.result_type { + // If the rewrite rules contain this ID, replace with the mapped value, otherwise don't touch it. + *id = rules.get(id).copied().unwrap_or(*id); + } + for op in &mut inst.operands { + if let rspirv::dr::Operand::IdRef(ref mut id) = op { + *id = rules.get(id).copied().unwrap_or(*id); + } + } + } + // Keep in mind, this algorithm requires forward type references to not exist - i.e. it's a valid spir-v module. use rspirv::binary::Assemble; @@ -175,23 +187,27 @@ fn remove_duplicate_types(module: &mut rspirv::dr::Module) { let mut key_to_result_id = HashMap::new(); for inst in &mut module.types_global_values { + // This is an important spot: Say that we come upon a duplicated aggregate type (one that references + // other types). Its arguments may be duplicated themselves, and so building the key directly will fail + // to match up with the first type. However, **because forward references are not allowed**, we're + // guaranteed to have already found and deduplicated the argument types! So that means the deduplication + // translation is already in rewrite_rules, and we merely need to apply the mapping before generating + // the key. + // Nit: Overwriting the instruction isn't technically necessary, as it will get handled by the final + // all_inst_iter_mut pass below. However, the code is a lil bit cleaner this way I guess. + rewrite_inst_with_rules(inst, &rewrite_rules); + let key = { let mut data = vec![]; data.push(inst.class.opcode as u32); - // TODO: Should this also include the result_type? + if let Some(id) = inst.result_type { + // We're not only deduplicating types here, but constants as well. Those contain result_types, and so we + // need to include those here. For example, OpConstant can have the same arg, but different result_type, + // and it should not be deduplicated (e.g. the constants 1u8 and 1u16). + data.push(id); + } for op in &mut inst.operands { - // This is an important spot: Say that we come upon a duplicated aggregate type (one that references - // other types). Its arguments may be duplicated themselves, and so building the key directly will fail - // to match up with the first type. However, **because forward references are not allowed**, we're - // guaranteed to have already found and deduplicated the argument types! So that means the deduplication - // translation is already in rewrite_rules, and we merely need to apply the mapping before generating - // the key. - if let rspirv::dr::Operand::IdRef(ref mut id) = op { - // Nit: Overwriting the instruction isn't technically necessary, as it will get handled by the final - // all_inst_iter_mut pass below. However, the code is a lil bit cleaner this way I guess. - *id = rewrite_rules.get(id).copied().unwrap_or(*id); - } op.assemble_into(&mut data); } @@ -223,15 +239,7 @@ fn remove_duplicate_types(module: &mut rspirv::dr::Module) { // Apply the rewrite rules to the whole module for inst in module.all_inst_iter_mut() { - if let Some(ref mut id) = inst.result_type { - // If the rewrite rules contain this ID, replace with the mapped value, otherwise don't touch it. - *id = rewrite_rules.get(id).copied().unwrap_or(*id); - } - for op in &mut inst.operands { - if let rspirv::dr::Operand::IdRef(ref mut id) = op { - *id = rewrite_rules.get(id).copied().unwrap_or(*id); - } - } + rewrite_inst_with_rules(inst, &rewrite_rules); } }