mirror of
https://github.com/EmbarkStudios/rust-gpu.git
synced 2024-11-22 06:45:13 +00:00
Take annotations into account when merging types
This commit is contained in:
parent
2c1b73ab9a
commit
28c0885b40
@ -1,4 +1,5 @@
|
||||
use crate::operand_idref_mut;
|
||||
use crate::{operand_idref, operand_idref_mut};
|
||||
use rspirv::binary::Assemble;
|
||||
use rspirv::spirv;
|
||||
use std::collections::{hash_map, HashMap, HashSet};
|
||||
|
||||
@ -35,12 +36,14 @@ pub fn remove_duplicate_ext_inst_imports(module: &mut rspirv::dr::Module) {
|
||||
hash_map::Entry::Occupied(entry) => {
|
||||
let old_value = rewrite_rules.insert(inst.result_id.unwrap(), *entry.get());
|
||||
assert!(old_value.is_none());
|
||||
// We're iterating through the vec, so removing items is hard - nop it out.
|
||||
*inst = rspirv::dr::Instruction::new(spirv::Op::Nop, None, None, vec![]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Delete the nops we inserted
|
||||
module
|
||||
.ext_inst_imports
|
||||
.retain(|op| op.class.opcode != spirv::Op::Nop);
|
||||
@ -55,32 +58,108 @@ pub fn remove_duplicate_ext_inst_imports(module: &mut rspirv::dr::Module) {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Don't merge zombie types with non-zombie types
|
||||
pub fn remove_duplicate_types(module: &mut rspirv::dr::Module) {
|
||||
fn rewrite_inst_with_rules(inst: &mut rspirv::dr::Instruction, rules: &HashMap<u32, u32>) {
|
||||
if let Some(ref mut id) = inst.result_type {
|
||||
// If the rewrite rules contain this ID, replace with the mapped value, otherwise don't touch it.
|
||||
*id = rules.get(id).copied().unwrap_or(*id);
|
||||
}
|
||||
for op in &mut inst.operands {
|
||||
if let Some(id) = operand_idref_mut(op) {
|
||||
*id = rules.get(id).copied().unwrap_or(*id);
|
||||
fn make_annotation_key(inst: &rspirv::dr::Instruction) -> Vec<u32> {
|
||||
let mut data = vec![];
|
||||
|
||||
data.push(inst.class.opcode as u32);
|
||||
// skip over the target ID
|
||||
for op in inst.operands.iter().skip(1) {
|
||||
op.assemble_into(&mut data);
|
||||
}
|
||||
|
||||
data
|
||||
}
|
||||
|
||||
fn gather_annotations(annotations: &[rspirv::dr::Instruction]) -> HashMap<spirv::Word, Vec<u32>> {
|
||||
let mut map = HashMap::new();
|
||||
for inst in annotations {
|
||||
if inst.class.opcode == spirv::Op::Decorate
|
||||
|| inst.class.opcode == spirv::Op::MemberDecorate
|
||||
{
|
||||
match map.entry(operand_idref(&inst.operands[0]).unwrap()) {
|
||||
hash_map::Entry::Vacant(entry) => {
|
||||
entry.insert(vec![make_annotation_key(inst)]);
|
||||
}
|
||||
hash_map::Entry::Occupied(mut entry) => {
|
||||
entry.get_mut().push(make_annotation_key(inst));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
map.into_iter()
|
||||
.map(|(key, mut value)| {
|
||||
(key, {
|
||||
value.sort();
|
||||
value.concat()
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn make_type_key(
|
||||
inst: &rspirv::dr::Instruction,
|
||||
unresolved_forward_pointers: &HashSet<spirv::Word>,
|
||||
annotations: &HashMap<spirv::Word, Vec<u32>>,
|
||||
) -> Vec<u32> {
|
||||
let mut data = vec![];
|
||||
|
||||
data.push(inst.class.opcode as u32);
|
||||
if let Some(id) = inst.result_type {
|
||||
// We're not only deduplicating types here, but constants as well. Those contain result_types, and so we
|
||||
// need to include those here. For example, OpConstant can have the same arg, but different result_type,
|
||||
// and it should not be deduplicated (e.g. the constants 1u8 and 1u16).
|
||||
data.push(id);
|
||||
}
|
||||
for op in &inst.operands {
|
||||
if let rspirv::dr::Operand::IdRef(id) = op {
|
||||
if unresolved_forward_pointers.contains(id) {
|
||||
// TODO: This is implementing forward pointers incorrectly. All unresolved forward pointers will
|
||||
// compare equal.
|
||||
rspirv::dr::Operand::IdRef(0).assemble_into(&mut data);
|
||||
} else {
|
||||
op.assemble_into(&mut data);
|
||||
}
|
||||
} else {
|
||||
op.assemble_into(&mut data);
|
||||
}
|
||||
}
|
||||
if let Some(id) = inst.result_id {
|
||||
if let Some(annos) = annotations.get(&id) {
|
||||
data.extend_from_slice(annos)
|
||||
}
|
||||
}
|
||||
|
||||
data
|
||||
}
|
||||
|
||||
fn rewrite_inst_with_rules(inst: &mut rspirv::dr::Instruction, rules: &HashMap<u32, u32>) {
|
||||
if let Some(ref mut id) = inst.result_type {
|
||||
// If the rewrite rules contain this ID, replace with the mapped value, otherwise don't touch it.
|
||||
*id = rules.get(id).copied().unwrap_or(*id);
|
||||
}
|
||||
for op in &mut inst.operands {
|
||||
if let Some(id) = operand_idref_mut(op) {
|
||||
*id = rules.get(id).copied().unwrap_or(*id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Don't merge zombie types with non-zombie types
|
||||
pub fn remove_duplicate_types(module: &mut rspirv::dr::Module) {
|
||||
// Keep in mind, this algorithm requires forward type references to not exist - i.e. it's a valid spir-v module.
|
||||
use rspirv::binary::Assemble;
|
||||
|
||||
// When a duplicate type is encountered, then this is a map from the deleted ID, to the new, deduplicated ID.
|
||||
let mut rewrite_rules = HashMap::new();
|
||||
// Instructions are encoded into "keys": their opcode, followed by arguments. Importantly, result_id is left out.
|
||||
// This means that any instruction that declares the same type, but with different result_id, will result in the
|
||||
// same key.
|
||||
// Instructions are encoded into "keys": their opcode, followed by arguments, then annotations.
|
||||
// Importantly, result_id is left out. This means that any instruction that declares the same
|
||||
// type, but with different result_id, will result in the same key.
|
||||
let mut key_to_result_id = HashMap::new();
|
||||
// TODO: This is implementing forward pointers incorrectly.
|
||||
let mut unresolved_forward_pointers = HashSet::new();
|
||||
|
||||
// Collect a map from type ID to an annotation "key blob" (to append to the type key)
|
||||
let annotations = gather_annotations(&module.annotations);
|
||||
|
||||
for inst in &mut module.types_global_values {
|
||||
if inst.class.opcode == spirv::Op::TypeForwardPointer {
|
||||
if let rspirv::dr::Operand::IdRef(id) = inst.operands[0] {
|
||||
@ -103,32 +182,7 @@ pub fn remove_duplicate_types(module: &mut rspirv::dr::Module) {
|
||||
// all_inst_iter_mut pass below. However, the code is a lil bit cleaner this way I guess.
|
||||
rewrite_inst_with_rules(inst, &rewrite_rules);
|
||||
|
||||
let key = {
|
||||
let mut data = vec![];
|
||||
|
||||
data.push(inst.class.opcode as u32);
|
||||
if let Some(id) = inst.result_type {
|
||||
// We're not only deduplicating types here, but constants as well. Those contain result_types, and so we
|
||||
// need to include those here. For example, OpConstant can have the same arg, but different result_type,
|
||||
// and it should not be deduplicated (e.g. the constants 1u8 and 1u16).
|
||||
data.push(id);
|
||||
}
|
||||
for op in &inst.operands {
|
||||
if let rspirv::dr::Operand::IdRef(id) = op {
|
||||
if unresolved_forward_pointers.contains(id) {
|
||||
// TODO: This is implementing forward pointers incorrectly. All unresolved forward pointers will
|
||||
// compare equal.
|
||||
rspirv::dr::Operand::IdRef(0).assemble_into(&mut data);
|
||||
} else {
|
||||
op.assemble_into(&mut data);
|
||||
}
|
||||
} else {
|
||||
op.assemble_into(&mut data);
|
||||
}
|
||||
}
|
||||
|
||||
data
|
||||
};
|
||||
let key = make_type_key(inst, &unresolved_forward_pointers, &annotations);
|
||||
|
||||
match key_to_result_id.entry(key) {
|
||||
hash_map::Entry::Vacant(entry) => {
|
||||
@ -157,4 +211,12 @@ pub fn remove_duplicate_types(module: &mut rspirv::dr::Module) {
|
||||
for inst in module.all_inst_iter_mut() {
|
||||
rewrite_inst_with_rules(inst, &rewrite_rules);
|
||||
}
|
||||
|
||||
// The same decorations for duplicated types will cause those different types to merge
|
||||
// together. So, we need to deduplicate the annotations as well. (Note we *do* care about the
|
||||
// ID of the type being applied to here, unlike `gather_annotations`)
|
||||
let mut anno_set = HashSet::new();
|
||||
module
|
||||
.annotations
|
||||
.retain(|inst| anno_set.insert(inst.assemble()));
|
||||
}
|
||||
|
@ -99,7 +99,7 @@ pub fn link<T>(
|
||||
}
|
||||
|
||||
let mut output = loader.module();
|
||||
let mut header = rspirv::dr::ModuleHeader::new(bound);
|
||||
let mut header = rspirv::dr::ModuleHeader::new(bound + 1);
|
||||
header.set_version(version.0, version.1);
|
||||
output.header = Some(header);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user