From 67746012f5ccc09cbe4a915128a5114e12701072 Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Wed, 10 Feb 2021 16:34:57 +0200 Subject: [PATCH] Infer storage classes using the specializer, replacing special pointer types. --- crates/rustc_codegen_spirv/src/abi.rs | 89 +++++++------------ .../src/builder/builder_methods.rs | 56 ++++-------- .../src/builder/intrinsics.rs | 8 +- crates/rustc_codegen_spirv/src/builder/mod.rs | 16 +--- .../src/builder/spirv_asm.rs | 55 ++++++++---- .../rustc_codegen_spirv/src/builder_spirv.rs | 5 +- .../src/codegen_cx/constant.rs | 25 +----- .../src/codegen_cx/declare.rs | 11 +-- .../src/codegen_cx/entry.rs | 49 +++++----- .../rustc_codegen_spirv/src/codegen_cx/mod.rs | 9 +- .../src/codegen_cx/type_.rs | 18 +--- crates/rustc_codegen_spirv/src/linker/mod.rs | 54 ++++++++++- crates/rustc_codegen_spirv/src/spirv_type.rs | 40 ++++----- crates/spirv-builder/src/test/basic.rs | 14 ++- 14 files changed, 208 insertions(+), 241 deletions(-) diff --git a/crates/rustc_codegen_spirv/src/abi.rs b/crates/rustc_codegen_spirv/src/abi.rs index 1f10d1e5c4..65b5867539 100644 --- a/crates/rustc_codegen_spirv/src/abi.rs +++ b/crates/rustc_codegen_spirv/src/abi.rs @@ -27,19 +27,12 @@ use std::fmt; /// tracking. #[derive(Default)] pub struct RecursivePointeeCache<'tcx> { - map: RefCell, StorageClass), PointeeDefState>>, + map: RefCell, PointeeDefState>>, } impl<'tcx> RecursivePointeeCache<'tcx> { - fn begin( - &self, - cx: &CodegenCx<'tcx>, - span: Span, - pointee: PointeeTy<'tcx>, - storage_class: StorageClass, - ) -> Option { - // Warning: storage_class must match the one called with end() - match self.map.borrow_mut().entry((pointee, storage_class)) { + fn begin(&self, cx: &CodegenCx<'tcx>, span: Span, pointee: PointeeTy<'tcx>) -> Option { + match self.map.borrow_mut().entry(pointee) { // State: This is the first time we've seen this type. Record that we're beginning to translate this type, // and start doing the translation. Entry::Vacant(entry) => { @@ -52,7 +45,11 @@ impl<'tcx> RecursivePointeeCache<'tcx> { // emit an OpTypeForwardPointer, and use that ID. (This is the juicy part of this algorithm) PointeeDefState::Defining => { let new_id = cx.emit_global().id(); - cx.emit_global().type_forward_pointer(new_id, storage_class); + // NOTE(eddyb) we emit `StorageClass::Generic` here, but later + // the linker will specialize the entire SPIR-V module to use + // storage classes inferred from `OpVariable`s. + cx.emit_global() + .type_forward_pointer(new_id, StorageClass::Generic); entry.insert(PointeeDefState::DefiningWithForward(new_id)); if !cx.builder.has_capability(Capability::Addresses) && !cx @@ -81,11 +78,9 @@ impl<'tcx> RecursivePointeeCache<'tcx> { cx: &CodegenCx<'tcx>, span: Span, pointee: PointeeTy<'tcx>, - storage_class: StorageClass, pointee_spv: Word, ) -> Word { - // Warning: storage_class must match the one called with begin() - match self.map.borrow_mut().entry((pointee, storage_class)) { + match self.map.borrow_mut().entry(pointee) { // We should have hit begin() on this type already, which always inserts an entry. Entry::Vacant(_) => bug!("RecursivePointeeCache::end should always have entry"), Entry::Occupied(mut entry) => match *entry.get() { @@ -93,7 +88,6 @@ impl<'tcx> RecursivePointeeCache<'tcx> { // OpTypeForwardPointer has been emitted. This is the most common case. PointeeDefState::Defining => { let id = SpirvType::Pointer { - storage_class, pointee: pointee_spv, } .def(span, cx); @@ -105,7 +99,6 @@ impl<'tcx> RecursivePointeeCache<'tcx> { PointeeDefState::DefiningWithForward(id) => { entry.insert(PointeeDefState::Defined(id)); SpirvType::Pointer { - storage_class, pointee: pointee_spv, } .def_with_id(cx, span, id) @@ -261,11 +254,7 @@ impl<'tcx> ConvSpirvType<'tcx> for FnAbi<'tcx, Ty<'tcx>> { PassMode::Cast(cast_target) => cast_target.spirv_type(span, cx), PassMode::Indirect { .. } => { let pointee = self.ret.layout.spirv_type(span, cx); - let pointer = SpirvType::Pointer { - storage_class: StorageClass::Function, - pointee, - } - .def(span, cx); + let pointer = SpirvType::Pointer { pointee }.def(span, cx); // Important: the return pointer comes *first*, not last. argument_types.push(pointer); SpirvType::Void.def(span, cx) @@ -304,11 +293,7 @@ impl<'tcx> ConvSpirvType<'tcx> for FnAbi<'tcx, Ty<'tcx>> { extra_attrs: None, .. } => { let pointee = arg.layout.spirv_type(span, cx); - SpirvType::Pointer { - storage_class: StorageClass::Function, - pointee, - } - .def(span, cx) + SpirvType::Pointer { pointee }.def(span, cx) } }; argument_types.push(arg_type); @@ -465,26 +450,20 @@ fn trans_scalar<'tcx>( Primitive::F32 => SpirvType::Float(32).def(span, cx), Primitive::F64 => SpirvType::Float(64).def(span, cx), Primitive::Pointer => { - let (storage_class, pointee_ty) = dig_scalar_pointee(cx, ty, index); - // Default to function storage class. - let storage_class = storage_class.unwrap_or(StorageClass::Function); + let pointee_ty = dig_scalar_pointee(cx, ty, index); // Pointers can be recursive. So, record what we're currently translating, and if we're already translating // the same type, emit an OpTypeForwardPointer and use that ID. - if let Some(predefined_result) = - cx.type_cache - .recursive_pointee_cache - .begin(cx, span, pointee_ty, storage_class) + if let Some(predefined_result) = cx + .type_cache + .recursive_pointee_cache + .begin(cx, span, pointee_ty) { predefined_result } else { let pointee = pointee_ty.spirv_type(span, cx); - cx.type_cache.recursive_pointee_cache.end( - cx, - span, - pointee_ty, - storage_class, - pointee, - ) + cx.type_cache + .recursive_pointee_cache + .end(cx, span, pointee_ty, pointee) } } } @@ -504,12 +483,12 @@ fn dig_scalar_pointee<'tcx>( cx: &CodegenCx<'tcx>, ty: TyAndLayout<'tcx>, index: Option, -) -> (Option, PointeeTy<'tcx>) { +) -> PointeeTy<'tcx> { match *ty.ty.kind() { TyKind::Ref(_, elem_ty, _) | TyKind::RawPtr(TypeAndMut { ty: elem_ty, .. }) => { let elem = cx.layout_of(elem_ty); match index { - None => (None, PointeeTy::Ty(elem)), + None => PointeeTy::Ty(elem), Some(index) => { if elem.is_unsized() { dig_scalar_pointee(cx, ty.field(cx, index), None) @@ -518,12 +497,12 @@ fn dig_scalar_pointee<'tcx>( // of ScalarPair could be deduced, but it's actually e.g. a sized pointer followed by some other // completely unrelated type, not a wide pointer. So, translate this as a single scalar, one // component of that ScalarPair. - (None, PointeeTy::Ty(elem)) + PointeeTy::Ty(elem) } } } } - TyKind::FnPtr(sig) if index.is_none() => (None, PointeeTy::Fn(sig)), + TyKind::FnPtr(sig) if index.is_none() => PointeeTy::Fn(sig), TyKind::Adt(def, _) if def.is_box() => { let ptr_ty = cx.layout_of(cx.tcx.mk_mut_ptr(ty.ty.boxed_ty())); dig_scalar_pointee(cx, ptr_ty, index) @@ -542,11 +521,8 @@ fn dig_scalar_pointee_adt<'tcx>( cx: &CodegenCx<'tcx>, ty: TyAndLayout<'tcx>, index: Option, -) -> (Option, PointeeTy<'tcx>) { - // Storage classes can only be applied on structs containing a single pointer field (because we said so), so we only - // need to handle the attribute here. - let storage_class = get_storage_class(cx, ty); - let result = match &ty.variants { +) -> PointeeTy<'tcx> { + match &ty.variants { // If it's a Variants::Multiple, then we want to emit the type of the dataful variant, not the type of the // discriminant. This is because the discriminant can e.g. have type *mut(), whereas we want the full underlying // type, only available in the dataful variant. @@ -603,21 +579,16 @@ fn dig_scalar_pointee_adt<'tcx>( }, } } - }; - match (storage_class, result) { - (storage_class, (None, result)) => (storage_class, result), - (None, (storage_class, result)) => (storage_class, result), - (Some(one), (Some(two), _)) => cx.tcx.sess.fatal(&format!( - "Double-applied storage class ({:?} and {:?}) on type {}", - one, two, ty.ty - )), } } -/// Handles `#[spirv(storage_class="blah")]`. Note this is only called in the scalar translation code, because this is only +/// Handles `#[spirv(storage_class="blah")]`. Note this is only called in the entry interface variables code, because this is only /// used for spooky builtin stuff, and we pinky promise to never have more than one pointer field in one of these. // TODO: Enforce this is only used in spirv-std. -fn get_storage_class<'tcx>(cx: &CodegenCx<'tcx>, ty: TyAndLayout<'tcx>) -> Option { +pub(crate) fn get_storage_class<'tcx>( + cx: &CodegenCx<'tcx>, + ty: TyAndLayout<'tcx>, +) -> Option { if let TyKind::Adt(adt, _substs) = ty.ty.kind() { for attr in parse_attrs(cx, cx.tcx.get_attrs(adt.did)) { if let SpirvAttribute::StorageClass(storage_class) = attr { diff --git a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs index 9db323ca16..ecb68be8c2 100644 --- a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs +++ b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs @@ -727,11 +727,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { } fn alloca(&mut self, ty: Self::Type, _align: Align) -> Self::Value { - let ptr_ty = SpirvType::Pointer { - storage_class: StorageClass::Function, - pointee: ty, - } - .def(self.span(), self); + let ptr_ty = SpirvType::Pointer { pointee: ty }.def(self.span(), self); // "All OpVariable instructions in a function must be the first instructions in the first block." let mut builder = self.emit(); builder.select_block(Some(0)).unwrap(); @@ -779,10 +775,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { return value; } let ty = match self.lookup_type(ptr.ty) { - SpirvType::Pointer { - storage_class: _, - pointee, - } => pointee, + SpirvType::Pointer { pointee } => pointee, ty => self.fatal(&format!( "load called on variable that wasn't a pointer: {:?}", ty @@ -803,10 +796,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { fn atomic_load(&mut self, ptr: Self::Value, order: AtomicOrdering, _size: Size) -> Self::Value { let ty = match self.lookup_type(ptr.ty) { - SpirvType::Pointer { - storage_class: _, - pointee, - } => pointee, + SpirvType::Pointer { pointee } => pointee, ty => self.fatal(&format!( "atomic_load called on variable that wasn't a pointer: {:?}", ty @@ -906,10 +896,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { fn store(&mut self, val: Self::Value, ptr: Self::Value, _align: Align) -> Self::Value { let ptr_elem_ty = match self.lookup_type(ptr.ty) { - SpirvType::Pointer { - storage_class: _, - pointee, - } => pointee, + SpirvType::Pointer { pointee } => pointee, ty => self.fatal(&format!( "store called on variable that wasn't a pointer: {:?}", ty @@ -946,10 +933,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { _size: Size, ) { let ptr_elem_ty = match self.lookup_type(ptr.ty) { - SpirvType::Pointer { - storage_class: _, - pointee, - } => pointee, + SpirvType::Pointer { pointee } => pointee, ty => self.fatal(&format!( "atomic_store called on variable that wasn't a pointer: {:?}", ty @@ -979,15 +963,12 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { } fn struct_gep(&mut self, ptr: Self::Value, idx: u64) -> Self::Value { - let (storage_class, result_pointee_type) = match self.lookup_type(ptr.ty) { - SpirvType::Pointer { - storage_class, - pointee, - } => match self.lookup_type(pointee) { - SpirvType::Adt { field_types, .. } => (storage_class, field_types[idx as usize]), + let result_pointee_type = match self.lookup_type(ptr.ty) { + SpirvType::Pointer { pointee } => match self.lookup_type(pointee) { + SpirvType::Adt { field_types, .. } => field_types[idx as usize], SpirvType::Array { element, .. } | SpirvType::RuntimeArray { element, .. } - | SpirvType::Vector { element, .. } => (storage_class, element), + | SpirvType::Vector { element, .. } => element, other => self.fatal(&format!( "struct_gep not on struct, array, or vector type: {:?}, index {}", other, idx @@ -999,7 +980,6 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { )), }; let result_type = SpirvType::Pointer { - storage_class, pointee: result_pointee_type, } .def(self.span(), self); @@ -1225,14 +1205,14 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { fn pointercast(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value { let val_pointee = match self.lookup_type(val.ty) { - SpirvType::Pointer { pointee, .. } => pointee, + SpirvType::Pointer { pointee } => pointee, other => self.fatal(&format!( "pointercast called on non-pointer source type: {:?}", other )), }; let dest_pointee = match self.lookup_type(dest_ty) { - SpirvType::Pointer { pointee, .. } => pointee, + SpirvType::Pointer { pointee } => pointee, other => self.fatal(&format!( "pointercast called on non-pointer dest type: {:?}", other @@ -1531,7 +1511,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { return; } let src_pointee = match self.lookup_type(src.ty) { - SpirvType::Pointer { pointee, .. } => Some(pointee), + SpirvType::Pointer { pointee } => Some(pointee), _ => None, }; let src_element_size = src_pointee.and_then(|p| self.lookup_type(p).sizeof(self)); @@ -1592,7 +1572,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { )); } let elem_ty = match self.lookup_type(ptr.ty) { - SpirvType::Pointer { pointee, .. } => pointee, + SpirvType::Pointer { pointee } => pointee, _ => self.fatal(&format!( "memset called on non-pointer type: {}", self.debug_type(ptr.ty) @@ -1780,10 +1760,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { _weak: bool, ) -> Self::Value { let dst_pointee_ty = match self.lookup_type(dst.ty) { - SpirvType::Pointer { - storage_class: _, - pointee, - } => pointee, + SpirvType::Pointer { pointee } => pointee, ty => self.fatal(&format!( "atomic_cmpxchg called on variable that wasn't a pointer: {:?}", ty @@ -1820,10 +1797,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { order: AtomicOrdering, ) -> Self::Value { let dst_pointee_ty = match self.lookup_type(dst.ty) { - SpirvType::Pointer { - storage_class: _, - pointee, - } => pointee, + SpirvType::Pointer { pointee } => pointee, ty => self.fatal(&format!( "atomic_rmw called on variable that wasn't a pointer: {:?}", ty diff --git a/crates/rustc_codegen_spirv/src/builder/intrinsics.rs b/crates/rustc_codegen_spirv/src/builder/intrinsics.rs index 0ca218fa30..2103723646 100644 --- a/crates/rustc_codegen_spirv/src/builder/intrinsics.rs +++ b/crates/rustc_codegen_spirv/src/builder/intrinsics.rs @@ -3,7 +3,7 @@ use crate::abi::ConvSpirvType; use crate::builder_spirv::{SpirvValue, SpirvValueExt}; use crate::codegen_cx::CodegenCx; use crate::spirv_type::SpirvType; -use rspirv::spirv::{CLOp, GLOp, StorageClass}; +use rspirv::spirv::{CLOp, GLOp}; use rustc_codegen_ssa::mir::operand::OperandRef; use rustc_codegen_ssa::mir::place::PlaceRef; use rustc_codegen_ssa::traits::{BuilderMethods, IntrinsicCallMethods}; @@ -103,11 +103,7 @@ impl<'a, 'tcx> IntrinsicCallMethods<'tcx> for Builder<'a, 'tcx> { let mut ptr = args[0].immediate(); if let PassMode::Cast(ty) = fn_abi.ret.mode { let pointee = ty.spirv_type(self.span(), self); - let pointer = SpirvType::Pointer { - storage_class: StorageClass::Function, - pointee, - } - .def(self.span(), self); + let pointer = SpirvType::Pointer { pointee }.def(self.span(), self); ptr = self.pointercast(ptr, pointer); } let load = self.volatile_load(ptr); diff --git a/crates/rustc_codegen_spirv/src/builder/mod.rs b/crates/rustc_codegen_spirv/src/builder/mod.rs index 2895e8a241..5f74b76d3a 100644 --- a/crates/rustc_codegen_spirv/src/builder/mod.rs +++ b/crates/rustc_codegen_spirv/src/builder/mod.rs @@ -12,7 +12,7 @@ use crate::abi::ConvSpirvType; use crate::builder_spirv::{BuilderCursor, SpirvValue, SpirvValueExt}; use crate::codegen_cx::CodegenCx; use crate::spirv_type::SpirvType; -use rspirv::spirv::{StorageClass, Word}; +use rspirv::spirv::Word; use rustc_codegen_ssa::mir::operand::OperandValue; use rustc_codegen_ssa::mir::place::PlaceRef; use rustc_codegen_ssa::traits::{ @@ -104,11 +104,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { // "An OpAccessChain instruction is the equivalent of an LLVM getelementptr instruction where the first index element is zero." // https://github.com/gpuweb/gpuweb/issues/33 let mut result_indices = Vec::with_capacity(indices.len() - 1); - let (storage_class, mut result_pointee_type) = match self.lookup_type(ptr.ty) { - SpirvType::Pointer { - storage_class, - pointee, - } => (storage_class, pointee), + let mut result_pointee_type = match self.lookup_type(ptr.ty) { + SpirvType::Pointer { pointee } => pointee, other_type => self.fatal(&format!( "GEP first deref not implemented for type {:?}", other_type @@ -125,7 +122,6 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { }; } let result_type = SpirvType::Pointer { - storage_class, pointee: result_pointee_type, } .def(self.span(), self); @@ -330,11 +326,7 @@ impl<'a, 'tcx> ArgAbiMethods<'tcx> for Builder<'a, 'tcx> { self.fatal("unsized `ArgAbi` must be handled through `store_fn_arg`"); } else if let PassMode::Cast(cast) = arg_abi.mode { let cast_ty = cast.spirv_type(self.span(), self); - let cast_ptr_ty = SpirvType::Pointer { - storage_class: StorageClass::Function, - pointee: cast_ty, - } - .def(self.span(), self); + let cast_ptr_ty = SpirvType::Pointer { pointee: cast_ty }.def(self.span(), self); let cast_dst = self.pointercast(dst.llval, cast_ptr_ty); self.store(val, cast_dst, arg_abi.layout.align.abi); } else { diff --git a/crates/rustc_codegen_spirv/src/builder/spirv_asm.rs b/crates/rustc_codegen_spirv/src/builder/spirv_asm.rs index 35efbf261f..13c0c4caf1 100644 --- a/crates/rustc_codegen_spirv/src/builder/spirv_asm.rs +++ b/crates/rustc_codegen_spirv/src/builder/spirv_asm.rs @@ -263,11 +263,25 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { element: inst.operands[0].unwrap_id_ref(), } .def(self.span(), self), - Op::TypePointer => SpirvType::Pointer { - storage_class: inst.operands[0].unwrap_storage_class(), - pointee: inst.operands[1].unwrap_id_ref(), + Op::TypePointer => { + let storage_class = inst.operands[0].unwrap_storage_class(); + if storage_class != StorageClass::Generic { + self.struct_err("TypePointer in asm! requires `Generic` storage class") + .note(&format!( + "`{:?}` storage class was specified", + storage_class + )) + .help(&format!( + "the storage class will be inferred automatically (e.g. to `{:?}`)", + storage_class + )) + .emit(); + } + SpirvType::Pointer { + pointee: inst.operands[1].unwrap_id_ref(), + } + .def(self.span(), self) } - .def(self.span(), self), Op::TypeImage => SpirvType::Image { sampled_type: inst.operands[0].unwrap_id_ref(), dim: inst.operands[1].unwrap_dim(), @@ -511,24 +525,26 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { use crate::spirv_type_constraints::{instruction_signatures, InstSig, TyListPat, TyPat}; #[derive(Debug)] - struct Mismatch; + struct Unapplicable; /// Recursively match `ty` against `pat`, returning one of: /// * `Ok(None)`: `pat` matched but contained no type variables /// * `Ok(Some(var))`: `pat` matched and `var` is the type variable /// * `Err(Mismatch)`: `pat` didn't match or isn't supported right now - fn apply_ty_pat( + fn match_ty_pat( cx: &CodegenCx<'_>, pat: &TyPat<'_>, ty: Word, - ) -> Result, Mismatch> { + ) -> Result, Unapplicable> { match pat { TyPat::Any => Ok(None), &TyPat::T => Ok(Some(ty)), TyPat::Either(a, b) => { - apply_ty_pat(cx, a, ty).or_else(|Mismatch| apply_ty_pat(cx, b, ty)) + match_ty_pat(cx, a, ty).or_else(|Unapplicable| match_ty_pat(cx, b, ty)) } _ => match (pat, cx.lookup_type(ty)) { + (TyPat::Any, _) | (&TyPat::T, _) | (TyPat::Either(..), _) => unreachable!(), + (TyPat::Void, SpirvType::Void) => Ok(None), (TyPat::Pointer(_, pat), SpirvType::Pointer { pointee: ty, .. }) | (TyPat::Vector(pat), SpirvType::Vector { element: ty, .. }) @@ -546,9 +562,9 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { }, ) | (TyPat::SampledImage(pat), SpirvType::SampledImage { image_type: ty }) => { - apply_ty_pat(cx, pat, ty) + match_ty_pat(cx, pat, ty) } - _ => Err(Mismatch), + _ => Err(Unapplicable), }, } } @@ -567,17 +583,19 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { let mut ids = instruction.operands.iter().filter_map(|o| o.id_ref_any()); while let TyListPat::Cons { first: pat, suffix } = *sig.input_types { - let apply_result = match id_to_type_map.get(&ids.next()?) { - Some(&ty) => apply_ty_pat(self, pat, ty), + sig.input_types = suffix; + + let match_result = match id_to_type_map.get(&ids.next()?) { + Some(&ty) => match_ty_pat(self, pat, ty), // Non-value ID operand (or value operand of unknown type), // only `TyPat::Any` is valid. None => match pat { TyPat::Any => Ok(None), - _ => Err(Mismatch), + _ => Err(Unapplicable), }, }; - match apply_result { + match match_result { Ok(Some(var)) => match combined_var { Some(combined_var) => { // FIXME(eddyb) this could use some error reporting @@ -591,11 +609,12 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { None => combined_var = Some(var), }, Ok(None) => {} - Err(Mismatch) => return None, + Err(Unapplicable) => return None, } - sig.input_types = suffix; } match sig.input_types { + TyListPat::Cons { .. } => unreachable!(), + TyListPat::Any => {} TyListPat::Nil => { if ids.next().is_some() { @@ -742,7 +761,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { Some(match kind { TypeofKind::Plain => ty, TypeofKind::Dereference => match self.lookup_type(ty) { - SpirvType::Pointer { pointee, .. } => pointee, + SpirvType::Pointer { pointee } => pointee, other => { self.tcx.sess.span_err( span, @@ -764,7 +783,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { self.check_reg(span, reg); match place { Some(place) => match self.lookup_type(place.llval.ty) { - SpirvType::Pointer { pointee, .. } => Some(pointee), + SpirvType::Pointer { pointee } => Some(pointee), other => { self.tcx.sess.span_err( span, diff --git a/crates/rustc_codegen_spirv/src/builder_spirv.rs b/crates/rustc_codegen_spirv/src/builder_spirv.rs index 12c705f050..caf3de1cdb 100644 --- a/crates/rustc_codegen_spirv/src/builder_spirv.rs +++ b/crates/rustc_codegen_spirv/src/builder_spirv.rs @@ -44,10 +44,7 @@ impl SpirvValue { global_var: _, } => { let ty = match cx.lookup_type(self.ty) { - SpirvType::Pointer { - storage_class: _, - pointee, - } => pointee, + SpirvType::Pointer { pointee } => pointee, ty => bug!("load called on variable that wasn't a pointer: {:?}", ty), }; Some(initializer.with_type(ty)) diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs b/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs index c9bbbbdf76..0245828133 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs @@ -271,7 +271,7 @@ impl<'tcx> ConstMethods<'tcx> for CodegenCx<'tcx> { let (base_addr, _base_addr_space) = match self.tcx.global_alloc(ptr.alloc_id) { GlobalAlloc::Memory(alloc) => { let pointee = match self.lookup_type(ty) { - SpirvType::Pointer { pointee, .. } => pointee, + SpirvType::Pointer { pointee } => pointee, other => self.tcx.sess.fatal(&format!( "GlobalAlloc::Memory type not implemented: {}", other.debug(ty, self) @@ -306,28 +306,7 @@ impl<'tcx> ConstMethods<'tcx> for CodegenCx<'tcx> { .fatal("Non-pointer-typed scalar_to_backend Scalar::Ptr not supported"); // unsafe { llvm::LLVMConstPtrToInt(llval, llty) } } else { - match (self.lookup_type(value.ty), self.lookup_type(ty)) { - ( - SpirvType::Pointer { - storage_class: a_space, - pointee: a, - }, - SpirvType::Pointer { - storage_class: b_space, - pointee: b, - }, - ) => { - if a_space != b_space { - // TODO: Emit the correct type that is passed into this function. - self.zombie_no_span( - value.def_cx(self), - "invalid pointer space in constant", - ); - } - assert_ty_eq!(self, a, b); - } - _ => assert_ty_eq!(self, value.ty, ty), - } + assert_ty_eq!(self, value.ty, ty); value } } diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs b/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs index 6d0156fe29..77a548bc11 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs @@ -184,14 +184,11 @@ impl<'tcx> CodegenCx<'tcx> { } fn declare_global(&self, span: Span, ty: Word) -> SpirvValue { - let ptr_ty = SpirvType::Pointer { - storage_class: StorageClass::Function, - pointee: ty, - } - .def(span, self); + let ptr_ty = SpirvType::Pointer { pointee: ty }.def(span, self); + // FIXME(eddyb) figure out what the correct storage class is. let result = self .emit_global() - .variable(ptr_ty, None, StorageClass::Function, None) + .variable(ptr_ty, None, StorageClass::Private, None) .with_type(ptr_ty); // TODO: These should be StorageClass::Private, so just zombie for now. self.zombie_with_span(result.def_cx(self), span, "Globals are not supported yet"); @@ -264,7 +261,7 @@ impl<'tcx> StaticMethods for CodegenCx<'tcx> { Err(_) => return, }; let value_ty = match self.lookup_type(g.ty) { - SpirvType::Pointer { pointee, .. } => pointee, + SpirvType::Pointer { pointee } => pointee, other => self.tcx.sess.fatal(&format!( "global had non-pointer type {}", other.debug(g.ty, self) diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs index 08cc9bdd0d..560da694fa 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs @@ -1,10 +1,12 @@ use super::CodegenCx; +use crate::abi::ConvSpirvType; use crate::builder_spirv::SpirvValue; use crate::spirv_type::SpirvType; use crate::symbols::{parse_attrs, Entry, SpirvAttribute}; use rspirv::dr::Operand; use rspirv::spirv::{Decoration, ExecutionModel, FunctionControl, StorageClass, Word}; -use rustc_hir::{Param, PatKind}; +use rustc_hir as hir; +use rustc_middle::ty::layout::TyAndLayout; use rustc_middle::ty::{Instance, Ty}; use rustc_span::Span; use rustc_target::abi::call::{FnAbi, PassMode}; @@ -18,7 +20,7 @@ impl<'tcx> CodegenCx<'tcx> { pub fn entry_stub( &self, instance: &Instance<'_>, - fn_abi: &FnAbi<'_, Ty<'_>>, + fn_abi: &FnAbi<'tcx, Ty<'tcx>>, entry_func: SpirvValue, name: String, entry: Entry, @@ -60,6 +62,7 @@ impl<'tcx> CodegenCx<'tcx> { self.shader_entry_stub( self.tcx.def_span(instance.def_id()), entry_func, + fn_abi, body.params, name, execution_model, @@ -78,7 +81,8 @@ impl<'tcx> CodegenCx<'tcx> { &self, span: Span, entry_func: SpirvValue, - hir_params: &[Param<'tcx>], + entry_fn_abi: &FnAbi<'tcx, Ty<'tcx>>, + hir_params: &[hir::Param<'tcx>], name: String, execution_model: ExecutionModel, ) -> Word { @@ -88,11 +92,11 @@ impl<'tcx> CodegenCx<'tcx> { arguments: vec![], } .def(span, self); - let (entry_func_return, entry_func_args) = match self.lookup_type(entry_func.ty) { + let entry_func_return_type = match self.lookup_type(entry_func.ty) { SpirvType::Function { return_type, - arguments, - } => (return_type, arguments), + arguments: _, + } => return_type, other => self.tcx.sess.fatal(&format!( "Invalid entry_stub type: {}", other.debug(entry_func.ty, self) @@ -100,11 +104,12 @@ impl<'tcx> CodegenCx<'tcx> { }; let mut decoration_locations = HashMap::new(); // Create OpVariables before OpFunction so they're global instead of local vars. - let arguments = entry_func_args + let arguments = entry_fn_abi + .args .iter() .zip(hir_params) - .map(|(&arg, hir_param)| { - self.declare_parameter(arg, hir_param, &mut decoration_locations) + .map(|(entry_fn_arg, hir_param)| { + self.declare_parameter(entry_fn_arg.layout, hir_param, &mut decoration_locations) }) .collect::>(); let mut emit = self.emit_global(); @@ -113,7 +118,7 @@ impl<'tcx> CodegenCx<'tcx> { .unwrap(); emit.begin_block(None).unwrap(); emit.function_call( - entry_func_return, + entry_func_return_type, None, entry_func.def_cx(self), arguments.iter().map(|&(a, _)| a), @@ -139,24 +144,26 @@ impl<'tcx> CodegenCx<'tcx> { fn declare_parameter( &self, - arg: Word, - hir_param: &Param<'tcx>, + layout: TyAndLayout<'tcx>, + hir_param: &hir::Param<'tcx>, decoration_locations: &mut HashMap, ) -> (Word, StorageClass) { - let storage_class = match self.lookup_type(arg) { - SpirvType::Pointer { storage_class, .. } => storage_class, - other => self.tcx.sess.fatal(&format!( - "Invalid entry arg type {}", - other.debug(arg, self) - )), - }; + let storage_class = crate::abi::get_storage_class(self, layout).unwrap_or_else(|| { + self.tcx.sess.span_fatal( + hir_param.span, + &format!("invalid entry param type `{}`", layout.ty), + ); + }); let mut has_location = matches!( storage_class, StorageClass::Input | StorageClass::Output | StorageClass::UniformConstant ); // Note: this *declares* the variable too. - let variable = self.emit_global().variable(arg, None, storage_class, None); - if let PatKind::Binding(_, _, ident, _) = &hir_param.pat.kind { + let spirv_type = layout.spirv_type(hir_param.span, self); + let variable = self + .emit_global() + .variable(spirv_type, None, storage_class, None); + if let hir::PatKind::Binding(_, _, ident, _) = &hir_param.pat.kind { self.emit_global().name(variable, ident.to_string()); } for attr in parse_attrs(self, hir_param.attrs) { diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs b/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs index 3158d4c433..408e011e82 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs @@ -217,18 +217,15 @@ impl<'tcx> CodegenCx<'tcx> { /// See note on `SpirvValueKind::ConstantPointer` pub fn make_constant_pointer(&self, span: Span, value: SpirvValue) -> SpirvValue { - let ty = SpirvType::Pointer { - storage_class: StorageClass::Function, - pointee: value.ty, - } - .def(span, self); + let ty = SpirvType::Pointer { pointee: value.ty }.def(span, self); let initializer = value.def_cx(self); // Create these up front instead of on demand in SpirvValue::def because // SpirvValue::def can't use cx.emit() + // FIXME(eddyb) figure out what the correct storage class is. let global_var = self.emit_global() - .variable(ty, None, StorageClass::Function, Some(initializer)); + .variable(ty, None, StorageClass::Private, Some(initializer)); // In all likelihood, this zombie message will get overwritten in SpirvValue::def_with_span // to the use site of this constant. However, if this constant happens to never get used, we diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/type_.rs b/crates/rustc_codegen_spirv/src/codegen_cx/type_.rs index eab5588a96..281dab8e76 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/type_.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/type_.rs @@ -1,7 +1,6 @@ use super::CodegenCx; use crate::abi::ConvSpirvType; use crate::spirv_type::SpirvType; -use rspirv::spirv::StorageClass; use rspirv::spirv::Word; use rustc_codegen_ssa::common::TypeKind; use rustc_codegen_ssa::traits::{BaseTypeMethods, LayoutTypeMethods}; @@ -181,25 +180,14 @@ impl<'tcx> BaseTypeMethods<'tcx> for CodegenCx<'tcx> { } } fn type_ptr_to(&self, ty: Self::Type) -> Self::Type { - SpirvType::Pointer { - storage_class: StorageClass::Function, - pointee: ty, - } - .def(DUMMY_SP, self) + SpirvType::Pointer { pointee: ty }.def(DUMMY_SP, self) } fn type_ptr_to_ext(&self, ty: Self::Type, _address_space: AddressSpace) -> Self::Type { - SpirvType::Pointer { - storage_class: StorageClass::Function, - pointee: ty, - } - .def(DUMMY_SP, self) + SpirvType::Pointer { pointee: ty }.def(DUMMY_SP, self) } fn element_type(&self, ty: Self::Type) -> Self::Type { match self.lookup_type(ty) { - SpirvType::Pointer { - storage_class: _, - pointee, - } => pointee, + SpirvType::Pointer { pointee } => pointee, SpirvType::Vector { element, .. } => element, spirv_type => self.tcx.sess.fatal(&format!( "element_type called on invalid type: {:?}", diff --git a/crates/rustc_codegen_spirv/src/linker/mod.rs b/crates/rustc_codegen_spirv/src/linker/mod.rs index 5c14deaca9..c42cd807f4 100644 --- a/crates/rustc_codegen_spirv/src/linker/mod.rs +++ b/crates/rustc_codegen_spirv/src/linker/mod.rs @@ -15,8 +15,8 @@ mod zombies; use crate::decorations::{CustomDecoration, UnrollLoopsDecoration}; use rspirv::binary::Consumer; -use rspirv::dr::{Block, Instruction, Loader, Module, ModuleHeader}; -use rspirv::spirv::{Op, Word}; +use rspirv::dr::{Block, Instruction, Loader, Module, ModuleHeader, Operand}; +use rspirv::spirv::{Op, StorageClass, Word}; use rustc_errors::ErrorReported; use rustc_session::Session; use std::collections::HashMap; @@ -107,6 +107,17 @@ pub fn link(sess: &Session, mut inputs: Vec, opts: &Options) -> Result, opts: &Options) -> Result { - let result = cx.emit_global().type_pointer(None, storage_class, pointee); + Self::Pointer { pointee } => { + // NOTE(eddyb) we emit `StorageClass::Generic` here, but later + // the linker will specialize the entire SPIR-V module to use + // storage classes inferred from `OpVariable`s. + let result = cx + .emit_global() + .type_pointer(None, StorageClass::Generic, pointee); // no pointers to functions if let Self::Function { .. } = cx.lookup_type(pointee) { cx.zombie_even_in_user_code( @@ -249,13 +250,13 @@ impl SpirvType { return cached; } let result = match self { - Self::Pointer { - storage_class, - pointee, - } => { - let result = cx - .emit_global() - .type_pointer(Some(id), storage_class, pointee); + Self::Pointer { pointee } => { + // NOTE(eddyb) we emit `StorageClass::Generic` here, but later + // the linker will specialize the entire SPIR-V module to use + // storage classes inferred from `OpVariable`s. + let result = + cx.emit_global() + .type_pointer(Some(id), StorageClass::Generic, pointee); // no pointers to functions if let Self::Function { .. } = cx.lookup_type(pointee) { cx.zombie_even_in_user_code( @@ -440,13 +441,9 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_> { .field("id", &self.id) .field("element", &self.cx.debug_type(element)) .finish(), - SpirvType::Pointer { - storage_class, - pointee, - } => f + SpirvType::Pointer { pointee } => f .debug_struct("Pointer") .field("id", &self.id) - .field("storage_class", &storage_class) .field("pointee", &self.cx.debug_type(pointee)) .finish(), SpirvType::Function { @@ -599,11 +596,8 @@ impl SpirvTypePrinter<'_, '_> { ty(self.cx, stack, f, element)?; f.write_str("]") } - SpirvType::Pointer { - storage_class, - pointee, - } => { - write!(f, "*{{{:?}}} ", storage_class)?; + SpirvType::Pointer { pointee } => { + f.write_str("*")?; ty(self.cx, stack, f, pointee) } SpirvType::Function { diff --git a/crates/spirv-builder/src/test/basic.rs b/crates/spirv-builder/src/test/basic.rs index 144111c31c..092cd42d6d 100644 --- a/crates/spirv-builder/src/test/basic.rs +++ b/crates/spirv-builder/src/test/basic.rs @@ -148,9 +148,12 @@ fn asm_op_decorate() { "%image_2d = OpTypeImage %float Dim2D 0 0 0 1 Unknown", "%sampled_image_2d = OpTypeSampledImage %image_2d", "%image_array = OpTypeRuntimeArray %sampled_image_2d", - "%ptr_image_array = OpTypePointer UniformConstant %image_array", + // NOTE(eddyb) `Generic` is used here because it's the placeholder + // for storage class inference - both of the two `OpTypePointer` + // types below should end up inferring to `UniformConstant`. + "%ptr_image_array = OpTypePointer Generic %image_array", "%image_2d_var = OpVariable %ptr_image_array UniformConstant", - "%ptr_sampled_image_2d = OpTypePointer UniformConstant %sampled_image_2d", + "%ptr_sampled_image_2d = OpTypePointer Generic %sampled_image_2d", "", // ^^ type preamble "%offset = OpLoad _ {0}", "%24 = OpAccessChain %ptr_sampled_image_2d %image_2d_var %offset", @@ -516,9 +519,12 @@ fn complex_image_sample_inst() { "%image_2d = OpTypeImage %float Dim2D 0 0 0 1 Unknown", "%sampled_image_2d = OpTypeSampledImage %image_2d", "%image_array = OpTypeRuntimeArray %sampled_image_2d", - "%ptr_image_array = OpTypePointer UniformConstant %image_array", + // NOTE(eddyb) `Generic` is used here because it's the placeholder + // for storage class inference - both of the two `OpTypePointer` + // types below should end up inferring to `UniformConstant`. + "%ptr_image_array = OpTypePointer Generic %image_array", "%image_2d_var = OpVariable %ptr_image_array UniformConstant", - "%ptr_sampled_image_2d = OpTypePointer UniformConstant %sampled_image_2d", + "%ptr_sampled_image_2d = OpTypePointer Generic %sampled_image_2d", "", // ^^ type preamble "%offset = OpLoad _ {1}", "%24 = OpAccessChain %ptr_sampled_image_2d %image_2d_var %offset",