Add a bunch of test cases from spirv-linker and bug fixes to make these test cases work

This commit is contained in:
Jasper Bekkers 2020-09-03 17:58:34 +01:00
parent 733008a993
commit cc77d911d3
No known key found for this signature in database
GPG Key ID: C59CE25F4DA6625D
4 changed files with 402 additions and 42 deletions

View File

@ -273,6 +273,7 @@ dependencies = [
"pretty_assertions",
"rspirv",
"tempfile",
"thiserror",
"topological-sort",
]
@ -344,6 +345,26 @@ dependencies = [
"winapi",
]
[[package]]
name = "thiserror"
version = "1.0.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7dfdd070ccd8ccb78f4ad66bf1982dc37f620ef696c6b5028fe2ed83dd3d0d08"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd80fc12f73063ac132ac92aceea36734f04a1d93c1240c6944e23a3b8841793"
dependencies = [
"proc-macro2 1.0.19",
"quote 1.0.7",
"syn 1.0.39",
]
[[package]]
name = "thread_local"
version = "1.0.1"

View File

@ -9,6 +9,7 @@ edition = "2018"
[dependencies]
rspirv = { path = "C:/Users/Jasper/traverse/rspirv/rspirv/"}
topological-sort = "0.1"
thiserror = "1.0.20"
[dev-dependencies]
tempfile = "3.1"

View File

