diff --git a/crates/rustc_codegen_spirv/src/builder/ext_inst.rs b/crates/rustc_codegen_spirv/src/builder/ext_inst.rs index 9ce66ab9a3..fcdaec0793 100644 --- a/crates/rustc_codegen_spirv/src/builder/ext_inst.rs +++ b/crates/rustc_codegen_spirv/src/builder/ext_inst.rs @@ -55,9 +55,12 @@ impl ExtInst { } if !bx .builder - .has_extension("SPV_INTEL_shader_integer_functions2") + .has_extension(bx.sym.spv_intel_shader_integer_functions2) { - bx.zombie(to_zombie, "extension IntegerFunctions2INTEL is required"); + bx.zombie( + to_zombie, + "extension SPV_INTEL_shader_integer_functions2 is required", + ); } } } diff --git a/crates/rustc_codegen_spirv/src/builder_spirv.rs b/crates/rustc_codegen_spirv/src/builder_spirv.rs index 5fa2ef9f46..aed107a06a 100644 --- a/crates/rustc_codegen_spirv/src/builder_spirv.rs +++ b/crates/rustc_codegen_spirv/src/builder_spirv.rs @@ -1,13 +1,15 @@ use crate::builder; use crate::codegen_cx::CodegenCx; use crate::spirv_type::SpirvType; +use crate::symbols::Symbols; use crate::target::SpirvTarget; use crate::target_feature::TargetFeature; use rspirv::dr::{Block, Builder, Module, Operand}; use rspirv::spirv::{AddressingModel, Capability, MemoryModel, Op, StorageClass, Word}; use rspirv::{binary::Assemble, binary::Disassemble}; -use rustc_data_structures::fx::FxHashMap; +use rustc_data_structures::fx::{FxHashMap, FxHashSet}; use rustc_middle::bug; +use rustc_span::symbol::Symbol; use rustc_span::{Span, DUMMY_SP}; use std::cell::{RefCell, RefMut}; use std::rc::Rc; @@ -319,40 +321,84 @@ pub struct BuilderSpirv { const_to_id: RefCell, WithConstLegality>>, id_to_const: RefCell>>, string_cache: RefCell>, + + enabled_capabilities: FxHashSet, + enabled_extensions: FxHashSet, } impl BuilderSpirv { - pub fn new(target: &SpirvTarget, features: &[TargetFeature]) -> Self { + pub fn new( + sym: &Symbols, + target: &SpirvTarget, + features: &[TargetFeature], + bindless: bool, + ) -> Self { let version = target.spirv_version(); let memory_model = target.memory_model(); let mut builder = Builder::new(); builder.set_version(version.0, version.1); + let mut enabled_capabilities = FxHashSet::default(); + let mut enabled_extensions = FxHashSet::default(); + + fn add_cap( + builder: &mut Builder, + enabled_capabilities: &mut FxHashSet, + cap: Capability, + ) { + // This should be the only callsite of Builder::capability (aside from tests), to make + // sure the hashset stays in sync. + builder.capability(cap); + enabled_capabilities.insert(cap); + } + fn add_ext(builder: &mut Builder, enabled_extensions: &mut FxHashSet, ext: Symbol) { + // This should be the only callsite of Builder::extension (aside from tests), to make + // sure the hashset stays in sync. + builder.extension(&*ext.as_str()); + enabled_extensions.insert(ext); + } + for feature in features { - match feature { - TargetFeature::Capability(cap) => builder.capability(*cap), - TargetFeature::Extension(ext) => builder.extension(&*ext.as_str()), + match *feature { + TargetFeature::Capability(cap) => { + add_cap(&mut builder, &mut enabled_capabilities, cap); + } + TargetFeature::Extension(ext) => { + add_ext(&mut builder, &mut enabled_extensions, ext); + } } } if target.is_kernel() { - builder.capability(Capability::Kernel); + add_cap(&mut builder, &mut enabled_capabilities, Capability::Kernel); } else { - builder.capability(Capability::Shader); + add_cap(&mut builder, &mut enabled_capabilities, Capability::Shader); if memory_model == MemoryModel::Vulkan { if version < (1, 5) { - builder.extension("SPV_KHR_vulkan_memory_model"); + add_ext( + &mut builder, + &mut enabled_extensions, + sym.spv_khr_vulkan_memory_model, + ); } - builder.capability(Capability::VulkanMemoryModel); + add_cap( + &mut builder, + &mut enabled_capabilities, + Capability::VulkanMemoryModel, + ); } } // The linker will always be ran on this module - builder.capability(Capability::Linkage); + add_cap(&mut builder, &mut enabled_capabilities, Capability::Linkage); let addressing_model = if target.is_kernel() { - builder.capability(Capability::Addresses); + add_cap( + &mut builder, + &mut enabled_capabilities, + Capability::Addresses, + ); AddressingModel::Physical32 } else { AddressingModel::Logical @@ -360,11 +406,26 @@ impl BuilderSpirv { builder.memory_model(addressing_model, memory_model); + if bindless { + add_ext( + &mut builder, + &mut enabled_extensions, + sym.spv_ext_descriptor_indexing, + ); + add_cap( + &mut builder, + &mut enabled_capabilities, + Capability::RuntimeDescriptorArray, + ); + } + Self { builder: RefCell::new(builder), const_to_id: Default::default(), id_to_const: Default::default(), string_cache: Default::default(), + enabled_capabilities, + enabled_extensions, } } @@ -410,27 +471,11 @@ impl BuilderSpirv { } pub fn has_capability(&self, capability: Capability) -> bool { - self.builder - .borrow() - .module_ref() - .capabilities - .iter() - .any(|inst| { - inst.class.opcode == Op::Capability - && inst.operands[0].unwrap_capability() == capability - }) + self.enabled_capabilities.contains(&capability) } - pub fn has_extension(&self, extension: &str) -> bool { - self.builder - .borrow() - .module_ref() - .extensions - .iter() - .any(|inst| { - inst.class.opcode == Op::Extension - && inst.operands[0].unwrap_literal_string() == extension - }) + pub fn has_extension(&self, extension: Symbol) -> bool { + self.enabled_extensions.contains(&extension) } pub fn select_function_by_id(&self, id: Word) -> BuilderCursor { diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs index 9c36ebb67b..df7c207b28 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs @@ -312,23 +312,17 @@ impl<'tcx> CodegenCx<'tcx> { bx.call(entry_func, &call_args, None); bx.ret_void(); - if self.bindless() { - self.emit_global().extension("SPV_EXT_descriptor_indexing"); - self.emit_global() - .capability(Capability::RuntimeDescriptorArray); + if self.bindless() && self.target.spirv_version() > (1, 3) { + let sets = self.bindless_descriptor_sets.borrow().unwrap(); - if self.target.spirv_version() > (1, 3) { - let sets = self.bindless_descriptor_sets.borrow().unwrap(); + op_entry_point_interface_operands.push(sets.buffers); - op_entry_point_interface_operands.push(sets.buffers); - - //op_entry_point_interface_operands - // .push(sets.sampled_image_1d); - // op_entry_point_interface_operands - // .push(sets.sampled_image_2d); - //op_entry_point_interface_operands - //.push(sets.sampled_image_3d); - } + //op_entry_point_interface_operands + // .push(sets.sampled_image_1d); + // op_entry_point_interface_operands + // .push(sets.sampled_image_2d); + //op_entry_point_interface_operands + //.push(sets.sampled_image_3d); } let stub_fn_id = stub_fn.def_cx(self); diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs b/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs index 56d513d347..9b829be474 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs @@ -132,7 +132,7 @@ impl<'tcx> CodegenCx<'tcx> { let result = Self { tcx, codegen_unit, - builder: BuilderSpirv::new(&target, &features), + builder: BuilderSpirv::new(&sym, &target, &features, bindless), instances: Default::default(), function_parameter_values: Default::default(), type_cache: Default::default(), diff --git a/crates/rustc_codegen_spirv/src/symbols.rs b/crates/rustc_codegen_spirv/src/symbols.rs index 5f6d804012..1837f7de87 100644 --- a/crates/rustc_codegen_spirv/src/symbols.rs +++ b/crates/rustc_codegen_spirv/src/symbols.rs @@ -21,6 +21,9 @@ pub struct Symbols { pub libm: Symbol, pub num_traits: Symbol, pub entry_point_name: Symbol, + pub spv_intel_shader_integer_functions2: Symbol, + pub spv_khr_vulkan_memory_model: Symbol, + pub spv_ext_descriptor_indexing: Symbol, descriptor_set: Symbol, binding: Symbol, input_attachment_index: Symbol, @@ -366,11 +369,16 @@ impl Symbols { Self { fmt_decimal: Symbol::intern("fmt_decimal"), - entry_point_name: Symbol::intern("entry_point_name"), spirv: Symbol::intern("spirv"), spirv_std: Symbol::intern("spirv_std"), libm: Symbol::intern("libm"), num_traits: Symbol::intern("num_traits"), + entry_point_name: Symbol::intern("entry_point_name"), + spv_intel_shader_integer_functions2: Symbol::intern( + "SPV_INTEL_shader_integer_functions2", + ), + spv_khr_vulkan_memory_model: Symbol::intern("SPV_KHR_vulkan_memory_model"), + spv_ext_descriptor_indexing: Symbol::intern("SPV_EXT_descriptor_indexing"), descriptor_set: Symbol::intern("descriptor_set"), binding: Symbol::intern("binding"), input_attachment_index: Symbol::intern("input_attachment_index"),