mirror of
https://github.com/EmbarkStudios/rust-gpu.git
synced 2024-11-25 16:25:25 +00:00
Take annotations into account when merging types
This commit is contained in:
parent
2c1b73ab9a
commit
28c0885b40
@ -1,4 +1,5 @@
|
|||||||
use crate::operand_idref_mut;
|
use crate::{operand_idref, operand_idref_mut};
|
||||||
|
use rspirv::binary::Assemble;
|
||||||
use rspirv::spirv;
|
use rspirv::spirv;
|
||||||
use std::collections::{hash_map, HashMap, HashSet};
|
use std::collections::{hash_map, HashMap, HashSet};
|
||||||
|
|
||||||
@ -35,12 +36,14 @@ pub fn remove_duplicate_ext_inst_imports(module: &mut rspirv::dr::Module) {
|
|||||||
hash_map::Entry::Occupied(entry) => {
|
hash_map::Entry::Occupied(entry) => {
|
||||||
let old_value = rewrite_rules.insert(inst.result_id.unwrap(), *entry.get());
|
let old_value = rewrite_rules.insert(inst.result_id.unwrap(), *entry.get());
|
||||||
assert!(old_value.is_none());
|
assert!(old_value.is_none());
|
||||||
|
// We're iterating through the vec, so removing items is hard - nop it out.
|
||||||
*inst = rspirv::dr::Instruction::new(spirv::Op::Nop, None, None, vec![]);
|
*inst = rspirv::dr::Instruction::new(spirv::Op::Nop, None, None, vec![]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Delete the nops we inserted
|
||||||
module
|
module
|
||||||
.ext_inst_imports
|
.ext_inst_imports
|
||||||
.retain(|op| op.class.opcode != spirv::Op::Nop);
|
.retain(|op| op.class.opcode != spirv::Op::Nop);
|
||||||
@ -55,8 +58,80 @@ pub fn remove_duplicate_ext_inst_imports(module: &mut rspirv::dr::Module) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Don't merge zombie types with non-zombie types
|
fn make_annotation_key(inst: &rspirv::dr::Instruction) -> Vec<u32> {
|
||||||
pub fn remove_duplicate_types(module: &mut rspirv::dr::Module) {
|
let mut data = vec![];
|
||||||
|
|
||||||
|
data.push(inst.class.opcode as u32);
|
||||||
|
// skip over the target ID
|
||||||
|
for op in inst.operands.iter().skip(1) {
|
||||||
|
op.assemble_into(&mut data);
|
||||||
|
}
|
||||||
|
|
||||||
|
data
|
||||||
|
}
|
||||||
|
|
||||||
|
fn gather_annotations(annotations: &[rspirv::dr::Instruction]) -> HashMap<spirv::Word, Vec<u32>> {
|
||||||
|
let mut map = HashMap::new();
|
||||||
|
for inst in annotations {
|
||||||
|
if inst.class.opcode == spirv::Op::Decorate
|
||||||
|
|| inst.class.opcode == spirv::Op::MemberDecorate
|
||||||
|
{
|
||||||
|
match map.entry(operand_idref(&inst.operands[0]).unwrap()) {
|
||||||
|
hash_map::Entry::Vacant(entry) => {
|
||||||
|
entry.insert(vec![make_annotation_key(inst)]);
|
||||||
|
}
|
||||||
|
hash_map::Entry::Occupied(mut entry) => {
|
||||||
|
entry.get_mut().push(make_annotation_key(inst));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
map.into_iter()
|
||||||
|
.map(|(key, mut value)| {
|
||||||
|
(key, {
|
||||||
|
value.sort();
|
||||||
|
value.concat()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_type_key(
|
||||||
|
inst: &rspirv::dr::Instruction,
|
||||||
|
unresolved_forward_pointers: &HashSet<spirv::Word>,
|
||||||
|
annotations: &HashMap<spirv::Word, Vec<u32>>,
|
||||||
|
) -> Vec<u32> {
|
||||||
|
let mut data = vec![];
|
||||||
|
|
||||||
|
data.push(inst.class.opcode as u32);
|
||||||
|
if let Some(id) = inst.result_type {
|
||||||
|
// We're not only deduplicating types here, but constants as well. Those contain result_types, and so we
|
||||||
|
// need to include those here. For example, OpConstant can have the same arg, but different result_type,
|
||||||
|
// and it should not be deduplicated (e.g. the constants 1u8 and 1u16).
|
||||||
|
data.push(id);
|
||||||
|
}
|
||||||
|
for op in &inst.operands {
|
||||||
|
if let rspirv::dr::Operand::IdRef(id) = op {
|
||||||
|
if unresolved_forward_pointers.contains(id) {
|
||||||
|
// TODO: This is implementing forward pointers incorrectly. All unresolved forward pointers will
|
||||||
|
// compare equal.
|
||||||
|
rspirv::dr::Operand::IdRef(0).assemble_into(&mut data);
|
||||||
|
} else {
|
||||||
|
op.assemble_into(&mut data);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
op.assemble_into(&mut data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(id) = inst.result_id {
|
||||||
|
if let Some(annos) = annotations.get(&id) {
|
||||||
|
data.extend_from_slice(annos)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data
|
||||||
|
}
|
||||||
|
|
||||||
fn rewrite_inst_with_rules(inst: &mut rspirv::dr::Instruction, rules: &HashMap<u32, u32>) {
|
fn rewrite_inst_with_rules(inst: &mut rspirv::dr::Instruction, rules: &HashMap<u32, u32>) {
|
||||||
if let Some(ref mut id) = inst.result_type {
|
if let Some(ref mut id) = inst.result_type {
|
||||||
// If the rewrite rules contain this ID, replace with the mapped value, otherwise don't touch it.
|
// If the rewrite rules contain this ID, replace with the mapped value, otherwise don't touch it.
|
||||||
@ -69,18 +144,22 @@ pub fn remove_duplicate_types(module: &mut rspirv::dr::Module) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Don't merge zombie types with non-zombie types
|
||||||
|
pub fn remove_duplicate_types(module: &mut rspirv::dr::Module) {
|
||||||
// Keep in mind, this algorithm requires forward type references to not exist - i.e. it's a valid spir-v module.
|
// Keep in mind, this algorithm requires forward type references to not exist - i.e. it's a valid spir-v module.
|
||||||
use rspirv::binary::Assemble;
|
|
||||||
|
|
||||||
// When a duplicate type is encountered, then this is a map from the deleted ID, to the new, deduplicated ID.
|
// When a duplicate type is encountered, then this is a map from the deleted ID, to the new, deduplicated ID.
|
||||||
let mut rewrite_rules = HashMap::new();
|
let mut rewrite_rules = HashMap::new();
|
||||||
// Instructions are encoded into "keys": their opcode, followed by arguments. Importantly, result_id is left out.
|
// Instructions are encoded into "keys": their opcode, followed by arguments, then annotations.
|
||||||
// This means that any instruction that declares the same type, but with different result_id, will result in the
|
// Importantly, result_id is left out. This means that any instruction that declares the same
|
||||||
// same key.
|
// type, but with different result_id, will result in the same key.
|
||||||
let mut key_to_result_id = HashMap::new();
|
let mut key_to_result_id = HashMap::new();
|
||||||
// TODO: This is implementing forward pointers incorrectly.
|
// TODO: This is implementing forward pointers incorrectly.
|
||||||
let mut unresolved_forward_pointers = HashSet::new();
|
let mut unresolved_forward_pointers = HashSet::new();
|
||||||
|
|
||||||
|
// Collect a map from type ID to an annotation "key blob" (to append to the type key)
|
||||||
|
let annotations = gather_annotations(&module.annotations);
|
||||||
|
|
||||||
for inst in &mut module.types_global_values {
|
for inst in &mut module.types_global_values {
|
||||||
if inst.class.opcode == spirv::Op::TypeForwardPointer {
|
if inst.class.opcode == spirv::Op::TypeForwardPointer {
|
||||||
if let rspirv::dr::Operand::IdRef(id) = inst.operands[0] {
|
if let rspirv::dr::Operand::IdRef(id) = inst.operands[0] {
|
||||||
@ -103,32 +182,7 @@ pub fn remove_duplicate_types(module: &mut rspirv::dr::Module) {
|
|||||||
// all_inst_iter_mut pass below. However, the code is a lil bit cleaner this way I guess.
|
// all_inst_iter_mut pass below. However, the code is a lil bit cleaner this way I guess.
|
||||||
rewrite_inst_with_rules(inst, &rewrite_rules);
|
rewrite_inst_with_rules(inst, &rewrite_rules);
|
||||||
|
|
||||||
let key = {
|
let key = make_type_key(inst, &unresolved_forward_pointers, &annotations);
|
||||||
let mut data = vec![];
|
|
||||||
|
|
||||||
data.push(inst.class.opcode as u32);
|
|
||||||
if let Some(id) = inst.result_type {
|
|
||||||
// We're not only deduplicating types here, but constants as well. Those contain result_types, and so we
|
|
||||||
// need to include those here. For example, OpConstant can have the same arg, but different result_type,
|
|
||||||
// and it should not be deduplicated (e.g. the constants 1u8 and 1u16).
|
|
||||||
data.push(id);
|
|
||||||
}
|
|
||||||
for op in &inst.operands {
|
|
||||||
if let rspirv::dr::Operand::IdRef(id) = op {
|
|
||||||
if unresolved_forward_pointers.contains(id) {
|
|
||||||
// TODO: This is implementing forward pointers incorrectly. All unresolved forward pointers will
|
|
||||||
// compare equal.
|
|
||||||
rspirv::dr::Operand::IdRef(0).assemble_into(&mut data);
|
|
||||||
} else {
|
|
||||||
op.assemble_into(&mut data);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
op.assemble_into(&mut data);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
data
|
|
||||||
};
|
|
||||||
|
|
||||||
match key_to_result_id.entry(key) {
|
match key_to_result_id.entry(key) {
|
||||||
hash_map::Entry::Vacant(entry) => {
|
hash_map::Entry::Vacant(entry) => {
|
||||||
@ -157,4 +211,12 @@ pub fn remove_duplicate_types(module: &mut rspirv::dr::Module) {
|
|||||||
for inst in module.all_inst_iter_mut() {
|
for inst in module.all_inst_iter_mut() {
|
||||||
rewrite_inst_with_rules(inst, &rewrite_rules);
|
rewrite_inst_with_rules(inst, &rewrite_rules);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The same decorations for duplicated types will cause those different types to merge
|
||||||
|
// together. So, we need to deduplicate the annotations as well. (Note we *do* care about the
|
||||||
|
// ID of the type being applied to here, unlike `gather_annotations`)
|
||||||
|
let mut anno_set = HashSet::new();
|
||||||
|
module
|
||||||
|
.annotations
|
||||||
|
.retain(|inst| anno_set.insert(inst.assemble()));
|
||||||
}
|
}
|
||||||
|
@ -99,7 +99,7 @@ pub fn link<T>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
let mut output = loader.module();
|
let mut output = loader.module();
|
||||||
let mut header = rspirv::dr::ModuleHeader::new(bound);
|
let mut header = rspirv::dr::ModuleHeader::new(bound + 1);
|
||||||
header.set_version(version.0, version.1);
|
header.set_version(version.0, version.1);
|
||||||
output.header = Some(header);
|
output.header = Some(header);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user