@ -4,8 +4,27 @@ use rspirv::binary::Consumer;
use rspirv::binary::Disassemble;
use rspirv::spirv;
use std::collections::{HashMap, HashSet};
use thiserror::Error;
use topological_sort::TopologicalSort;
#[derive(Error, Debug, PartialEq)]
pub enum LinkerError {
#[error("Unresolved symbol {:?}", .0)]
UnresolvedSymbol(String),
#[error("Multiple exports found for {:?}", .0)]
MultipleExports(String),
#[error("Types mismatch for {:?}, imported with type {:?}, exported with type {:?}", .name, .import_type, .export_type)]
TypeMismatch {
name: String,
import_type: String,
export_type: String,
},
#[error("unknown data store error")]
Unknown,
}
type Result<T> = std::result::Result<T, LinkerError>;
fn load(bytes: &[u8]) -> rspirv::dr::Module {
let mut loader = rspirv::dr::Loader::new();
rspirv::binary::parse_bytes(&bytes, &mut loader).unwrap();
@ -90,15 +109,21 @@ fn remove_duplicate_ext_inst_imports(module: &mut rspirv::dr::Module) {
}
fn kill_with_id(insts: &mut Vec<rspirv::dr::Instruction>, id: u32) {
kill_with(insts, |inst| match inst.operands[0] {
rspirv::dr::Operand::IdMemorySemantics(w)
| rspirv::dr::Operand::IdScope(w)
| rspirv::dr::Operand::IdRef(w)
if w == id =>
{
true
kill_with(insts, |inst| {
if inst.operands.is_empty() {
return false;
}
match inst.operands[0] {
rspirv::dr::Operand::IdMemorySemantics(w)
| rspirv::dr::Operand::IdScope(w)
| rspirv::dr::Operand::IdRef(w)
if w == id =>
{
true
}
_ => false,
}
_ => false,
})
}
@ -117,7 +142,7 @@ where
insts.swap_remove(idx);
}
if idx == 0 {
if idx == 0 || insts.is_empty() {
break;
}
@ -133,14 +158,13 @@ fn kill_annotations_and_debug(module: &mut rspirv::dr::Module, id: u32) {
fn remove_duplicate_types(module: &mut rspirv::dr::Module) {
// jb-todo: spirv-tools's linker has special case handling for SpvOpTypeForwardPointer,
// not sure if we need that; see https://github.com/KhronosGroup/SPIRV-Tools/blob/e7866de4b1dc2a7e8672867caeb0bdca49f458d3/source/opt/remove_duplicates_pass.cpp for reference
let mut start = 0;
// need to do this process iteratively because types can reference each other
loop {
let mut replace = None;
// start with `nth` so we can restart this loop quickly after killing the op
for (i_idx, i) in module.types_global_values.iter().enumerate().nth(start) {
for (i_idx, i) in module.types_global_values.iter().enumerate() {
let mut identical = None;
for j in module.types_global_values.iter().skip(i_idx + 1) {
if i.is_type_identical(j) {
@ -161,7 +185,6 @@ fn remove_duplicate_types(module: &mut rspirv::dr::Module) {
kill_annotations_and_debug(module, remove);
replace_all_uses_with(module, remove, keep);
module.types_global_values.swap_remove(kill_idx);
start = kill_idx; // jb-todo: is it correct to restart this loop here?
} else {
break;
}
@ -205,7 +228,7 @@ fn inst_fully_eq(a: &rspirv::dr::Instruction, b: &rspirv::dr::Instruction) -> bo
&& a.operands == b.operands
}
fn find_import_export_pairs(module: &rspirv::dr::Module, defs: &DefAnalyzer) -> LinkInfo {
fn find_import_export_pairs(module: &rspirv::dr::Module, defs: &DefAnalyzer) -> Result<LinkInfo> {
let mut imports = vec![];
let mut exports: HashMap<String, Vec<LinkSymbol>> = HashMap::new();
@ -218,7 +241,12 @@ fn find_import_export_pairs(module: &rspirv::dr::Module, defs: &DefAnalyzer) ->
rspirv::dr::Operand::IdRef(i) => i,
_ => panic!("Expected IdRef"),
};
let name = &annotation.operands[2];
let name = match &annotation.operands[2] {
rspirv::dr::Operand::LiteralString(s) => s,
_ => panic!("Expected LiteralString"),
};
let ty = &annotation.operands[3];
let def_inst = defs
@ -271,26 +299,51 @@ fn find_import_export_pairs(module: &rspirv::dr::Module, defs: &DefAnalyzer) ->
.find_potential_pairs()
}
fn cleanup_type(mut ty: rspirv::dr::Instruction) -> String {
ty.result_id = None;
ty.disassemble()
}
impl LinkInfo {
fn find_potential_pairs(mut self) -> Self {
fn find_potential_pairs(mut self) -> Result<Self> {
for import in &self.imports {
let potential_matching_exports = self.exports.get(&import.name);
if let Some(potential_matching_exports) = potential_matching_exports {
if potential_matching_exports.len() > 1 {
return Err(LinkerError::MultipleExports(import.name.clone()));
}
self.potential_pairs.push(ImportExportPair {
import: import.clone(),
export: potential_matching_exports.first().unwrap().clone(),
});
} else {
panic!("Can't find matching export for {}", import.name);
return Err(LinkerError::UnresolvedSymbol(import.name.clone()));
}
}
self
Ok(self)
}
/// returns the list of matching import / export pairs after validation the list of potential pairs
fn ensure_matching_import_export_pairs(&self) -> &Vec<ImportExportPair> {
fn ensure_matching_import_export_pairs(
&self,
defs: &DefAnalyzer,
) -> Result<&Vec<ImportExportPair>> {
for pair in &self.potential_pairs {
let import_result_type = defs.def(pair.import.type_id).unwrap();
let export_result_type = defs.def(pair.export.type_id).unwrap();
if import_result_type.class.opcode != spirv::Op::TypeFunction {
if !import_result_type.is_type_identical(export_result_type) {
return Err(LinkerError::TypeMismatch {
name: pair.import.name.clone(),
import_type: cleanup_type(import_result_type.clone()),
export_type: cleanup_type(export_result_type.clone()),
});
}
}
for (import_param, export_param) in pair
.import
.parameters
@ -305,7 +358,7 @@ impl LinkInfo {
}
}
&self.potential_pairs
Ok(&self.potential_pairs)
}
}
@ -386,6 +439,10 @@ fn kill_linkage_instructions(
let eq = pairs
.iter()
.find(|p| {
if inst.operands.is_empty() {
return false;
}
if let rspirv::dr::Operand::IdRef(id) = inst.operands[0] {
id == p.import.id || id == p.export.id
} else {
@ -454,19 +511,19 @@ fn sort_globals(module: &mut rspirv::dr::Module) {
let mut ts = TopologicalSort::<u32>::new();
for t in module.types_global_values.iter() {
if let Some(result_type) = t.result_type {
if let Some(result_id) = t.result_id {
if let Some(result_id) = t.result_id {
if let Some(result_type) = t.result_type {
ts.add_dependency(result_type, result_id);
}
for op in &t.operands {
match op {
rspirv::dr::Operand::IdMemorySemantics(w)
| rspirv::dr::Operand::IdScope(w)
| rspirv::dr::Operand::IdRef(w) => {
ts.add_dependency(*w, result_id); // the op defining the IdRef should come before our op / result_id
}
_ => {}
for op in &t.operands {
match op {
rspirv::dr::Operand::IdMemorySemantics(w)
| rspirv::dr::Operand::IdScope(w)
| rspirv::dr::Operand::IdRef(w) => {
ts.add_dependency(*w, result_id); // the op defining the IdRef should come before our op / result_id
}
_ => {}
}
}
}
@ -494,7 +551,7 @@ fn sort_globals(module: &mut rspirv::dr::Module) {
module.types_global_values = new_types_global_values;
}
fn link(inputs: &mut [&mut rspirv::dr::Module], opts: &Options) -> rspirv::dr::Module {
fn link(inputs: &mut [&mut rspirv::dr::Module], opts: &Options) -> Result<rspirv::dr::Module> {
// shift all the ids
let mut bound = inputs[0].header.as_ref().unwrap().bound - 1;
@ -520,10 +577,10 @@ fn link(inputs: &mut [&mut rspirv::dr::Module], opts: &Options) -> rspirv::dr::M
// find import / export pairs
let defs = DefAnalyzer::new(&output);
let info = find_import_export_pairs(&output, &defs);
let info = find_import_export_pairs(&output, &defs)?;
// ensure import / export pairs have matching types and defintions
let matching_pairs = info.ensure_matching_import_export_pairs();
let matching_pairs = info.ensure_matching_import_export_pairs(&defs)?;
// remove duplicates (https://github.com/KhronosGroup/SPIRV-Tools/blob/e7866de4b1dc2a7e8672867caeb0bdca49f458d3/source/opt/remove_duplicates_pass.cpp)
remove_duplicates(&mut output);
@ -557,10 +614,10 @@ fn link(inputs: &mut [&mut rspirv::dr::Module], opts: &Options) -> rspirv::dr::M
println!("{}\n\n", output.disassemble());
// output the module
output
Ok(output)
}
fn main() {
fn main() -> Result<()> {
let body1 = include_bytes!("../test/1/body_1.spv");
let body2 = include_bytes!("../test/1/body_2.spv");
@ -572,6 +629,8 @@ fn main() {
partial: false,
};
let output = link(&mut [&mut body1, &mut body2], &opts);
let output = link(&mut [&mut body1, &mut body2], &opts)?;
println!("{}\n\n", output.disassemble());
Ok(())
}

View File

@ -68,7 +68,10 @@ fn load(bytes: &[u8]) -> rspirv::dr::Module {
module
}
fn assemble_and_link(binaries: &[&[u8]], opts: &crate::Options) -> rspirv::dr::Module {
fn assemble_and_link(
binaries: &[&[u8]],
opts: &crate::Options,
) -> crate::Result<rspirv::dr::Module> {
let mut modules = binaries.iter().cloned().map(load).collect::<Vec<_>>();
let mut modules = modules.iter_mut().map(|m| m).collect::<Vec<_>>();
@ -97,6 +100,7 @@ fn without_header_eq(mut result: rspirv::dr::Module, expected: &str) {
.join("\n");
if result != expected {
println!("{}", &result);
panic!(
"assertion failed: `(left.contains(right))`\
\n\
@ -111,10 +115,12 @@ mod test {
use crate::test::assemble_and_link;
use crate::test::assemble_spirv;
use crate::test::without_header_eq;
use crate::LinkerError;
use crate::Options;
use crate::Result;
#[test]
fn standard() {
fn standard() -> Result<()> {
let a = assemble_spirv(
r#"OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Import
@ -132,7 +138,7 @@ mod test {
"#,
);
let result = assemble_and_link(&[&a, &b], &Options::default());
let result = assemble_and_link(&[&a, &b], &Options::default())?;
let expect = r#"OpModuleProcessed "Linked by rspirv-linker"
%1 = OpTypeFloat 32
%2 = OpVariable %1 Input
@ -140,10 +146,11 @@ mod test {
%4 = OpVariable %1 Uniform %3"#;
without_header_eq(result, expect);
Ok(())
}
#[test]
fn not_a_lib_extra_exports() {
fn not_a_lib_extra_exports() -> Result<()> {
let a = assemble_spirv(
r#"OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Export
@ -151,15 +158,16 @@ mod test {
%1 = OpVariable %2 Uniform"#,
);
let result = assemble_and_link(&[&a], &Options::default());
let result = assemble_and_link(&[&a], &Options::default())?;
let expect = r#"OpModuleProcessed "Linked by rspirv-linker"
%1 = OpTypeFloat 32
%2 = OpVariable %1 Uniform"#;
without_header_eq(result, expect);
Ok(())
}
#[test]
fn lib_extra_exports() {
fn lib_extra_exports() -> Result<()> {
let a = assemble_spirv(
r#"OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Export
@ -173,12 +181,283 @@ mod test {
lib: true,
..Default::default()
},
);
)?;
let expect = r#"OpModuleProcessed "Linked by rspirv-linker"
OpDecorate %1 LinkageAttributes "foo" Export
%2 = OpTypeFloat 32
%1 = OpVariable %2 Uniform"#;
without_header_eq(result, expect);
Ok(())
}
#[test]
fn unresolved_symbol() -> Result<()> {
let a = assemble_spirv(
r#"OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Import
%2 = OpTypeFloat 32
%1 = OpVariable %2 Uniform"#,
);
let b = assemble_spirv("OpCapability Linkage");
let result = assemble_and_link(&[&a, &b], &Options::default());
assert_eq!(
result.err(),
Some(LinkerError::UnresolvedSymbol("foo".to_string()))
);
Ok(())
}
#[test]
fn type_mismatch() -> Result<()> {
let a = assemble_spirv(
r#"OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Import
%2 = OpTypeFloat 32
%1 = OpVariable %2 Uniform
%3 = OpVariable %2 Input"#,
);
let b = assemble_spirv(
r#"OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Export
%2 = OpTypeInt 32 0
%3 = OpConstant %2 42
%1 = OpVariable %2 Uniform %3
"#,
);
let result = assemble_and_link(&[&a, &b], &Options::default());
assert_eq!(
result.err(),
Some(LinkerError::TypeMismatch {
name: "foo".to_string(),
import_type: "OpTypeFloat 32".to_string(),
export_type: "OpTypeInt 32 0".to_string(),
})
);
Ok(())
}
#[test]
fn multiple_definitions() -> Result<()> {
let a = assemble_spirv(
r#"OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Import
%2 = OpTypeFloat 32
%1 = OpVariable %2 Uniform
%3 = OpVariable %2 Input"#,
);
let b = assemble_spirv(
r#"OpCapability Linkage
OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Export
%2 = OpTypeFloat 32
%3 = OpConstant %2 42
%1 = OpVariable %2 Uniform %3"#,
);
let c = assemble_spirv(
r#"OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Export
%2 = OpTypeFloat 32
%3 = OpConstant %2 -1
%1 = OpVariable %2 Uniform %3"#,
);
let result = assemble_and_link(&[&a, &b, &c], &Options::default());
assert_eq!(
result.err(),
Some(LinkerError::MultipleExports("foo".to_string()))
);
Ok(())
}
#[test]
fn multiple_definitions_different_types() -> Result<()> {
let a = assemble_spirv(
r#"OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Import
%2 = OpTypeFloat 32
%1 = OpVariable %2 Uniform
%3 = OpVariable %2 Input"#,
);
let b = assemble_spirv(
r#"OpCapability Linkage
OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Export
%2 = OpTypeInt 32 0
%3 = OpConstant %2 42
%1 = OpVariable %2 Uniform %3"#,
);
let c = assemble_spirv(
r#"OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Export
%2 = OpTypeFloat 32
%3 = OpConstant %2 12
%1 = OpVariable %2 Uniform %3"#,
);
let result = assemble_and_link(&[&a, &b, &c], &Options::default());
assert_eq!(
result.err(),
Some(LinkerError::MultipleExports("foo".to_string()))
);
Ok(())
}
/*
jb-todo: this isn't validated yet in the linker (see ensure_matching_import_export_pairs)
#[test]
fn decoration_mismatch() -> Result<()> {
let a = assemble_spirv(
r#"OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Import
OpDecorate %2 Constant
%2 = OpTypeFloat 32
%1 = OpVariable %2 Uniform
%3 = OpVariable %2 Input"#,
);
let b = assemble_spirv(
r#"OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Export
%2 = OpTypeFloat 32
%3 = OpConstant %2 42
%1 = OpVariable %2 Uniform %3"#,
);
let result = assemble_and_link(&[&a, &b], &Options::default());
assert_eq!(result.err(), Some(LinkerError::MultipleExports("foo".to_string())));
Ok(())
}*/
/*
jb-todo: disabled because `ensure_matching_import_export_pairs` is broken - it should recursively check type instead of doing a simple `is_type_identical`
#[test]
fn func_ctrl() -> Result<()> {
let a = assemble_spirv(
r#"OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Import
%2 = OpTypeVoid
%3 = OpTypeFunction %2
%4 = OpTypeFloat 32
%5 = OpVariable %4 Uniform
%1 = OpFunction %2 None %3
OpFunctionEnd"#,
);
let b = assemble_spirv(
r#"OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Export
%2 = OpTypeVoid
%3 = OpTypeFunction %2
%1 = OpFunction %2 Inline %3
%4 = OpLabel
OpReturn
OpFunctionEnd"#,
);
let result = assemble_and_link(&[&a, &b], &Options::default())?;
let expect = r#"OpModuleProcessed "Linked by rspirv-linker"
%1 = OpTypeVoid
%2 = OpTypeFunction %1
%3 = OpTypeFloat 32
%4 = OpVariable %3 Uniform
%5 = OpFunction %1 Inline %2
%6 = OpLabel
OpReturn
OpFunctionEnd"#;
without_header_eq(result, expect);
Ok(())
}*/
/*
#[test]
fn func_ctrl() -> Result<()> {
let a = assemble_spirv(
r#"OpCapability Kernel
OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Import
OpDecorate %2 FuncParamAttr Zext
%2 = OpDecorationGroup
OpGroupDecorate %2 %3 %4
%5 = OpTypeVoid
%6 = OpTypeInt 32 0
%7 = OpTypeFunction %5 %6
%1 = OpFunction %5 None %7
%3 = OpFunctionParameter %6
OpFunctionEnd
%8 = OpFunction %5 None %7
%4 = OpFunctionParameter %6
OpFunctionEnd
"#,
);
let b = assemble_spirv(
r#"OpCapability Kernel
OpCapability Linkage
OpDecorate %1 LinkageAttributes "foo" Export
OpDecorate %2 FuncParamAttr Sext
%3 = OpTypeVoid
%4 = OpTypeInt 32 0
%5 = OpTypeFunction %3 %4
%1 = OpFunction %3 None %5
%2 = OpFunctionParameter %4
%6 = OpLabel
OpReturn
OpFunctionEnd
"#,
);
let result = assemble_and_link(&[&a, &b], &Options::default())?;
/*
OpCapability Kernel
OpModuleProcessed "Linked by rspirv-linker"
OpDecorate %1 FuncParamAttr Sext
OpDecorate %2 FuncParamAttr Zext
%2 = OpDecorationGroup
OpGroupDecorate %2 %3 %4
%5 = OpTypeVoid
%6 = OpTypeInt 32 0
%7 = OpTypeFunction %5 %6
%8 = OpFunction %5 None %7
%4 = OpFunctionParameter %6
OpFunctionEnd
%9 = OpFunction %5 None %7
%1 = OpFunctionParameter %6
%10 = OpLabel
OpReturn
OpFunctionEnd
*/
let expect = r#"OpCapability Kernel
OpModuleProcessed "Linked by rspirv-linker"
OpDecorate %1 FuncParamAttr Zext
OpDecorate %3 FuncParamAttr Sext
%1 = OpDecorationGroup
OpGroupDecorate %1 %2
%4 = OpTypeVoid
%5 = OpTypeInt 32 0
%6 = OpTypeFunction %4 %5
%7 = OpFunction %4 None %6
%2 = OpFunctionParameter %5
OpFunctionEnd
%8 = OpFunction %4 None %6
%3 = OpFunctionParameter %5
%9 = OpLabel
OpReturn
OpFunctionEnd
"#;
without_header_eq(result, expect);
Ok(())
}*/
}