Simplify remove_duplicate_types

This commit is contained in:
khyperia 2020-09-22 17:26:17 +02:00
parent 31e8d95898
commit a3f859afa5
2 changed files with 214 additions and 372 deletions

View File

@ -163,119 +163,53 @@ fn kill_annotations_and_debug(module: &mut rspirv::dr::Module, id: u32) {
kill_with_id(&mut module.debugs, id);
}
fn remove_duplicate_types(module: rspirv::dr::Module) -> rspirv::dr::Module {
fn remove_duplicate_types(module: &mut rspirv::dr::Module) {
use rspirv::binary::Assemble;
// jb-todo: spirv-tools's linker has special case handling for SpvOpTypeForwardPointer,
// not sure if we need that; see https://github.com/KhronosGroup/SPIRV-Tools/blob/e7866de4b1dc2a7e8672867caeb0bdca49f458d3/source/opt/remove_duplicates_pass.cpp for reference
let mut rewrite_rules = HashMap::new();
let mut key_to_result_id = HashMap::new();
let mut instructions = module
.all_inst_iter()
.cloned()
.collect::<Vec<_>>()
.into_boxed_slice(); // force boxed slice so we don't accidentally grow or shrink it later
for inst in &mut module.types_global_values {
let key = {
let mut data = vec![];
let mut def_use_analyzer = DefUseAnalyzer::new(&mut instructions);
let mut kill_annotations = vec![];
let mut continue_from_idx = 0;
// need to do this process iteratively because types can reference each other
loop {
let mut dedup = std::collections::HashMap::new();
let mut duplicate = None;
for (iterator_idx, module_inst) in module
.types_global_values
.iter()
.enumerate()
.skip(continue_from_idx)
{
let (inst_idx, inst) = def_use_analyzer.def(module_inst.result_id.unwrap());
if inst.class.opcode == spirv::Op::Nop {
continue;
data.push(inst.class.opcode as u32);
for op in &mut inst.operands {
if let rspirv::dr::Operand::IdRef(ref mut id) = op {
*id = rewrite_rules.get(id).copied().unwrap_or(*id);
}
op.assemble_into(&mut data);
}
// partially assemble only the opcode and operands to be used as a key
// maybe this should also include the result_type
let data = {
let mut data = vec![];
data
};
data.push(inst.class.opcode as u32);
for op in &inst.operands {
op.assemble_into(&mut data);
}
data
};
// dedup contains a tuple of three indices;
// the first two point into our `def_use_analyzer.instructions` map
// the last one points into the `module.types_global_values` iterator so we can resume iteration
dedup
.entry(data)
.and_modify(|(identical_idx, backtrack_idx)| {
duplicate = Some((inst_idx, *identical_idx, *backtrack_idx));
})
.or_insert((inst_idx, iterator_idx)); // store the index that we encountered an instruction
// for the first time so we can backtrack later
if let Some((_, _, backtrack_idx)) = duplicate {
continue_from_idx = backtrack_idx;
break;
match key_to_result_id.entry(key) {
hash_map::Entry::Vacant(entry) => {
entry.insert(inst.result_id.unwrap());
}
hash_map::Entry::Occupied(entry) => {
assert!(!rewrite_rules.contains_key(&inst.result_id.unwrap()));
rewrite_rules.insert(inst.result_id.unwrap(), *entry.get());
*inst = rspirv::dr::Instruction::new(spirv::Op::Nop, None, None, vec![]);
}
}
if let Some((before_idx, after_idx, _)) = duplicate {
let before_id = def_use_analyzer.instructions[before_idx].result_id.unwrap();
let after_id = def_use_analyzer.instructions[after_idx].result_id.unwrap();
// remove annotations later
kill_annotations.push(before_id);
def_use_analyzer.for_each_use(before_id, |inst| {
if inst.result_type == Some(before_id) {
inst.result_type = Some(after_id);
}
for op in inst.operands.iter_mut() {
match op {
rspirv::dr::Operand::IdMemorySemantics(w)
| rspirv::dr::Operand::IdScope(w)
| rspirv::dr::Operand::IdRef(w) => {
if *w == before_id {
*w = after_id
}
}
_ => {}
}
}
});
// this loop / system works on the assumption that all indices remain valid,
// so instead of removing the instruction we just nop it out - `consume_instruction` will then
// skip all OpNops and they won't appear in the newly constructed module
def_use_analyzer.instructions[before_idx] =
rspirv::dr::Instruction::new(spirv::Op::Nop, None, None, vec![]);
} else {
break;
}
}
let mut loader = rspirv::dr::Loader::new();
for inst in def_use_analyzer.instructions.iter() {
loader.consume_instruction(inst.clone());
}
let mut module = loader.module();
for remove in kill_annotations {
kill_annotations_and_debug(&mut module, remove);
}
module
.types_global_values
.retain(|op| op.class.opcode != spirv::Op::Nop);
for inst in module.all_inst_iter_mut() {
if let Some(ref mut id) = inst.result_type {
*id = rewrite_rules.get(id).copied().unwrap_or(*id);
}
for op in &mut inst.operands {
if let rspirv::dr::Operand::IdRef(ref mut id) = op {
*id = rewrite_rules.get(id).copied().unwrap_or(*id);
}
}
}
}
#[derive(Clone, Debug)]
@ -469,91 +403,6 @@ impl DefAnalyzer {
}
}
struct DefUseAnalyzer<'a> {
def_ids: HashMap<u32, usize>,
use_ids: HashMap<u32, Vec<usize>>,
use_result_type_ids: HashMap<u32, Vec<usize>>,
instructions: &'a mut [rspirv::dr::Instruction],
}
impl<'a> DefUseAnalyzer<'a> {
fn new(instructions: &'a mut [rspirv::dr::Instruction]) -> Self {
let mut def_ids = HashMap::new();
let mut use_ids: HashMap<u32, Vec<usize>> = HashMap::new();
let mut use_result_type_ids: HashMap<u32, Vec<usize>> = HashMap::new();
instructions
.iter()
.enumerate()
.for_each(|(inst_idx, inst)| {
if let Some(def_id) = inst.result_id {
def_ids
.entry(def_id)
.and_modify(|stored_inst| {
*stored_inst = inst_idx;
})
.or_insert(inst_idx);
}
if let Some(result_type) = inst.result_type {
use_result_type_ids
.entry(result_type)
.and_modify(|v| v.push(inst_idx))
.or_insert_with(|| vec![inst_idx]);
}
for op in inst.operands.iter() {
match op {
rspirv::dr::Operand::IdMemorySemantics(w)
| rspirv::dr::Operand::IdScope(w)
| rspirv::dr::Operand::IdRef(w) => {
use_ids
.entry(*w)
.and_modify(|v| v.push(inst_idx))
.or_insert_with(|| vec![inst_idx]);
}
_ => {}
}
}
});
Self {
def_ids,
use_ids,
use_result_type_ids,
instructions,
}
}
fn def_idx(&self, id: u32) -> usize {
self.def_ids[&id]
}
fn def(&self, id: u32) -> (usize, &rspirv::dr::Instruction) {
let idx = self.def_idx(id);
(idx, &self.instructions[idx])
}
fn for_each_use<F>(&mut self, id: u32, mut f: F)
where
F: FnMut(&mut rspirv::dr::Instruction),
{
// find by `result_type`
if let Some(use_result_type_id) = self.use_result_type_ids.get(&id) {
for inst_idx in use_result_type_id {
f(&mut self.instructions[*inst_idx])
}
}
// find by operand
if let Some(use_id) = self.use_ids.get(&id) {
for inst_idx in use_id {
f(&mut self.instructions[*inst_idx]);
}
}
}
}
fn import_kill_annotations_and_debug(module: &mut rspirv::dr::Module, info: &LinkInfo) {
for import in &info.imports {
kill_annotations_and_debug(module, import.id);
@ -1210,7 +1059,7 @@ pub fn link(inputs: &mut [&mut rspirv::dr::Module], opts: &Options) -> Result<rs
// remove duplicates (https://github.com/KhronosGroup/SPIRV-Tools/blob/e7866de4b1dc2a7e8672867caeb0bdca49f458d3/source/opt/remove_duplicates_pass.cpp)
remove_duplicate_capablities(&mut output);
remove_duplicate_ext_inst_imports(&mut output);
let mut output = remove_duplicate_types(output);
remove_duplicate_types(&mut output);
// jb-todo: strip identical OpDecoration / OpDecorationGroups
// remove names and decorations of import variables / functions https://github.com/KhronosGroup/SPIRV-Tools/blob/8a0ebd40f86d1f18ad42ea96c6ac53915076c3c7/source/opt/ir_context.cpp#L404

View File

@ -1,4 +1,7 @@
use crate::link;
use crate::LinkerError;
use crate::Options;
use crate::Result;
// https://github.com/colin-kiegel/rust-pretty-assertions/issues/24
#[derive(PartialEq, Eq)]
@ -21,7 +24,7 @@ fn assemble_spirv(spirv: &str) -> Vec<u8> {
std::fs::write(&input, spirv).unwrap();
let process = Command::new("spirv-as.exe")
let process = Command::new("spirv-as")
.arg(input.to_str().unwrap())
.arg("-o")
.arg(output.to_str().unwrap())
@ -64,8 +67,7 @@ fn validate(spirv: &[u32]) {
fn load(bytes: &[u8]) -> rspirv::dr::Module {
let mut loader = rspirv::dr::Loader::new();
rspirv::binary::parse_bytes(&bytes, &mut loader).unwrap();
let module = loader.module();
module
loader.module()
}
fn assemble_and_link(
@ -73,7 +75,7 @@ fn assemble_and_link(
opts: &crate::Options,
) -> crate::Result<rspirv::dr::Module> {
let mut modules = binaries.iter().cloned().map(load).collect::<Vec<_>>();
let mut modules = modules.iter_mut().map(|m| m).collect::<Vec<_>>();
let mut modules = modules.iter_mut().collect::<Vec<_>>();
link(&mut modules, opts)
}
@ -88,13 +90,13 @@ fn without_header_eq(mut result: rspirv::dr::Module, expected: &str) {
let result = result.disassemble();
let expected = expected
.split("\n")
.split('\n')
.map(|l| l.trim())
.collect::<Vec<_>>()
.join("\n");
let result = result
.split("\n")
.split('\n')
.map(|l| l.trim().replace(" ", " ")) // rspirv outputs multiple spaces between operands
.collect::<Vec<_>>()
.join("\n");
@ -111,240 +113,232 @@ fn without_header_eq(mut result: rspirv::dr::Module, expected: &str) {
}
}
mod test {
use crate::test::assemble_and_link;
use crate::test::assemble_spirv;
use crate::test::without_header_eq;
use crate::LinkerError;
use crate::Options;
use crate::Result;
#[test]
fn standard() -> Result<()> {
let a = assemble_spirv(
r#"OpCapability Linkage
#[test]
fn standard() -> Result<()> {
let a = assemble_spirv(
r#"OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Import
%2 = OpTypeFloat 32
%1 = OpVariable %2 Uniform
%3 = OpVariable %2 Input"#,
);
);
let b = assemble_spirv(
r#"OpCapability Linkage
let b = assemble_spirv(
r#"OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Export
%2 = OpTypeFloat 32
%3 = OpConstant %2 42
%1 = OpVariable %2 Uniform %3
"#,
);
);
let result = assemble_and_link(&[&a, &b], &Options::default())?;
let expect = r#"OpModuleProcessed "Linked by rspirv-linker"
let result = assemble_and_link(&[&a, &b], &Options::default())?;
let expect = r#"OpModuleProcessed "Linked by rspirv-linker"
%1 = OpTypeFloat 32
%2 = OpVariable %1 Input
%3 = OpConstant %1 42.0
%4 = OpVariable %1 Uniform %3"#;
without_header_eq(result, expect);
Ok(())
}
without_header_eq(result, expect);
Ok(())
}
#[test]
fn not_a_lib_extra_exports() -> Result<()> {
let a = assemble_spirv(
r#"OpCapability Linkage
#[test]
fn not_a_lib_extra_exports() -> Result<()> {
let a = assemble_spirv(
r#"OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Export
%2 = OpTypeFloat 32
%1 = OpVariable %2 Uniform"#,
);
);
let result = assemble_and_link(&[&a], &Options::default())?;
let expect = r#"OpModuleProcessed "Linked by rspirv-linker"
let result = assemble_and_link(&[&a], &Options::default())?;
let expect = r#"OpModuleProcessed "Linked by rspirv-linker"
%1 = OpTypeFloat 32
%2 = OpVariable %1 Uniform"#;
without_header_eq(result, expect);
Ok(())
}
without_header_eq(result, expect);
Ok(())
}
#[test]
fn lib_extra_exports() -> Result<()> {
let a = assemble_spirv(
r#"OpCapability Linkage
#[test]
fn lib_extra_exports() -> Result<()> {
let a = assemble_spirv(
r#"OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Export
%2 = OpTypeFloat 32
%1 = OpVariable %2 Uniform"#,
);
);
let result = assemble_and_link(
&[&a],
&Options {
lib: true,
..Default::default()
},
)?;
let result = assemble_and_link(
&[&a],
&Options {
lib: true,
..Default::default()
},
)?;
let expect = r#"OpModuleProcessed "Linked by rspirv-linker"
let expect = r#"OpModuleProcessed "Linked by rspirv-linker"
OpDecorate %1 LinkageAttributes "foo" Export
%2 = OpTypeFloat 32
%1 = OpVariable %2 Uniform"#;
without_header_eq(result, expect);
Ok(())
}
without_header_eq(result, expect);
Ok(())
}
#[test]
fn unresolved_symbol() -> Result<()> {
let a = assemble_spirv(
r#"OpCapability Linkage
#[test]
fn unresolved_symbol() -> Result<()> {
let a = assemble_spirv(
r#"OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Import
%2 = OpTypeFloat 32
%1 = OpVariable %2 Uniform"#,
);
);
let b = assemble_spirv("OpCapability Linkage");
let b = assemble_spirv("OpCapability Linkage");
let result = assemble_and_link(&[&a, &b], &Options::default());
let result = assemble_and_link(&[&a, &b], &Options::default());
assert_eq!(
result.err(),
Some(LinkerError::UnresolvedSymbol("foo".to_string()))
);
assert_eq!(
result.err(),
Some(LinkerError::UnresolvedSymbol("foo".to_string()))
);
Ok(())
}
Ok(())
}
#[test]
fn type_mismatch() -> Result<()> {
let a = assemble_spirv(
r#"OpCapability Linkage
#[test]
fn type_mismatch() -> Result<()> {
let a = assemble_spirv(
r#"OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Import
%2 = OpTypeFloat 32
%1 = OpVariable %2 Uniform
%3 = OpVariable %2 Input"#,
);
);
let b = assemble_spirv(
r#"OpCapability Linkage
let b = assemble_spirv(
r#"OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Export
%2 = OpTypeInt 32 0
%3 = OpConstant %2 42
%1 = OpVariable %2 Uniform %3
"#,
);
);
let result = assemble_and_link(&[&a, &b], &Options::default());
assert_eq!(
result.err(),
Some(LinkerError::TypeMismatch {
name: "foo".to_string(),
import_type: "OpTypeFloat 32".to_string(),
export_type: "OpTypeInt 32 0".to_string(),
})
);
Ok(())
}
let result = assemble_and_link(&[&a, &b], &Options::default());
assert_eq!(
result.err(),
Some(LinkerError::TypeMismatch {
name: "foo".to_string(),
import_type: "f32".to_string(),
export_type: "u32".to_string(),
})
);
Ok(())
}
#[test]
fn multiple_definitions() -> Result<()> {
let a = assemble_spirv(
r#"OpCapability Linkage
#[test]
fn multiple_definitions() -> Result<()> {
let a = assemble_spirv(
r#"OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Import
%2 = OpTypeFloat 32
%1 = OpVariable %2 Uniform
%3 = OpVariable %2 Input"#,
);
);
let b = assemble_spirv(
r#"OpCapability Linkage
let b = assemble_spirv(
r#"OpCapability Linkage
OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Export
%2 = OpTypeFloat 32
%3 = OpConstant %2 42
%1 = OpVariable %2 Uniform %3"#,
);
);
let c = assemble_spirv(
r#"OpCapability Linkage
let c = assemble_spirv(
r#"OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Export
%2 = OpTypeFloat 32
%3 = OpConstant %2 -1
%1 = OpVariable %2 Uniform %3"#,
);
);
let result = assemble_and_link(&[&a, &b, &c], &Options::default());
assert_eq!(
result.err(),
Some(LinkerError::MultipleExports("foo".to_string()))
);
Ok(())
}
let result = assemble_and_link(&[&a, &b, &c], &Options::default());
assert_eq!(
result.err(),
Some(LinkerError::MultipleExports("foo".to_string()))
);
Ok(())
}
#[test]
fn multiple_definitions_different_types() -> Result<()> {
let a = assemble_spirv(
r#"OpCapability Linkage
#[test]
fn multiple_definitions_different_types() -> Result<()> {
let a = assemble_spirv(
r#"OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Import
%2 = OpTypeFloat 32
%1 = OpVariable %2 Uniform
%3 = OpVariable %2 Input"#,
);
);
let b = assemble_spirv(
r#"OpCapability Linkage
let b = assemble_spirv(
r#"OpCapability Linkage
OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Export
%2 = OpTypeInt 32 0
%3 = OpConstant %2 42
%1 = OpVariable %2 Uniform %3"#,
);
);
let c = assemble_spirv(
r#"OpCapability Linkage
let c = assemble_spirv(
r#"OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Export
%2 = OpTypeFloat 32
%3 = OpConstant %2 12
%1 = OpVariable %2 Uniform %3"#,
);
);
let result = assemble_and_link(&[&a, &b, &c], &Options::default());
assert_eq!(
result.err(),
Some(LinkerError::MultipleExports("foo".to_string()))
);
Ok(())
}
let result = assemble_and_link(&[&a, &b, &c], &Options::default());
assert_eq!(
result.err(),
Some(LinkerError::MultipleExports("foo".to_string()))
);
Ok(())
}
//jb-todo: this isn't validated yet in the linker (see ensure_matching_import_export_pairs)
/*#[test]
fn decoration_mismatch() -> Result<()> {
let a = assemble_spirv(
r#"OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Import
OpDecorate %2 Constant
%2 = OpTypeFloat 32
%1 = OpVariable %2 Uniform
%3 = OpVariable %2 Input"#,
);
//jb-todo: this isn't validated yet in the linker (see ensure_matching_import_export_pairs)
/*#[test]
fn decoration_mismatch() -> Result<()> {
let a = assemble_spirv(
r#"OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Import
OpDecorate %2 Constant
%2 = OpTypeFloat 32
%1 = OpVariable %2 Uniform
%3 = OpVariable %2 Input"#,
);
let b = assemble_spirv(
r#"OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Export
%2 = OpTypeFloat 32
%3 = OpConstant %2 42
%1 = OpVariable %2 Uniform %3"#,
);
let b = assemble_spirv(
r#"OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Export
%2 = OpTypeFloat 32
%3 = OpConstant %2 42
%1 = OpVariable %2 Uniform %3"#,
);
let result = assemble_and_link(&[&a, &b], &Options::default());
assert_eq!(
result.err(),
Some(LinkerError::MultipleExports("foo".to_string()))
);
Ok(())
}*/
let result = assemble_and_link(&[&a, &b], &Options::default());
assert_eq!(
result.err(),
Some(LinkerError::MultipleExports("foo".to_string()))
);
Ok(())
}*/
#[test]
fn func_ctrl() -> Result<()> {
let a = assemble_spirv(
r#"OpCapability Linkage
#[test]
fn func_ctrl() -> Result<()> {
let a = assemble_spirv(
r#"OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Import
%2 = OpTypeVoid
%3 = OpTypeFunction %2
@ -352,10 +346,10 @@ mod test {
%5 = OpVariable %4 Uniform
%1 = OpFunction %2 None %3
OpFunctionEnd"#,
);
);
let b = assemble_spirv(
r#"OpCapability Linkage
let b = assemble_spirv(
r#"OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Export
%2 = OpTypeVoid
%3 = OpTypeFunction %2
@ -363,11 +357,11 @@ mod test {
%4 = OpLabel
OpReturn
OpFunctionEnd"#,
);
);
let result = assemble_and_link(&[&a, &b], &Options::default())?;
let result = assemble_and_link(&[&a, &b], &Options::default())?;
let expect = r#"OpModuleProcessed "Linked by rspirv-linker"
let expect = r#"OpModuleProcessed "Linked by rspirv-linker"
%1 = OpTypeVoid
%2 = OpTypeFloat 32
%3 = OpTypeFunction %1
@ -377,14 +371,14 @@ mod test {
OpReturn
OpFunctionEnd"#;
without_header_eq(result, expect);
Ok(())
}
without_header_eq(result, expect);
Ok(())
}
#[test]
fn use_exported_func_param_attr() -> Result<()> {
let a = assemble_spirv(
r#"OpCapability Kernel
#[test]
fn use_exported_func_param_attr() -> Result<()> {
let a = assemble_spirv(
r#"OpCapability Kernel
OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Import
OpDecorate %2 FuncParamAttr Zext
@ -400,10 +394,10 @@ mod test {
%4 = OpFunctionParameter %6
OpFunctionEnd
"#,
);
);
let b = assemble_spirv(
r#"OpCapability Kernel
let b = assemble_spirv(
r#"OpCapability Kernel
OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Export
OpDecorate %2 FuncParamAttr Sext
@ -416,11 +410,11 @@ mod test {
OpReturn
OpFunctionEnd
"#,
);
);
let result = assemble_and_link(&[&a, &b], &Options::default())?;
let result = assemble_and_link(&[&a, &b], &Options::default())?;
let expect = r#"OpCapability Kernel
let expect = r#"OpCapability Kernel
OpModuleProcessed "Linked by rspirv-linker"
OpDecorate %1 FuncParamAttr Sext
OpDecorate %2 FuncParamAttr Zext
@ -438,14 +432,14 @@ mod test {
OpReturn
OpFunctionEnd"#;
without_header_eq(result, expect);
Ok(())
}
without_header_eq(result, expect);
Ok(())
}
#[test]
fn names_and_decorations() -> Result<()> {
let a = assemble_spirv(
r#"OpCapability Kernel
#[test]
fn names_and_decorations() -> Result<()> {
let a = assemble_spirv(
r#"OpCapability Kernel
OpCapability Linkage
OpName %1 "foo"
OpName %3 "param"
@ -465,10 +459,10 @@ mod test {
%4 = OpFunctionParameter %9
OpFunctionEnd
"#,
);
);
let b = assemble_spirv(
r#"OpCapability Kernel
let b = assemble_spirv(
r#"OpCapability Kernel
OpCapability Linkage
OpName %1 "foo"
OpName %2 "param"
@ -484,11 +478,11 @@ mod test {
OpReturn
OpFunctionEnd
"#,
);
);
let result = assemble_and_link(&[&a, &b], &Options::default())?;
let result = assemble_and_link(&[&a, &b], &Options::default())?;
let expect = r#"OpCapability Kernel
let expect = r#"OpCapability Kernel
OpName %1 "param"
OpName %2 "foo"
OpModuleProcessed "Linked by rspirv-linker"
@ -510,7 +504,6 @@ mod test {
OpReturn
OpFunctionEnd"#;
without_header_eq(result, expect);
Ok(())
}
without_header_eq(result, expect);
Ok(())
}