diff --git a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs index 234d261453..d7c0964426 100644 --- a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs +++ b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs @@ -372,6 +372,15 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { } } + fn zombie_ptr_equal(&self, def: Word, inst: &str) { + if !self.builder.has_capability(Capability::VariablePointers) { + self.zombie( + def, + &format!("{} without OpCapability VariablePointers", inst), + ); + } + } + /// If possible, return the appropriate `OpAccessChain` indices for going from /// a pointer to `ty`, to a pointer to `leaf_ty`, with an added `offset`. /// @@ -1433,7 +1442,12 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { SpirvType::Pointer { .. } => match op { IntEQ => { if self.emit().version().unwrap() > (1, 3) { - self.emit().ptr_equal(b, None, lhs.def(self), rhs.def(self)) + self.emit() + .ptr_equal(b, None, lhs.def(self), rhs.def(self)) + .map(|result| { + self.zombie_ptr_equal(result, "OpPtrEqual"); + result + }) } else { let int_ty = self.type_usize(); let lhs = self @@ -1453,6 +1467,10 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { if self.emit().version().unwrap() > (1, 3) { self.emit() .ptr_not_equal(b, None, lhs.def(self), rhs.def(self)) + .map(|result| { + self.zombie_ptr_equal(result, "OpPtrNotEqual"); + result + }) } else { let int_ty = self.type_usize(); let lhs = self diff --git a/crates/rustc_codegen_spirv/src/builder/mod.rs b/crates/rustc_codegen_spirv/src/builder/mod.rs index 5f74b76d3a..edad67fabe 100644 --- a/crates/rustc_codegen_spirv/src/builder/mod.rs +++ b/crates/rustc_codegen_spirv/src/builder/mod.rs @@ -46,7 +46,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { self.emit_with_cursor(self.cursor) } - pub fn zombie(&self, word: Word, reason: &'static str) { + pub fn zombie(&self, word: Word, reason: &str) { if let Some(current_span) = self.current_span { self.zombie_with_span(word, current_span, reason); } else { diff --git a/crates/rustc_codegen_spirv/src/builder_spirv.rs b/crates/rustc_codegen_spirv/src/builder_spirv.rs index 326273f00e..45d15eb220 100644 --- a/crates/rustc_codegen_spirv/src/builder_spirv.rs +++ b/crates/rustc_codegen_spirv/src/builder_spirv.rs @@ -320,10 +320,6 @@ impl BuilderSpirv { } builder.capability(Capability::VulkanMemoryModel); } - builder.capability(Capability::VariablePointers); - if version < (1, 3) { - builder.extension("SPV_KHR_variable_pointers"); - } } // The linker will always be ran on this module diff --git a/crates/rustc_codegen_spirv/src/linker/inline.rs b/crates/rustc_codegen_spirv/src/linker/inline.rs index 4b16c723a8..f7750c23fe 100644 --- a/crates/rustc_codegen_spirv/src/linker/inline.rs +++ b/crates/rustc_codegen_spirv/src/linker/inline.rs @@ -19,7 +19,8 @@ pub fn inline(module: &mut Module) { .iter() .map(|f| (f.def_id().unwrap(), f.clone())) .collect(); - let disallowed_argument_types = compute_disallowed_argument_types(module); + let (disallowed_argument_types, disallowed_return_types) = + compute_disallowed_argument_and_return_types(module); let void = module .types_global_values .iter() @@ -30,7 +31,7 @@ pub fn inline(module: &mut Module) { // inlines in functions that will get inlined) let mut dropped_ids = FxHashSet::default(); module.functions.retain(|f| { - if should_inline(&disallowed_argument_types, f) { + if should_inline(&disallowed_argument_types, &disallowed_return_types, f) { // TODO: We should insert all defined IDs in this function. dropped_ids.insert(f.def_id().unwrap()); false @@ -51,6 +52,7 @@ pub fn inline(module: &mut Module) { void, functions: &functions, disallowed_argument_types: &disallowed_argument_types, + disallowed_return_types: &disallowed_return_types, }; for function in &mut module.functions { inliner.inline_fn(function); @@ -58,17 +60,19 @@ pub fn inline(module: &mut Module) { } } -fn compute_disallowed_argument_types(module: &Module) -> FxHashSet { +fn compute_disallowed_argument_and_return_types( + module: &Module, +) -> (FxHashSet, FxHashSet) { let allowed_argument_storage_classes = &[ StorageClass::UniformConstant, StorageClass::Function, StorageClass::Private, StorageClass::Workgroup, StorageClass::AtomicCounter, - // TODO: StorageBuffer is allowed if VariablePointers is enabled ]; let mut disallowed_argument_types = FxHashSet::default(); let mut disallowed_pointees = FxHashSet::default(); + let mut disallowed_return_types = FxHashSet::default(); for inst in &module.types_global_values { match inst.class.opcode { Op::TypePointer => { @@ -81,24 +85,19 @@ fn compute_disallowed_argument_types(module: &Module) -> FxHashSet { disallowed_argument_types.insert(inst.result_id.unwrap()); } disallowed_pointees.insert(inst.result_id.unwrap()); + disallowed_return_types.insert(inst.result_id.unwrap()); } Op::TypeStruct => { - if inst - .operands - .iter() - .map(|op| op.id_ref_any().unwrap()) - .any(|id| disallowed_argument_types.contains(&id)) - { + let fields = || inst.operands.iter().map(|op| op.id_ref_any().unwrap()); + if fields().any(|id| disallowed_argument_types.contains(&id)) { disallowed_argument_types.insert(inst.result_id.unwrap()); } - if inst - .operands - .iter() - .map(|op| op.id_ref_any().unwrap()) - .any(|id| disallowed_pointees.contains(&id)) - { + if fields().any(|id| disallowed_pointees.contains(&id)) { disallowed_pointees.insert(inst.result_id.unwrap()); } + if fields().any(|id| disallowed_return_types.contains(&id)) { + disallowed_return_types.insert(inst.result_id.unwrap()); + } } Op::TypeArray | Op::TypeRuntimeArray | Op::TypeVector => { let id = inst.operands[0].id_ref_any().unwrap(); @@ -112,10 +111,14 @@ fn compute_disallowed_argument_types(module: &Module) -> FxHashSet { _ => {} } } - disallowed_argument_types + (disallowed_argument_types, disallowed_return_types) } -fn should_inline(disallowed_argument_types: &FxHashSet, function: &Function) -> bool { +fn should_inline( + disallowed_argument_types: &FxHashSet, + disallowed_return_types: &FxHashSet, + function: &Function, +) -> bool { let def = function.def.as_ref().unwrap(); let control = def.operands[0].unwrap_function_control(); control.contains(FunctionControl::INLINE) @@ -123,6 +126,7 @@ fn should_inline(disallowed_argument_types: &FxHashSet, function: &Functio .parameters .iter() .any(|inst| disallowed_argument_types.contains(inst.result_type.as_ref().unwrap())) + || disallowed_return_types.contains(&function.def.as_ref().unwrap().result_type.unwrap()) } // Steps: @@ -137,6 +141,7 @@ struct Inliner<'m, 'map> { void: Word, functions: &'map FunctionMap, disallowed_argument_types: &'map FxHashSet, + disallowed_return_types: &'map FxHashSet, // rewrite_rules: FxHashMap, } @@ -198,7 +203,13 @@ impl Inliner<'_, '_> { .unwrap(), ) }) - .find(|(_, _, f)| should_inline(self.disallowed_argument_types, f)); + .find(|(_, _, f)| { + should_inline( + self.disallowed_argument_types, + self.disallowed_return_types, + f, + ) + }); let (call_index, call_inst, callee) = match call { None => return false, Some(call) => call, diff --git a/crates/spirv-builder/src/test/basic.rs b/crates/spirv-builder/src/test/basic.rs index c8f89ba022..d12648d89c 100644 --- a/crates/spirv-builder/src/test/basic.rs +++ b/crates/spirv-builder/src/test/basic.rs @@ -27,7 +27,6 @@ fn custom_entry_point() { pub fn main() { } "#, r#"OpCapability Shader -OpCapability VariablePointers OpMemoryModel Logical Simple OpEntryPoint Fragment %1 "hello_world" OpExecutionMode %1 OriginUpperLeft @@ -171,7 +170,6 @@ fn asm_op_decorate() { add_decorate(); }"#, r#"OpCapability Shader -OpCapability VariablePointers OpCapability RuntimeDescriptorArray OpExtension "SPV_EXT_descriptor_indexing" OpMemoryModel Logical Simple diff --git a/tests/ui/lang/consts/nested-ref.rs b/tests/ui/lang/consts/nested-ref.rs index 58a8799f4c..096550abcf 100644 --- a/tests/ui/lang/consts/nested-ref.rs +++ b/tests/ui/lang/consts/nested-ref.rs @@ -26,6 +26,7 @@ pub fn main( bool_out: &mut bool, vec_out: &mut Vec2, ) { + unsafe { asm!("OpCapability VariablePointers") }; *scalar_out = deep_load(&&123); *bool_out = vec_in == &Vec2::ZERO; *vec_out = deep_transpose(&ROT90) * *vec_in;