mirror of
https://github.com/EmbarkStudios/rust-gpu.git
synced 2024-11-22 14:56:27 +00:00
Simplify remove_duplicate_types
This commit is contained in:
parent
31e8d95898
commit
a3f859afa5
@ -163,119 +163,53 @@ fn kill_annotations_and_debug(module: &mut rspirv::dr::Module, id: u32) {
|
||||
kill_with_id(&mut module.debugs, id);
|
||||
}
|
||||
|
||||
fn remove_duplicate_types(module: rspirv::dr::Module) -> rspirv::dr::Module {
|
||||
fn remove_duplicate_types(module: &mut rspirv::dr::Module) {
|
||||
use rspirv::binary::Assemble;
|
||||
|
||||
// jb-todo: spirv-tools's linker has special case handling for SpvOpTypeForwardPointer,
|
||||
// not sure if we need that; see https://github.com/KhronosGroup/SPIRV-Tools/blob/e7866de4b1dc2a7e8672867caeb0bdca49f458d3/source/opt/remove_duplicates_pass.cpp for reference
|
||||
let mut rewrite_rules = HashMap::new();
|
||||
let mut key_to_result_id = HashMap::new();
|
||||
|
||||
let mut instructions = module
|
||||
.all_inst_iter()
|
||||
.cloned()
|
||||
.collect::<Vec<_>>()
|
||||
.into_boxed_slice(); // force boxed slice so we don't accidentally grow or shrink it later
|
||||
|
||||
let mut def_use_analyzer = DefUseAnalyzer::new(&mut instructions);
|
||||
|
||||
let mut kill_annotations = vec![];
|
||||
let mut continue_from_idx = 0;
|
||||
|
||||
// need to do this process iteratively because types can reference each other
|
||||
loop {
|
||||
let mut dedup = std::collections::HashMap::new();
|
||||
let mut duplicate = None;
|
||||
|
||||
for (iterator_idx, module_inst) in module
|
||||
.types_global_values
|
||||
.iter()
|
||||
.enumerate()
|
||||
.skip(continue_from_idx)
|
||||
{
|
||||
let (inst_idx, inst) = def_use_analyzer.def(module_inst.result_id.unwrap());
|
||||
|
||||
if inst.class.opcode == spirv::Op::Nop {
|
||||
continue;
|
||||
}
|
||||
|
||||
// partially assemble only the opcode and operands to be used as a key
|
||||
// maybe this should also include the result_type
|
||||
let data = {
|
||||
for inst in &mut module.types_global_values {
|
||||
let key = {
|
||||
let mut data = vec![];
|
||||
|
||||
data.push(inst.class.opcode as u32);
|
||||
for op in &inst.operands {
|
||||
for op in &mut inst.operands {
|
||||
if let rspirv::dr::Operand::IdRef(ref mut id) = op {
|
||||
*id = rewrite_rules.get(id).copied().unwrap_or(*id);
|
||||
}
|
||||
op.assemble_into(&mut data);
|
||||
}
|
||||
|
||||
data
|
||||
};
|
||||
|
||||
// dedup contains a tuple of three indices;
|
||||
// the first two point into our `def_use_analyzer.instructions` map
|
||||
// the last one points into the `module.types_global_values` iterator so we can resume iteration
|
||||
dedup
|
||||
.entry(data)
|
||||
.and_modify(|(identical_idx, backtrack_idx)| {
|
||||
duplicate = Some((inst_idx, *identical_idx, *backtrack_idx));
|
||||
})
|
||||
.or_insert((inst_idx, iterator_idx)); // store the index that we encountered an instruction
|
||||
// for the first time so we can backtrack later
|
||||
|
||||
if let Some((_, _, backtrack_idx)) = duplicate {
|
||||
continue_from_idx = backtrack_idx;
|
||||
break;
|
||||
match key_to_result_id.entry(key) {
|
||||
hash_map::Entry::Vacant(entry) => {
|
||||
entry.insert(inst.result_id.unwrap());
|
||||
}
|
||||
hash_map::Entry::Occupied(entry) => {
|
||||
assert!(!rewrite_rules.contains_key(&inst.result_id.unwrap()));
|
||||
rewrite_rules.insert(inst.result_id.unwrap(), *entry.get());
|
||||
*inst = rspirv::dr::Instruction::new(spirv::Op::Nop, None, None, vec![]);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some((before_idx, after_idx, _)) = duplicate {
|
||||
let before_id = def_use_analyzer.instructions[before_idx].result_id.unwrap();
|
||||
let after_id = def_use_analyzer.instructions[after_idx].result_id.unwrap();
|
||||
|
||||
// remove annotations later
|
||||
kill_annotations.push(before_id);
|
||||
|
||||
def_use_analyzer.for_each_use(before_id, |inst| {
|
||||
if inst.result_type == Some(before_id) {
|
||||
inst.result_type = Some(after_id);
|
||||
}
|
||||
|
||||
for op in inst.operands.iter_mut() {
|
||||
match op {
|
||||
rspirv::dr::Operand::IdMemorySemantics(w)
|
||||
| rspirv::dr::Operand::IdScope(w)
|
||||
| rspirv::dr::Operand::IdRef(w) => {
|
||||
if *w == before_id {
|
||||
*w = after_id
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// this loop / system works on the assumption that all indices remain valid,
|
||||
// so instead of removing the instruction we just nop it out - `consume_instruction` will then
|
||||
// skip all OpNops and they won't appear in the newly constructed module
|
||||
def_use_analyzer.instructions[before_idx] =
|
||||
rspirv::dr::Instruction::new(spirv::Op::Nop, None, None, vec![]);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let mut loader = rspirv::dr::Loader::new();
|
||||
|
||||
for inst in def_use_analyzer.instructions.iter() {
|
||||
loader.consume_instruction(inst.clone());
|
||||
}
|
||||
|
||||
let mut module = loader.module();
|
||||
|
||||
for remove in kill_annotations {
|
||||
kill_annotations_and_debug(&mut module, remove);
|
||||
}
|
||||
|
||||
module
|
||||
.types_global_values
|
||||
.retain(|op| op.class.opcode != spirv::Op::Nop);
|
||||
|
||||
for inst in module.all_inst_iter_mut() {
|
||||
if let Some(ref mut id) = inst.result_type {
|
||||
*id = rewrite_rules.get(id).copied().unwrap_or(*id);
|
||||
}
|
||||
for op in &mut inst.operands {
|
||||
if let rspirv::dr::Operand::IdRef(ref mut id) = op {
|
||||
*id = rewrite_rules.get(id).copied().unwrap_or(*id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
@ -469,91 +403,6 @@ impl DefAnalyzer {
|
||||
}
|
||||
}
|
||||
|
||||
struct DefUseAnalyzer<'a> {
|
||||
def_ids: HashMap<u32, usize>,
|
||||
use_ids: HashMap<u32, Vec<usize>>,
|
||||
use_result_type_ids: HashMap<u32, Vec<usize>>,
|
||||
instructions: &'a mut [rspirv::dr::Instruction],
|
||||
}
|
||||
|
||||
impl<'a> DefUseAnalyzer<'a> {
|
||||
fn new(instructions: &'a mut [rspirv::dr::Instruction]) -> Self {
|
||||
let mut def_ids = HashMap::new();
|
||||
let mut use_ids: HashMap<u32, Vec<usize>> = HashMap::new();
|
||||
let mut use_result_type_ids: HashMap<u32, Vec<usize>> = HashMap::new();
|
||||
|
||||
instructions
|
||||
.iter()
|
||||
.enumerate()
|
||||
.for_each(|(inst_idx, inst)| {
|
||||
if let Some(def_id) = inst.result_id {
|
||||
def_ids
|
||||
.entry(def_id)
|
||||
.and_modify(|stored_inst| {
|
||||
*stored_inst = inst_idx;
|
||||
})
|
||||
.or_insert(inst_idx);
|
||||
}
|
||||
|
||||
if let Some(result_type) = inst.result_type {
|
||||
use_result_type_ids
|
||||
.entry(result_type)
|
||||
.and_modify(|v| v.push(inst_idx))
|
||||
.or_insert_with(|| vec![inst_idx]);
|
||||
}
|
||||
|
||||
for op in inst.operands.iter() {
|
||||
match op {
|
||||
rspirv::dr::Operand::IdMemorySemantics(w)
|
||||
| rspirv::dr::Operand::IdScope(w)
|
||||
| rspirv::dr::Operand::IdRef(w) => {
|
||||
use_ids
|
||||
.entry(*w)
|
||||
.and_modify(|v| v.push(inst_idx))
|
||||
.or_insert_with(|| vec![inst_idx]);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
def_ids,
|
||||
use_ids,
|
||||
use_result_type_ids,
|
||||
instructions,
|
||||
}
|
||||
}
|
||||
|
||||
fn def_idx(&self, id: u32) -> usize {
|
||||
self.def_ids[&id]
|
||||
}
|
||||
|
||||
fn def(&self, id: u32) -> (usize, &rspirv::dr::Instruction) {
|
||||
let idx = self.def_idx(id);
|
||||
(idx, &self.instructions[idx])
|
||||
}
|
||||
|
||||
fn for_each_use<F>(&mut self, id: u32, mut f: F)
|
||||
where
|
||||
F: FnMut(&mut rspirv::dr::Instruction),
|
||||
{
|
||||
// find by `result_type`
|
||||
if let Some(use_result_type_id) = self.use_result_type_ids.get(&id) {
|
||||
for inst_idx in use_result_type_id {
|
||||
f(&mut self.instructions[*inst_idx])
|
||||
}
|
||||
}
|
||||
|
||||
// find by operand
|
||||
if let Some(use_id) = self.use_ids.get(&id) {
|
||||
for inst_idx in use_id {
|
||||
f(&mut self.instructions[*inst_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn import_kill_annotations_and_debug(module: &mut rspirv::dr::Module, info: &LinkInfo) {
|
||||
for import in &info.imports {
|
||||
kill_annotations_and_debug(module, import.id);
|
||||
@ -1210,7 +1059,7 @@ pub fn link(inputs: &mut [&mut rspirv::dr::Module], opts: &Options) -> Result<rs
|
||||
// remove duplicates (https://github.com/KhronosGroup/SPIRV-Tools/blob/e7866de4b1dc2a7e8672867caeb0bdca49f458d3/source/opt/remove_duplicates_pass.cpp)
|
||||
remove_duplicate_capablities(&mut output);
|
||||
remove_duplicate_ext_inst_imports(&mut output);
|
||||
let mut output = remove_duplicate_types(output);
|
||||
remove_duplicate_types(&mut output);
|
||||
// jb-todo: strip identical OpDecoration / OpDecorationGroups
|
||||
|
||||
// remove names and decorations of import variables / functions https://github.com/KhronosGroup/SPIRV-Tools/blob/8a0ebd40f86d1f18ad42ea96c6ac53915076c3c7/source/opt/ir_context.cpp#L404
|
||||
|
@ -1,4 +1,7 @@
|
||||
use crate::link;
|
||||
use crate::LinkerError;
|
||||
use crate::Options;
|
||||
use crate::Result;
|
||||
|
||||
// https://github.com/colin-kiegel/rust-pretty-assertions/issues/24
|
||||
#[derive(PartialEq, Eq)]
|
||||
@ -21,7 +24,7 @@ fn assemble_spirv(spirv: &str) -> Vec<u8> {
|
||||
|
||||
std::fs::write(&input, spirv).unwrap();
|
||||
|
||||
let process = Command::new("spirv-as.exe")
|
||||
let process = Command::new("spirv-as")
|
||||
.arg(input.to_str().unwrap())
|
||||
.arg("-o")
|
||||
.arg(output.to_str().unwrap())
|
||||
@ -64,8 +67,7 @@ fn validate(spirv: &[u32]) {
|
||||
fn load(bytes: &[u8]) -> rspirv::dr::Module {
|
||||
let mut loader = rspirv::dr::Loader::new();
|
||||
rspirv::binary::parse_bytes(&bytes, &mut loader).unwrap();
|
||||
let module = loader.module();
|
||||
module
|
||||
loader.module()
|
||||
}
|
||||
|
||||
fn assemble_and_link(
|
||||
@ -73,7 +75,7 @@ fn assemble_and_link(
|
||||
opts: &crate::Options,
|
||||
) -> crate::Result<rspirv::dr::Module> {
|
||||
let mut modules = binaries.iter().cloned().map(load).collect::<Vec<_>>();
|
||||
let mut modules = modules.iter_mut().map(|m| m).collect::<Vec<_>>();
|
||||
let mut modules = modules.iter_mut().collect::<Vec<_>>();
|
||||
|
||||
link(&mut modules, opts)
|
||||
}
|
||||
@ -88,13 +90,13 @@ fn without_header_eq(mut result: rspirv::dr::Module, expected: &str) {
|
||||
let result = result.disassemble();
|
||||
|
||||
let expected = expected
|
||||
.split("\n")
|
||||
.split('\n')
|
||||
.map(|l| l.trim())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
let result = result
|
||||
.split("\n")
|
||||
.split('\n')
|
||||
.map(|l| l.trim().replace(" ", " ")) // rspirv outputs multiple spaces between operands
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
@ -111,14 +113,6 @@ fn without_header_eq(mut result: rspirv::dr::Module, expected: &str) {
|
||||
}
|
||||
}
|
||||
|
||||
mod test {
|
||||
use crate::test::assemble_and_link;
|
||||
use crate::test::assemble_spirv;
|
||||
use crate::test::without_header_eq;
|
||||
use crate::LinkerError;
|
||||
use crate::Options;
|
||||
use crate::Result;
|
||||
|
||||
#[test]
|
||||
fn standard() -> Result<()> {
|
||||
let a = assemble_spirv(
|
||||
@ -236,8 +230,8 @@ mod test {
|
||||
result.err(),
|
||||
Some(LinkerError::TypeMismatch {
|
||||
name: "foo".to_string(),
|
||||
import_type: "OpTypeFloat 32".to_string(),
|
||||
export_type: "OpTypeInt 32 0".to_string(),
|
||||
import_type: "f32".to_string(),
|
||||
export_type: "u32".to_string(),
|
||||
})
|
||||
);
|
||||
Ok(())
|
||||
@ -513,4 +507,3 @@ mod test {
|
||||
without_header_eq(result, expect);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user