diff --git a/rspirv-linker/src/capability_computation.rs b/rspirv-linker/src/capability_computation.rs new file mode 100644 index 0000000000..dca0325f43 --- /dev/null +++ b/rspirv-linker/src/capability_computation.rs @@ -0,0 +1,85 @@ +use rspirv::dr::{Module, Operand}; +use rspirv::spirv::{Capability, Op}; +use std::collections::HashSet; + +pub fn remove_extra_capabilities(module: &mut Module) { + remove_capabilities(module, &compute_capabilities(module)); +} + +// TODO: This is enormously unimplemented +fn compute_capabilities(module: &Module) -> HashSet { + let mut set = HashSet::new(); + for inst in module.all_inst_iter() { + set.extend(inst.class.capabilities); + match inst.class.opcode { + Op::TypeInt => match inst.operands[0] { + Operand::LiteralInt32(width) => match width { + 8 => { + set.insert(Capability::Int8); + } + 16 => { + set.insert(Capability::Int16); + } + 64 => { + set.insert(Capability::Int64); + } + _ => {} + }, + _ => panic!(), + }, + Op::TypeFloat => match inst.operands[0] { + Operand::LiteralInt32(width) => match width { + 16 => { + set.insert(Capability::Float16); + } + 64 => { + set.insert(Capability::Float64); + } + _ => {} + }, + _ => panic!(), + }, + _ => {} + } + } + // always keep these capabilities, for now + set.insert(Capability::Addresses); + set.insert(Capability::Kernel); + set.insert(Capability::Shader); + set.insert(Capability::VariablePointers); + set.insert(Capability::VulkanMemoryModel); + set +} + +fn remove_capabilities(module: &mut Module, set: &HashSet) { + module.capabilities.retain(|inst| { + inst.class.opcode != Op::Capability + || set.contains(match &inst.operands[0] { + Operand::Capability(s) => s, + _ => panic!(), + }) + }); +} + +pub fn remove_extra_extensions(module: &mut Module) { + // TODO: Make this more generalized once this gets more advanced. + let has_intel_integer_cap = module.capabilities.iter().any(|inst| { + inst.class.opcode == Op::Capability + && match inst.operands[0] { + Operand::Capability(s) => s == Capability::IntegerFunctions2INTEL, + _ => panic!(), + } + }); + if !has_intel_integer_cap { + module.extensions.retain(|inst| { + inst.class.opcode != Op::Extension + || match &inst.operands[0] { + Operand::LiteralString(s) if s == "SPV_INTEL_shader_integer_functions2" => { + false + } + Operand::LiteralString(_) => true, + _ => panic!(), + } + }) + } +} diff --git a/rspirv-linker/src/duplicates.rs b/rspirv-linker/src/duplicates.rs index 989e837092..178f260f4d 100644 --- a/rspirv-linker/src/duplicates.rs +++ b/rspirv-linker/src/duplicates.rs @@ -18,20 +18,13 @@ pub fn remove_duplicate_extensions(module: &mut Module) { pub fn remove_duplicate_capablities(module: &mut Module) { let mut set = HashSet::new(); - let mut caps = vec![]; - - for c in &module.capabilities { - let keep = match c.operands[0] { - Operand::Capability(cap) => set.insert(cap), - _ => true, - }; - - if keep { - caps.push(c.clone()); - } - } - - module.capabilities = caps; + module.capabilities.retain(|inst| { + inst.class.opcode != Op::Capability + || set.insert(match inst.operands[0] { + Operand::Capability(s) => s, + _ => panic!(), + }) + }); } pub fn remove_duplicate_ext_inst_imports(module: &mut Module) { diff --git a/rspirv-linker/src/lib.rs b/rspirv-linker/src/lib.rs index 115fc3406a..573211aa80 100644 --- a/rspirv-linker/src/lib.rs +++ b/rspirv-linker/src/lib.rs @@ -1,6 +1,7 @@ #[cfg(test)] mod test; +mod capability_computation; mod dce; mod def_analyzer; mod duplicates; @@ -13,7 +14,6 @@ use def_analyzer::DefAnalyzer; use rspirv::binary::Consumer; use rspirv::dr::{Instruction, Loader, Module, ModuleHeader, Operand}; use rspirv::spirv::{Op, Word}; -use std::env; use thiserror::Error; #[derive(Error, Debug, PartialEq)] @@ -34,6 +34,20 @@ pub enum LinkerError { pub type Result = std::result::Result; +pub struct Options { + pub dce: bool, + pub compact_ids: bool, +} + +impl Default for Options { + fn default() -> Self { + Self { + dce: true, + compact_ids: true, + } + } +} + pub fn load(bytes: &[u8]) -> Module { let mut loader = Loader::new(); rspirv::binary::parse_bytes(&bytes, &mut loader).unwrap(); @@ -72,67 +86,83 @@ fn extract_literal_u32(op: &Operand) -> u32 { } } -pub fn link(inputs: &mut [&mut Module], timer: impl Fn(&'static str) -> T) -> Result { - let merge_timer = timer("link_merge"); - // shift all the ids - let mut bound = inputs[0].header.as_ref().unwrap().bound - 1; - let version = inputs[0].header.as_ref().unwrap().version(); +pub fn link( + inputs: &mut [&mut Module], + opts: &Options, + timer: impl Fn(&'static str) -> T, +) -> Result { + let mut output = { + let _timer = timer("link_merge"); + // shift all the ids + let mut bound = inputs[0].header.as_ref().unwrap().bound - 1; + let version = inputs[0].header.as_ref().unwrap().version(); - for mut module in inputs.iter_mut().skip(1) { - simple_passes::shift_ids(&mut module, bound); - bound += module.header.as_ref().unwrap().bound - 1; - assert_eq!(version, module.header.as_ref().unwrap().version()); - } + for mut module in inputs.iter_mut().skip(1) { + simple_passes::shift_ids(&mut module, bound); + bound += module.header.as_ref().unwrap().bound - 1; + assert_eq!(version, module.header.as_ref().unwrap().version()); + } - // merge the binaries - let mut loader = Loader::new(); + // merge the binaries + let mut loader = Loader::new(); - for module in inputs.iter() { - module.all_inst_iter().for_each(|inst| { - loader.consume_instruction(inst.clone()); - }); - } + for module in inputs.iter() { + module.all_inst_iter().for_each(|inst| { + loader.consume_instruction(inst.clone()); + }); + } - let mut output = loader.module(); - let mut header = ModuleHeader::new(bound + 1); - header.set_version(version.0, version.1); - output.header = Some(header); + let mut output = loader.module(); + let mut header = ModuleHeader::new(bound + 1); + header.set_version(version.0, version.1); + output.header = Some(header); + output + }; - drop(merge_timer); - - let find_pairs_timer = timer("link_find_pairs"); // find import / export pairs - import_export_link::run(&mut output)?; - drop(find_pairs_timer); - - let remove_duplicates_timer = timer("link_remove_duplicates"); - // remove duplicates (https://github.com/KhronosGroup/SPIRV-Tools/blob/e7866de4b1dc2a7e8672867caeb0bdca49f458d3/source/opt/remove_duplicates_pass.cpp) - duplicates::remove_duplicate_extensions(&mut output); - duplicates::remove_duplicate_capablities(&mut output); - duplicates::remove_duplicate_ext_inst_imports(&mut output); - duplicates::remove_duplicate_types(&mut output); - // jb-todo: strip identical OpDecoration / OpDecorationGroups - drop(remove_duplicates_timer); - - let remove_zombies_timer = timer("link_remove_zombies"); - zombies::remove_zombies(&mut output); - drop(remove_zombies_timer); - - let block_ordering_pass_timer = timer("link_block_ordering_pass"); - for func in &mut output.functions { - simple_passes::block_ordering_pass(func); + { + let _timer = timer("link_find_pairs"); + import_export_link::run(&mut output)?; } - drop(block_ordering_pass_timer); - let sort_globals_timer = timer("link_sort_globals"); - simple_passes::sort_globals(&mut output); - drop(sort_globals_timer); - if env::var("DCE").is_ok() { + // remove duplicates (https://github.com/KhronosGroup/SPIRV-Tools/blob/e7866de4b1dc2a7e8672867caeb0bdca49f458d3/source/opt/remove_duplicates_pass.cpp) + { + let _timer = timer("link_remove_duplicates"); + duplicates::remove_duplicate_extensions(&mut output); + duplicates::remove_duplicate_capablities(&mut output); + duplicates::remove_duplicate_ext_inst_imports(&mut output); + duplicates::remove_duplicate_types(&mut output); + // jb-todo: strip identical OpDecoration / OpDecorationGroups + } + + { + let _timer = timer("link_remove_zombies"); + zombies::remove_zombies(&mut output); + } + + { + let _timer = timer("link_block_ordering_pass"); + for func in &mut output.functions { + simple_passes::block_ordering_pass(func); + } + } + { + let _timer = timer("link_sort_globals"); + simple_passes::sort_globals(&mut output); + } + + if opts.dce { let _timer = timer("link_dce"); dce::dce(&mut output); } - if env::var("NO_COMPACT_IDS").is_err() { + { + let _timer = timer("link_remove_extra_capabilities"); + capability_computation::remove_extra_capabilities(&mut output); + capability_computation::remove_extra_extensions(&mut output); + } + + if opts.compact_ids { let _timer = timer("link_compact_ids"); // compact the ids https://github.com/KhronosGroup/SPIRV-Tools/blob/e02f178a716b0c3c803ce31b9df4088596537872/source/opt/compact_ids_pass.cpp#L43 output.header.as_mut().unwrap().bound = simple_passes::compact_ids(&mut output); diff --git a/rspirv-linker/src/main.rs b/rspirv-linker/src/main.rs index c0e382a853..d1a56d568a 100644 --- a/rspirv-linker/src/main.rs +++ b/rspirv-linker/src/main.rs @@ -8,7 +8,7 @@ fn main() -> Result<()> { let mut body1 = crate::load(&body1[..]); let mut body2 = crate::load(&body2[..]); - let output = link(&mut [&mut body1, &mut body2], drop)?; + let output = link(&mut [&mut body1, &mut body2], &Options::default(), drop)?; println!("{}\n\n", output.disassemble()); Ok(()) diff --git a/rspirv-linker/src/test.rs b/rspirv-linker/src/test.rs index 218e9b9400..7e0fd05747 100644 --- a/rspirv-linker/src/test.rs +++ b/rspirv-linker/src/test.rs @@ -1,6 +1,4 @@ -use crate::link; -use crate::LinkerError; -use crate::Result; +use crate::{link, LinkerError, Options, Result}; use rspirv::dr::{Loader, Module}; // https://github.com/colin-kiegel/rust-pretty-assertions/issues/24 @@ -74,7 +72,14 @@ fn assemble_and_link(binaries: &[&[u8]]) -> crate::Result { let mut modules = binaries.iter().cloned().map(load).collect::>(); let mut modules = modules.iter_mut().collect::>(); - link(&mut modules, drop) + link( + &mut modules, + &Options { + dce: false, + compact_ids: true, + }, + drop, + ) } fn without_header_eq(mut result: Module, expected: &str) { diff --git a/rustc_codegen_spirv/src/codegen_cx/declare.rs b/rustc_codegen_spirv/src/codegen_cx/declare.rs index cbc6bf1c5c..48b28fe2e3 100644 --- a/rustc_codegen_spirv/src/codegen_cx/declare.rs +++ b/rustc_codegen_spirv/src/codegen_cx/declare.rs @@ -3,8 +3,9 @@ use crate::abi::ConvSpirvType; use crate::builder_spirv::{SpirvConst, SpirvValue, SpirvValueExt}; use crate::spirv_type::SpirvType; use crate::symbols::{parse_attr, SpirvAttribute}; +use rspirv::dr::Operand; use rspirv::spirv::{ - ExecutionMode, ExecutionModel, FunctionControl, LinkageType, StorageClass, Word, + Decoration, ExecutionMode, ExecutionModel, FunctionControl, LinkageType, StorageClass, Word, }; use rustc_attr::InlineAttr; use rustc_codegen_ssa::traits::{PreDefineMethods, StaticMethods}; @@ -15,6 +16,7 @@ use rustc_middle::ty::{Instance, ParamEnv, Ty, TypeFoldable}; use rustc_span::def_id::DefId; use rustc_target::abi::call::FnAbi; use rustc_target::abi::{Align, LayoutOf}; +use std::collections::HashMap; fn attrs_to_spirv(attrs: &CodegenFnAttrs) -> FunctionControl { let mut control = FunctionControl::NONE; @@ -170,6 +172,7 @@ impl<'tcx> CodegenCx<'tcx> { ), }; let mut emit = self.emit_global(); + let mut decoration_locations = HashMap::new(); // Create OpVariables before OpFunction so they're global instead of local vars. let arguments = entry_func_args .iter() @@ -178,8 +181,30 @@ impl<'tcx> CodegenCx<'tcx> { SpirvType::Pointer { storage_class, .. } => storage_class, other => panic!("Invalid entry arg type {}", other.debug(arg, self)), }; + let has_location = match storage_class { + StorageClass::Input | StorageClass::Output | StorageClass::UniformConstant => { + true + } + _ => false, + }; // Note: this *declares* the variable too. - emit.variable(arg, None, storage_class, None) + let variable = emit.variable(arg, None, storage_class, None); + // Assign locations from left to right, incrementing each storage class + // individually. + // TODO: Is this right for UniformConstant? Do they share locations with + // input/outpus? + if has_location { + let location = decoration_locations + .entry(storage_class) + .or_insert_with(|| 0); + emit.decorate( + variable, + Decoration::Location, + std::iter::once(Operand::LiteralInt32(*location)), + ); + *location += 1; + } + variable }) .collect::>(); let fn_id = emit diff --git a/rustc_codegen_spirv/src/lib.rs b/rustc_codegen_spirv/src/lib.rs index 12c9ecdb49..9d03553a58 100644 --- a/rustc_codegen_spirv/src/lib.rs +++ b/rustc_codegen_spirv/src/lib.rs @@ -281,12 +281,16 @@ impl CodegenBackend for SpirvCodegenBackend { return Ok(()); } + // TODO: Can we merge this sym with the one in symbols.rs? + let legalize = sess.target_features.contains(&Symbol::intern("shader")); + let timer = sess.timer("link_crate"); link::link( sess, &codegen_results, outputs, &codegen_results.crate_name.as_str(), + legalize, ); drop(timer); diff --git a/rustc_codegen_spirv/src/link.rs b/rustc_codegen_spirv/src/link.rs index 3054a40ee2..fcb2a38418 100644 --- a/rustc_codegen_spirv/src/link.rs +++ b/rustc_codegen_spirv/src/link.rs @@ -27,6 +27,7 @@ pub fn link<'a>( codegen_results: &CodegenResults, outputs: &OutputFilenames, crate_name: &str, + legalize: bool, ) { let output_metadata = sess.opts.output_types.contains_key(&OutputType::Metadata); for &crate_type in sess.crate_types().iter() { @@ -59,7 +60,7 @@ pub fn link<'a>( link_rlib(codegen_results, &out_filename); } CrateType::Executable | CrateType::Cdylib | CrateType::Dylib => { - link_exe(sess, crate_type, &out_filename, codegen_results) + link_exe(sess, crate_type, &out_filename, codegen_results, legalize) } other => panic!("CrateType {:?} not supported yet", other), } @@ -98,6 +99,7 @@ fn link_exe( crate_type: CrateType, out_filename: &Path, codegen_results: &CodegenResults, + legalize: bool, ) { let mut objects = Vec::new(); let mut rlibs = Vec::new(); @@ -118,23 +120,29 @@ fn link_exe( do_link(sess, &objects, &rlibs, out_filename); - if env::var("SPIRV_OPT").is_ok() { + let opt = env::var("SPIRV_OPT").is_ok(); + if legalize || opt { let _timer = sess.timer("link_spirv_opt"); - do_spirv_opt(out_filename); + do_spirv_opt(out_filename, legalize, opt); } } -fn do_spirv_opt(filename: &Path) { +fn do_spirv_opt(filename: &Path, legalize: bool, opt: bool) { let tmp = filename.with_extension("opt.spv"); - let status = std::process::Command::new("spirv-opt") - .args(&[ + let mut cmd = std::process::Command::new("spirv-opt"); + if legalize { + cmd.args(&[ "--before-hlsl-legalization", "--inline-entry-points-exhaustive", - "--ssa-rewrite", - "-Os", - "--eliminate-dead-const", - "--strip-debug", - ]) + ]); + if !opt { + cmd.arg("--eliminate-dead-functions"); + } + } + if opt { + cmd.args(&["-Os", "--eliminate-dead-const", "--strip-debug"]); + } + let status = cmd .arg(&filename) .arg("-o") .arg(&tmp) @@ -327,7 +335,11 @@ fn do_link(sess: &Session, objects: &[PathBuf], rlibs: &[PathBuf], out_filename: drop(load_modules_timer); // Do the link... - let link_result = rspirv_linker::link(&mut module_refs, |name| sess.timer(name)); + let options = rspirv_linker::Options { + dce: env::var("NO_DCE").is_err(), + compact_ids: env::var("NO_COMPACT_IDS").is_err(), + }; + let link_result = rspirv_linker::link(&mut module_refs, &options, |name| sess.timer(name)); let save_modules_timer = sess.timer("link_save_modules"); let assembled = match link_result {