Infer storage classes using the specializer, replacing special pointer types.

This commit is contained in:
Eduard-Mihai Burtescu 2021-02-10 16:34:57 +02:00 committed by Eduard-Mihai Burtescu
parent cb0bd4b04a
commit 67746012f5
14 changed files with 208 additions and 241 deletions

View File

@ -27,19 +27,12 @@ use std::fmt;
/// tracking.
#[derive(Default)]
pub struct RecursivePointeeCache<'tcx> {
map: RefCell<HashMap<(PointeeTy<'tcx>, StorageClass), PointeeDefState>>,
map: RefCell<HashMap<PointeeTy<'tcx>, PointeeDefState>>,
}
impl<'tcx> RecursivePointeeCache<'tcx> {
fn begin(
&self,
cx: &CodegenCx<'tcx>,
span: Span,
pointee: PointeeTy<'tcx>,
storage_class: StorageClass,
) -> Option<Word> {
// Warning: storage_class must match the one called with end()
match self.map.borrow_mut().entry((pointee, storage_class)) {
fn begin(&self, cx: &CodegenCx<'tcx>, span: Span, pointee: PointeeTy<'tcx>) -> Option<Word> {
match self.map.borrow_mut().entry(pointee) {
// State: This is the first time we've seen this type. Record that we're beginning to translate this type,
// and start doing the translation.
Entry::Vacant(entry) => {
@ -52,7 +45,11 @@ impl<'tcx> RecursivePointeeCache<'tcx> {
// emit an OpTypeForwardPointer, and use that ID. (This is the juicy part of this algorithm)
PointeeDefState::Defining => {
let new_id = cx.emit_global().id();
cx.emit_global().type_forward_pointer(new_id, storage_class);
// NOTE(eddyb) we emit `StorageClass::Generic` here, but later
// the linker will specialize the entire SPIR-V module to use
// storage classes inferred from `OpVariable`s.
cx.emit_global()
.type_forward_pointer(new_id, StorageClass::Generic);
entry.insert(PointeeDefState::DefiningWithForward(new_id));
if !cx.builder.has_capability(Capability::Addresses)
&& !cx
@ -81,11 +78,9 @@ impl<'tcx> RecursivePointeeCache<'tcx> {
cx: &CodegenCx<'tcx>,
span: Span,
pointee: PointeeTy<'tcx>,
storage_class: StorageClass,
pointee_spv: Word,
) -> Word {
// Warning: storage_class must match the one called with begin()
match self.map.borrow_mut().entry((pointee, storage_class)) {
match self.map.borrow_mut().entry(pointee) {
// We should have hit begin() on this type already, which always inserts an entry.
Entry::Vacant(_) => bug!("RecursivePointeeCache::end should always have entry"),
Entry::Occupied(mut entry) => match *entry.get() {
@ -93,7 +88,6 @@ impl<'tcx> RecursivePointeeCache<'tcx> {
// OpTypeForwardPointer has been emitted. This is the most common case.
PointeeDefState::Defining => {
let id = SpirvType::Pointer {
storage_class,
pointee: pointee_spv,
}
.def(span, cx);
@ -105,7 +99,6 @@ impl<'tcx> RecursivePointeeCache<'tcx> {
PointeeDefState::DefiningWithForward(id) => {
entry.insert(PointeeDefState::Defined(id));
SpirvType::Pointer {
storage_class,
pointee: pointee_spv,
}
.def_with_id(cx, span, id)
@ -261,11 +254,7 @@ impl<'tcx> ConvSpirvType<'tcx> for FnAbi<'tcx, Ty<'tcx>> {
PassMode::Cast(cast_target) => cast_target.spirv_type(span, cx),
PassMode::Indirect { .. } => {
let pointee = self.ret.layout.spirv_type(span, cx);
let pointer = SpirvType::Pointer {
storage_class: StorageClass::Function,
pointee,
}
.def(span, cx);
let pointer = SpirvType::Pointer { pointee }.def(span, cx);
// Important: the return pointer comes *first*, not last.
argument_types.push(pointer);
SpirvType::Void.def(span, cx)
@ -304,11 +293,7 @@ impl<'tcx> ConvSpirvType<'tcx> for FnAbi<'tcx, Ty<'tcx>> {
extra_attrs: None, ..
} => {
let pointee = arg.layout.spirv_type(span, cx);
SpirvType::Pointer {
storage_class: StorageClass::Function,
pointee,
}
.def(span, cx)
SpirvType::Pointer { pointee }.def(span, cx)
}
};
argument_types.push(arg_type);
@ -465,26 +450,20 @@ fn trans_scalar<'tcx>(
Primitive::F32 => SpirvType::Float(32).def(span, cx),
Primitive::F64 => SpirvType::Float(64).def(span, cx),
Primitive::Pointer => {
let (storage_class, pointee_ty) = dig_scalar_pointee(cx, ty, index);
// Default to function storage class.
let storage_class = storage_class.unwrap_or(StorageClass::Function);
let pointee_ty = dig_scalar_pointee(cx, ty, index);
// Pointers can be recursive. So, record what we're currently translating, and if we're already translating
// the same type, emit an OpTypeForwardPointer and use that ID.
if let Some(predefined_result) =
cx.type_cache
if let Some(predefined_result) = cx
.type_cache
.recursive_pointee_cache
.begin(cx, span, pointee_ty, storage_class)
.begin(cx, span, pointee_ty)
{
predefined_result
} else {
let pointee = pointee_ty.spirv_type(span, cx);
cx.type_cache.recursive_pointee_cache.end(
cx,
span,
pointee_ty,
storage_class,
pointee,
)
cx.type_cache
.recursive_pointee_cache
.end(cx, span, pointee_ty, pointee)
}
}
}
@ -504,12 +483,12 @@ fn dig_scalar_pointee<'tcx>(
cx: &CodegenCx<'tcx>,
ty: TyAndLayout<'tcx>,
index: Option<usize>,
) -> (Option<StorageClass>, PointeeTy<'tcx>) {
) -> PointeeTy<'tcx> {
match *ty.ty.kind() {
TyKind::Ref(_, elem_ty, _) | TyKind::RawPtr(TypeAndMut { ty: elem_ty, .. }) => {
let elem = cx.layout_of(elem_ty);
match index {
None => (None, PointeeTy::Ty(elem)),
None => PointeeTy::Ty(elem),
Some(index) => {
if elem.is_unsized() {
dig_scalar_pointee(cx, ty.field(cx, index), None)
@ -518,12 +497,12 @@ fn dig_scalar_pointee<'tcx>(
// of ScalarPair could be deduced, but it's actually e.g. a sized pointer followed by some other
// completely unrelated type, not a wide pointer. So, translate this as a single scalar, one
// component of that ScalarPair.
(None, PointeeTy::Ty(elem))
PointeeTy::Ty(elem)
}
}
}
}
TyKind::FnPtr(sig) if index.is_none() => (None, PointeeTy::Fn(sig)),
TyKind::FnPtr(sig) if index.is_none() => PointeeTy::Fn(sig),
TyKind::Adt(def, _) if def.is_box() => {
let ptr_ty = cx.layout_of(cx.tcx.mk_mut_ptr(ty.ty.boxed_ty()));
dig_scalar_pointee(cx, ptr_ty, index)
@ -542,11 +521,8 @@ fn dig_scalar_pointee_adt<'tcx>(
cx: &CodegenCx<'tcx>,
ty: TyAndLayout<'tcx>,
index: Option<usize>,
) -> (Option<StorageClass>, PointeeTy<'tcx>) {
// Storage classes can only be applied on structs containing a single pointer field (because we said so), so we only
// need to handle the attribute here.
let storage_class = get_storage_class(cx, ty);
let result = match &ty.variants {
) -> PointeeTy<'tcx> {
match &ty.variants {
// If it's a Variants::Multiple, then we want to emit the type of the dataful variant, not the type of the
// discriminant. This is because the discriminant can e.g. have type *mut(), whereas we want the full underlying
// type, only available in the dataful variant.
@ -603,21 +579,16 @@ fn dig_scalar_pointee_adt<'tcx>(
},
}
}
};
match (storage_class, result) {
(storage_class, (None, result)) => (storage_class, result),
(None, (storage_class, result)) => (storage_class, result),
(Some(one), (Some(two), _)) => cx.tcx.sess.fatal(&format!(
"Double-applied storage class ({:?} and {:?}) on type {}",
one, two, ty.ty
)),
}
}
/// Handles `#[spirv(storage_class="blah")]`. Note this is only called in the scalar translation code, because this is only
/// Handles `#[spirv(storage_class="blah")]`. Note this is only called in the entry interface variables code, because this is only
/// used for spooky builtin stuff, and we pinky promise to never have more than one pointer field in one of these.
// TODO: Enforce this is only used in spirv-std.
fn get_storage_class<'tcx>(cx: &CodegenCx<'tcx>, ty: TyAndLayout<'tcx>) -> Option<StorageClass> {
pub(crate) fn get_storage_class<'tcx>(
cx: &CodegenCx<'tcx>,
ty: TyAndLayout<'tcx>,
) -> Option<StorageClass> {
if let TyKind::Adt(adt, _substs) = ty.ty.kind() {
for attr in parse_attrs(cx, cx.tcx.get_attrs(adt.did)) {
if let SpirvAttribute::StorageClass(storage_class) = attr {

View File

@ -727,11 +727,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
}
fn alloca(&mut self, ty: Self::Type, _align: Align) -> Self::Value {
let ptr_ty = SpirvType::Pointer {
storage_class: StorageClass::Function,
pointee: ty,
}
.def(self.span(), self);
let ptr_ty = SpirvType::Pointer { pointee: ty }.def(self.span(), self);
// "All OpVariable instructions in a function must be the first instructions in the first block."
let mut builder = self.emit();
builder.select_block(Some(0)).unwrap();
@ -779,10 +775,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
return value;
}
let ty = match self.lookup_type(ptr.ty) {
SpirvType::Pointer {
storage_class: _,
pointee,
} => pointee,
SpirvType::Pointer { pointee } => pointee,
ty => self.fatal(&format!(
"load called on variable that wasn't a pointer: {:?}",
ty
@ -803,10 +796,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
fn atomic_load(&mut self, ptr: Self::Value, order: AtomicOrdering, _size: Size) -> Self::Value {
let ty = match self.lookup_type(ptr.ty) {
SpirvType::Pointer {
storage_class: _,
pointee,
} => pointee,
SpirvType::Pointer { pointee } => pointee,
ty => self.fatal(&format!(
"atomic_load called on variable that wasn't a pointer: {:?}",
ty
@ -906,10 +896,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
fn store(&mut self, val: Self::Value, ptr: Self::Value, _align: Align) -> Self::Value {
let ptr_elem_ty = match self.lookup_type(ptr.ty) {
SpirvType::Pointer {
storage_class: _,
pointee,
} => pointee,
SpirvType::Pointer { pointee } => pointee,
ty => self.fatal(&format!(
"store called on variable that wasn't a pointer: {:?}",
ty
@ -946,10 +933,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
_size: Size,
) {
let ptr_elem_ty = match self.lookup_type(ptr.ty) {
SpirvType::Pointer {
storage_class: _,
pointee,
} => pointee,
SpirvType::Pointer { pointee } => pointee,
ty => self.fatal(&format!(
"atomic_store called on variable that wasn't a pointer: {:?}",
ty
@ -979,15 +963,12 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
}
fn struct_gep(&mut self, ptr: Self::Value, idx: u64) -> Self::Value {
let (storage_class, result_pointee_type) = match self.lookup_type(ptr.ty) {
SpirvType::Pointer {
storage_class,
pointee,
} => match self.lookup_type(pointee) {
SpirvType::Adt { field_types, .. } => (storage_class, field_types[idx as usize]),
let result_pointee_type = match self.lookup_type(ptr.ty) {
SpirvType::Pointer { pointee } => match self.lookup_type(pointee) {
SpirvType::Adt { field_types, .. } => field_types[idx as usize],
SpirvType::Array { element, .. }
| SpirvType::RuntimeArray { element, .. }
| SpirvType::Vector { element, .. } => (storage_class, element),
| SpirvType::Vector { element, .. } => element,
other => self.fatal(&format!(
"struct_gep not on struct, array, or vector type: {:?}, index {}",
other, idx
@ -999,7 +980,6 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
)),
};
let result_type = SpirvType::Pointer {
storage_class,
pointee: result_pointee_type,
}
.def(self.span(), self);
@ -1225,14 +1205,14 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
fn pointercast(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value {
let val_pointee = match self.lookup_type(val.ty) {
SpirvType::Pointer { pointee, .. } => pointee,
SpirvType::Pointer { pointee } => pointee,
other => self.fatal(&format!(
"pointercast called on non-pointer source type: {:?}",
other
)),
};
let dest_pointee = match self.lookup_type(dest_ty) {
SpirvType::Pointer { pointee, .. } => pointee,
SpirvType::Pointer { pointee } => pointee,
other => self.fatal(&format!(
"pointercast called on non-pointer dest type: {:?}",
other
@ -1531,7 +1511,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
return;
}
let src_pointee = match self.lookup_type(src.ty) {
SpirvType::Pointer { pointee, .. } => Some(pointee),
SpirvType::Pointer { pointee } => Some(pointee),
_ => None,
};
let src_element_size = src_pointee.and_then(|p| self.lookup_type(p).sizeof(self));
@ -1592,7 +1572,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
));
}
let elem_ty = match self.lookup_type(ptr.ty) {
SpirvType::Pointer { pointee, .. } => pointee,
SpirvType::Pointer { pointee } => pointee,
_ => self.fatal(&format!(
"memset called on non-pointer type: {}",
self.debug_type(ptr.ty)
@ -1780,10 +1760,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
_weak: bool,
) -> Self::Value {
let dst_pointee_ty = match self.lookup_type(dst.ty) {
SpirvType::Pointer {
storage_class: _,
pointee,
} => pointee,
SpirvType::Pointer { pointee } => pointee,
ty => self.fatal(&format!(
"atomic_cmpxchg called on variable that wasn't a pointer: {:?}",
ty
@ -1820,10 +1797,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
order: AtomicOrdering,
) -> Self::Value {
let dst_pointee_ty = match self.lookup_type(dst.ty) {
SpirvType::Pointer {
storage_class: _,
pointee,
} => pointee,
SpirvType::Pointer { pointee } => pointee,
ty => self.fatal(&format!(
"atomic_rmw called on variable that wasn't a pointer: {:?}",
ty

View File

@ -3,7 +3,7 @@ use crate::abi::ConvSpirvType;
use crate::builder_spirv::{SpirvValue, SpirvValueExt};
use crate::codegen_cx::CodegenCx;
use crate::spirv_type::SpirvType;
use rspirv::spirv::{CLOp, GLOp, StorageClass};
use rspirv::spirv::{CLOp, GLOp};
use rustc_codegen_ssa::mir::operand::OperandRef;
use rustc_codegen_ssa::mir::place::PlaceRef;
use rustc_codegen_ssa::traits::{BuilderMethods, IntrinsicCallMethods};
@ -103,11 +103,7 @@ impl<'a, 'tcx> IntrinsicCallMethods<'tcx> for Builder<'a, 'tcx> {
let mut ptr = args[0].immediate();
if let PassMode::Cast(ty) = fn_abi.ret.mode {
let pointee = ty.spirv_type(self.span(), self);
let pointer = SpirvType::Pointer {
storage_class: StorageClass::Function,
pointee,
}
.def(self.span(), self);
let pointer = SpirvType::Pointer { pointee }.def(self.span(), self);
ptr = self.pointercast(ptr, pointer);
}
let load = self.volatile_load(ptr);

View File

@ -12,7 +12,7 @@ use crate::abi::ConvSpirvType;
use crate::builder_spirv::{BuilderCursor, SpirvValue, SpirvValueExt};
use crate::codegen_cx::CodegenCx;
use crate::spirv_type::SpirvType;
use rspirv::spirv::{StorageClass, Word};
use rspirv::spirv::Word;
use rustc_codegen_ssa::mir::operand::OperandValue;
use rustc_codegen_ssa::mir::place::PlaceRef;
use rustc_codegen_ssa::traits::{
@ -104,11 +104,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
// "An OpAccessChain instruction is the equivalent of an LLVM getelementptr instruction where the first index element is zero."
// https://github.com/gpuweb/gpuweb/issues/33
let mut result_indices = Vec::with_capacity(indices.len() - 1);
let (storage_class, mut result_pointee_type) = match self.lookup_type(ptr.ty) {
SpirvType::Pointer {
storage_class,
pointee,
} => (storage_class, pointee),
let mut result_pointee_type = match self.lookup_type(ptr.ty) {
SpirvType::Pointer { pointee } => pointee,
other_type => self.fatal(&format!(
"GEP first deref not implemented for type {:?}",
other_type
@ -125,7 +122,6 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
};
}
let result_type = SpirvType::Pointer {
storage_class,
pointee: result_pointee_type,
}
.def(self.span(), self);
@ -330,11 +326,7 @@ impl<'a, 'tcx> ArgAbiMethods<'tcx> for Builder<'a, 'tcx> {
self.fatal("unsized `ArgAbi` must be handled through `store_fn_arg`");
} else if let PassMode::Cast(cast) = arg_abi.mode {
let cast_ty = cast.spirv_type(self.span(), self);
let cast_ptr_ty = SpirvType::Pointer {
storage_class: StorageClass::Function,
pointee: cast_ty,
}
.def(self.span(), self);
let cast_ptr_ty = SpirvType::Pointer { pointee: cast_ty }.def(self.span(), self);
let cast_dst = self.pointercast(dst.llval, cast_ptr_ty);
self.store(val, cast_dst, arg_abi.layout.align.abi);
} else {

View File

@ -263,11 +263,25 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
element: inst.operands[0].unwrap_id_ref(),
}
.def(self.span(), self),
Op::TypePointer => SpirvType::Pointer {
storage_class: inst.operands[0].unwrap_storage_class(),
Op::TypePointer => {
let storage_class = inst.operands[0].unwrap_storage_class();
if storage_class != StorageClass::Generic {
self.struct_err("TypePointer in asm! requires `Generic` storage class")
.note(&format!(
"`{:?}` storage class was specified",
storage_class
))
.help(&format!(
"the storage class will be inferred automatically (e.g. to `{:?}`)",
storage_class
))
.emit();
}
SpirvType::Pointer {
pointee: inst.operands[1].unwrap_id_ref(),
}
.def(self.span(), self),
.def(self.span(), self)
}
Op::TypeImage => SpirvType::Image {
sampled_type: inst.operands[0].unwrap_id_ref(),
dim: inst.operands[1].unwrap_dim(),
@ -511,24 +525,26 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
use crate::spirv_type_constraints::{instruction_signatures, InstSig, TyListPat, TyPat};
#[derive(Debug)]
struct Mismatch;
struct Unapplicable;
/// Recursively match `ty` against `pat`, returning one of:
/// * `Ok(None)`: `pat` matched but contained no type variables
/// * `Ok(Some(var))`: `pat` matched and `var` is the type variable
/// * `Err(Mismatch)`: `pat` didn't match or isn't supported right now
fn apply_ty_pat(
fn match_ty_pat(
cx: &CodegenCx<'_>,
pat: &TyPat<'_>,
ty: Word,
) -> Result<Option<Word>, Mismatch> {
) -> Result<Option<Word>, Unapplicable> {
match pat {
TyPat::Any => Ok(None),
&TyPat::T => Ok(Some(ty)),
TyPat::Either(a, b) => {
apply_ty_pat(cx, a, ty).or_else(|Mismatch| apply_ty_pat(cx, b, ty))
match_ty_pat(cx, a, ty).or_else(|Unapplicable| match_ty_pat(cx, b, ty))
}
_ => match (pat, cx.lookup_type(ty)) {
(TyPat::Any, _) | (&TyPat::T, _) | (TyPat::Either(..), _) => unreachable!(),
(TyPat::Void, SpirvType::Void) => Ok(None),
(TyPat::Pointer(_, pat), SpirvType::Pointer { pointee: ty, .. })
| (TyPat::Vector(pat), SpirvType::Vector { element: ty, .. })
@ -546,9 +562,9 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
},
)
| (TyPat::SampledImage(pat), SpirvType::SampledImage { image_type: ty }) => {
apply_ty_pat(cx, pat, ty)
match_ty_pat(cx, pat, ty)
}
_ => Err(Mismatch),
_ => Err(Unapplicable),
},
}
}
@ -567,17 +583,19 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
let mut ids = instruction.operands.iter().filter_map(|o| o.id_ref_any());
while let TyListPat::Cons { first: pat, suffix } = *sig.input_types {
let apply_result = match id_to_type_map.get(&ids.next()?) {
Some(&ty) => apply_ty_pat(self, pat, ty),
sig.input_types = suffix;
let match_result = match id_to_type_map.get(&ids.next()?) {
Some(&ty) => match_ty_pat(self, pat, ty),
// Non-value ID operand (or value operand of unknown type),
// only `TyPat::Any` is valid.
None => match pat {
TyPat::Any => Ok(None),
_ => Err(Mismatch),
_ => Err(Unapplicable),
},
};
match apply_result {
match match_result {
Ok(Some(var)) => match combined_var {
Some(combined_var) => {
// FIXME(eddyb) this could use some error reporting
@ -591,11 +609,12 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
None => combined_var = Some(var),
},
Ok(None) => {}
Err(Mismatch) => return None,
Err(Unapplicable) => return None,
}
sig.input_types = suffix;
}
match sig.input_types {
TyListPat::Cons { .. } => unreachable!(),
TyListPat::Any => {}
TyListPat::Nil => {
if ids.next().is_some() {
@ -742,7 +761,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
Some(match kind {
TypeofKind::Plain => ty,
TypeofKind::Dereference => match self.lookup_type(ty) {
SpirvType::Pointer { pointee, .. } => pointee,
SpirvType::Pointer { pointee } => pointee,
other => {
self.tcx.sess.span_err(
span,
@ -764,7 +783,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
self.check_reg(span, reg);
match place {
Some(place) => match self.lookup_type(place.llval.ty) {
SpirvType::Pointer { pointee, .. } => Some(pointee),
SpirvType::Pointer { pointee } => Some(pointee),
other => {
self.tcx.sess.span_err(
span,

View File

@ -44,10 +44,7 @@ impl SpirvValue {
global_var: _,
} => {
let ty = match cx.lookup_type(self.ty) {
SpirvType::Pointer {
storage_class: _,
pointee,
} => pointee,
SpirvType::Pointer { pointee } => pointee,
ty => bug!("load called on variable that wasn't a pointer: {:?}", ty),
};
Some(initializer.with_type(ty))

View File

@ -271,7 +271,7 @@ impl<'tcx> ConstMethods<'tcx> for CodegenCx<'tcx> {
let (base_addr, _base_addr_space) = match self.tcx.global_alloc(ptr.alloc_id) {
GlobalAlloc::Memory(alloc) => {
let pointee = match self.lookup_type(ty) {
SpirvType::Pointer { pointee, .. } => pointee,
SpirvType::Pointer { pointee } => pointee,
other => self.tcx.sess.fatal(&format!(
"GlobalAlloc::Memory type not implemented: {}",
other.debug(ty, self)
@ -306,28 +306,7 @@ impl<'tcx> ConstMethods<'tcx> for CodegenCx<'tcx> {
.fatal("Non-pointer-typed scalar_to_backend Scalar::Ptr not supported");
// unsafe { llvm::LLVMConstPtrToInt(llval, llty) }
} else {
match (self.lookup_type(value.ty), self.lookup_type(ty)) {
(
SpirvType::Pointer {
storage_class: a_space,
pointee: a,
},
SpirvType::Pointer {
storage_class: b_space,
pointee: b,
},
) => {
if a_space != b_space {
// TODO: Emit the correct type that is passed into this function.
self.zombie_no_span(
value.def_cx(self),
"invalid pointer space in constant",
);
}
assert_ty_eq!(self, a, b);
}
_ => assert_ty_eq!(self, value.ty, ty),
}
assert_ty_eq!(self, value.ty, ty);
value
}
}

View File

@ -184,14 +184,11 @@ impl<'tcx> CodegenCx<'tcx> {
}
fn declare_global(&self, span: Span, ty: Word) -> SpirvValue {
let ptr_ty = SpirvType::Pointer {
storage_class: StorageClass::Function,
pointee: ty,
}
.def(span, self);
let ptr_ty = SpirvType::Pointer { pointee: ty }.def(span, self);
// FIXME(eddyb) figure out what the correct storage class is.
let result = self
.emit_global()
.variable(ptr_ty, None, StorageClass::Function, None)
.variable(ptr_ty, None, StorageClass::Private, None)
.with_type(ptr_ty);
// TODO: These should be StorageClass::Private, so just zombie for now.
self.zombie_with_span(result.def_cx(self), span, "Globals are not supported yet");
@ -264,7 +261,7 @@ impl<'tcx> StaticMethods for CodegenCx<'tcx> {
Err(_) => return,
};
let value_ty = match self.lookup_type(g.ty) {
SpirvType::Pointer { pointee, .. } => pointee,
SpirvType::Pointer { pointee } => pointee,
other => self.tcx.sess.fatal(&format!(
"global had non-pointer type {}",
other.debug(g.ty, self)

View File

@ -1,10 +1,12 @@
use super::CodegenCx;
use crate::abi::ConvSpirvType;
use crate::builder_spirv::SpirvValue;
use crate::spirv_type::SpirvType;
use crate::symbols::{parse_attrs, Entry, SpirvAttribute};
use rspirv::dr::Operand;
use rspirv::spirv::{Decoration, ExecutionModel, FunctionControl, StorageClass, Word};
use rustc_hir::{Param, PatKind};
use rustc_hir as hir;
use rustc_middle::ty::layout::TyAndLayout;
use rustc_middle::ty::{Instance, Ty};
use rustc_span::Span;
use rustc_target::abi::call::{FnAbi, PassMode};
@ -18,7 +20,7 @@ impl<'tcx> CodegenCx<'tcx> {
pub fn entry_stub(
&self,
instance: &Instance<'_>,
fn_abi: &FnAbi<'_, Ty<'_>>,
fn_abi: &FnAbi<'tcx, Ty<'tcx>>,
entry_func: SpirvValue,
name: String,
entry: Entry,
@ -60,6 +62,7 @@ impl<'tcx> CodegenCx<'tcx> {
self.shader_entry_stub(
self.tcx.def_span(instance.def_id()),
entry_func,
fn_abi,
body.params,
name,
execution_model,
@ -78,7 +81,8 @@ impl<'tcx> CodegenCx<'tcx> {
&self,
span: Span,
entry_func: SpirvValue,
hir_params: &[Param<'tcx>],
entry_fn_abi: &FnAbi<'tcx, Ty<'tcx>>,
hir_params: &[hir::Param<'tcx>],
name: String,
execution_model: ExecutionModel,
) -> Word {
@ -88,11 +92,11 @@ impl<'tcx> CodegenCx<'tcx> {
arguments: vec![],
}
.def(span, self);
let (entry_func_return, entry_func_args) = match self.lookup_type(entry_func.ty) {
let entry_func_return_type = match self.lookup_type(entry_func.ty) {
SpirvType::Function {
return_type,
arguments,
} => (return_type, arguments),
arguments: _,
} => return_type,
other => self.tcx.sess.fatal(&format!(
"Invalid entry_stub type: {}",
other.debug(entry_func.ty, self)
@ -100,11 +104,12 @@ impl<'tcx> CodegenCx<'tcx> {
};
let mut decoration_locations = HashMap::new();
// Create OpVariables before OpFunction so they're global instead of local vars.
let arguments = entry_func_args
let arguments = entry_fn_abi
.args
.iter()
.zip(hir_params)
.map(|(&arg, hir_param)| {
self.declare_parameter(arg, hir_param, &mut decoration_locations)
.map(|(entry_fn_arg, hir_param)| {
self.declare_parameter(entry_fn_arg.layout, hir_param, &mut decoration_locations)
})
.collect::<Vec<_>>();
let mut emit = self.emit_global();
@ -113,7 +118,7 @@ impl<'tcx> CodegenCx<'tcx> {
.unwrap();
emit.begin_block(None).unwrap();
emit.function_call(
entry_func_return,
entry_func_return_type,
None,
entry_func.def_cx(self),
arguments.iter().map(|&(a, _)| a),
@ -139,24 +144,26 @@ impl<'tcx> CodegenCx<'tcx> {
fn declare_parameter(
&self,
arg: Word,
hir_param: &Param<'tcx>,
layout: TyAndLayout<'tcx>,
hir_param: &hir::Param<'tcx>,
decoration_locations: &mut HashMap<StorageClass, u32>,
) -> (Word, StorageClass) {
let storage_class = match self.lookup_type(arg) {
SpirvType::Pointer { storage_class, .. } => storage_class,
other => self.tcx.sess.fatal(&format!(
"Invalid entry arg type {}",
other.debug(arg, self)
)),
};
let storage_class = crate::abi::get_storage_class(self, layout).unwrap_or_else(|| {
self.tcx.sess.span_fatal(
hir_param.span,
&format!("invalid entry param type `{}`", layout.ty),
);
});
let mut has_location = matches!(
storage_class,
StorageClass::Input | StorageClass::Output | StorageClass::UniformConstant
);
// Note: this *declares* the variable too.
let variable = self.emit_global().variable(arg, None, storage_class, None);
if let PatKind::Binding(_, _, ident, _) = &hir_param.pat.kind {
let spirv_type = layout.spirv_type(hir_param.span, self);
let variable = self
.emit_global()
.variable(spirv_type, None, storage_class, None);
if let hir::PatKind::Binding(_, _, ident, _) = &hir_param.pat.kind {
self.emit_global().name(variable, ident.to_string());
}
for attr in parse_attrs(self, hir_param.attrs) {

View File

@ -217,18 +217,15 @@ impl<'tcx> CodegenCx<'tcx> {
/// See note on `SpirvValueKind::ConstantPointer`
pub fn make_constant_pointer(&self, span: Span, value: SpirvValue) -> SpirvValue {
let ty = SpirvType::Pointer {
storage_class: StorageClass::Function,
pointee: value.ty,
}
.def(span, self);
let ty = SpirvType::Pointer { pointee: value.ty }.def(span, self);
let initializer = value.def_cx(self);
// Create these up front instead of on demand in SpirvValue::def because
// SpirvValue::def can't use cx.emit()
// FIXME(eddyb) figure out what the correct storage class is.
let global_var =
self.emit_global()
.variable(ty, None, StorageClass::Function, Some(initializer));
.variable(ty, None, StorageClass::Private, Some(initializer));
// In all likelihood, this zombie message will get overwritten in SpirvValue::def_with_span
// to the use site of this constant. However, if this constant happens to never get used, we

View File

@ -1,7 +1,6 @@
use super::CodegenCx;
use crate::abi::ConvSpirvType;
use crate::spirv_type::SpirvType;
use rspirv::spirv::StorageClass;
use rspirv::spirv::Word;
use rustc_codegen_ssa::common::TypeKind;
use rustc_codegen_ssa::traits::{BaseTypeMethods, LayoutTypeMethods};
@ -181,25 +180,14 @@ impl<'tcx> BaseTypeMethods<'tcx> for CodegenCx<'tcx> {
}
}
fn type_ptr_to(&self, ty: Self::Type) -> Self::Type {
SpirvType::Pointer {
storage_class: StorageClass::Function,
pointee: ty,
}
.def(DUMMY_SP, self)
SpirvType::Pointer { pointee: ty }.def(DUMMY_SP, self)
}
fn type_ptr_to_ext(&self, ty: Self::Type, _address_space: AddressSpace) -> Self::Type {
SpirvType::Pointer {
storage_class: StorageClass::Function,
pointee: ty,
}
.def(DUMMY_SP, self)
SpirvType::Pointer { pointee: ty }.def(DUMMY_SP, self)
}
fn element_type(&self, ty: Self::Type) -> Self::Type {
match self.lookup_type(ty) {
SpirvType::Pointer {
storage_class: _,
pointee,
} => pointee,
SpirvType::Pointer { pointee } => pointee,
SpirvType::Vector { element, .. } => element,
spirv_type => self.tcx.sess.fatal(&format!(
"element_type called on invalid type: {:?}",

View File

@ -15,8 +15,8 @@ mod zombies;
use crate::decorations::{CustomDecoration, UnrollLoopsDecoration};
use rspirv::binary::Consumer;
use rspirv::dr::{Block, Instruction, Loader, Module, ModuleHeader};
use rspirv::spirv::{Op, Word};
use rspirv::dr::{Block, Instruction, Loader, Module, ModuleHeader, Operand};
use rspirv::spirv::{Op, StorageClass, Word};
use rustc_errors::ErrorReported;
use rustc_session::Session;
use std::collections::HashMap;
@ -107,6 +107,17 @@ pub fn link(sess: &Session, mut inputs: Vec<Module>, opts: &Options) -> Result<M
output
};
if let Ok(ref path) = std::env::var("DUMP_POST_MERGE") {
use rspirv::binary::Assemble;
use std::fs::File;
use std::io::Write;
File::create(path)
.unwrap()
.write_all(spirv_tools::binary::from_binary(&output.assemble()))
.unwrap();
}
// remove duplicates (https://github.com/KhronosGroup/SPIRV-Tools/blob/e7866de4b1dc2a7e8672867caeb0bdca49f458d3/source/opt/remove_duplicates_pass.cpp)
{
let _timer = sess.timer("link_remove_duplicates");
@ -128,6 +139,45 @@ pub fn link(sess: &Session, mut inputs: Vec<Module>, opts: &Options) -> Result<M
zombies::remove_zombies(sess, &mut output);
}
// HACK(eddyb) run DCE before specialization, not just after inlining,
// to remove needed work and chance of conflicts (in dead code).
// We can't do specialization before inlining because inlining assumes
// it can rely on storage classes being the correct final ones.
if opts.dce {
let _timer = sess.timer("link_dce");
dce::dce(&mut output);
}
{
let _timer = sess.timer("specialize_generic_storage_class");
// HACK(eddyb) `specializer` requires functions' blocks to be in RPO order
// (i.e. `block_ordering_pass`) - this could be relaxed by using RPO visit
// inside `specializer`, but this is easier.
for func in &mut output.functions {
simple_passes::block_ordering_pass(func);
}
output = specializer::specialize(
output,
specializer::SimpleSpecialization {
specialize_operand: |operand| {
matches!(operand, Operand::StorageClass(StorageClass::Generic))
},
// NOTE(eddyb) this can be anything that is guaranteed to pass
// validation - there are no constraints so this is either some
// unused pointer, or perhaps one created using `OpConstantNull`
// and simply never mixed with pointers that have a storage class.
// It would be nice to use `Generic` itself here so that we leave
// some kind of indication of it being unconstrained, but `Generic`
// requires additional capabilities, so we use `Function` instead.
// TODO(eddyb) investigate whether this can end up in a pointer
// type that's the value of a module-scoped variable, and whether
// `Function` is actually invalid! (may need `Private`)
concrete_fallback: Operand::StorageClass(StorageClass::Function),
},
);
}
if opts.inline {
let _timer = sess.timer("link_inline");
inline::inline(&mut output);

View File

@ -60,7 +60,6 @@ pub enum SpirvType {
element: Word,
},
Pointer {
storage_class: StorageClass,
pointee: Word,
},
Function {
@ -194,11 +193,13 @@ impl SpirvType {
}
result
}
Self::Pointer {
storage_class,
pointee,
} => {
let result = cx.emit_global().type_pointer(None, storage_class, pointee);
Self::Pointer { pointee } => {
// NOTE(eddyb) we emit `StorageClass::Generic` here, but later
// the linker will specialize the entire SPIR-V module to use
// storage classes inferred from `OpVariable`s.
let result = cx
.emit_global()
.type_pointer(None, StorageClass::Generic, pointee);
// no pointers to functions
if let Self::Function { .. } = cx.lookup_type(pointee) {
cx.zombie_even_in_user_code(
@ -249,13 +250,13 @@ impl SpirvType {
return cached;
}
let result = match self {
Self::Pointer {
storage_class,
pointee,
} => {
let result = cx
.emit_global()
.type_pointer(Some(id), storage_class, pointee);
Self::Pointer { pointee } => {
// NOTE(eddyb) we emit `StorageClass::Generic` here, but later
// the linker will specialize the entire SPIR-V module to use
// storage classes inferred from `OpVariable`s.
let result =
cx.emit_global()
.type_pointer(Some(id), StorageClass::Generic, pointee);
// no pointers to functions
if let Self::Function { .. } = cx.lookup_type(pointee) {
cx.zombie_even_in_user_code(
@ -440,13 +441,9 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_> {
.field("id", &self.id)
.field("element", &self.cx.debug_type(element))
.finish(),
SpirvType::Pointer {
storage_class,
pointee,
} => f
SpirvType::Pointer { pointee } => f
.debug_struct("Pointer")
.field("id", &self.id)
.field("storage_class", &storage_class)
.field("pointee", &self.cx.debug_type(pointee))
.finish(),
SpirvType::Function {
@ -599,11 +596,8 @@ impl SpirvTypePrinter<'_, '_> {
ty(self.cx, stack, f, element)?;
f.write_str("]")
}
SpirvType::Pointer {
storage_class,
pointee,
} => {
write!(f, "*{{{:?}}} ", storage_class)?;
SpirvType::Pointer { pointee } => {
f.write_str("*")?;
ty(self.cx, stack, f, pointee)
}
SpirvType::Function {

View File

@ -148,9 +148,12 @@ fn asm_op_decorate() {
"%image_2d = OpTypeImage %float Dim2D 0 0 0 1 Unknown",
"%sampled_image_2d = OpTypeSampledImage %image_2d",
"%image_array = OpTypeRuntimeArray %sampled_image_2d",
"%ptr_image_array = OpTypePointer UniformConstant %image_array",
// NOTE(eddyb) `Generic` is used here because it's the placeholder
// for storage class inference - both of the two `OpTypePointer`
// types below should end up inferring to `UniformConstant`.
"%ptr_image_array = OpTypePointer Generic %image_array",
"%image_2d_var = OpVariable %ptr_image_array UniformConstant",
"%ptr_sampled_image_2d = OpTypePointer UniformConstant %sampled_image_2d",
"%ptr_sampled_image_2d = OpTypePointer Generic %sampled_image_2d",
"", // ^^ type preamble
"%offset = OpLoad _ {0}",
"%24 = OpAccessChain %ptr_sampled_image_2d %image_2d_var %offset",
@ -516,9 +519,12 @@ fn complex_image_sample_inst() {
"%image_2d = OpTypeImage %float Dim2D 0 0 0 1 Unknown",
"%sampled_image_2d = OpTypeSampledImage %image_2d",
"%image_array = OpTypeRuntimeArray %sampled_image_2d",
"%ptr_image_array = OpTypePointer UniformConstant %image_array",
// NOTE(eddyb) `Generic` is used here because it's the placeholder
// for storage class inference - both of the two `OpTypePointer`
// types below should end up inferring to `UniformConstant`.
"%ptr_image_array = OpTypePointer Generic %image_array",
"%image_2d_var = OpVariable %ptr_image_array UniformConstant",
"%ptr_sampled_image_2d = OpTypePointer UniformConstant %sampled_image_2d",
"%ptr_sampled_image_2d = OpTypePointer Generic %sampled_image_2d",
"", // ^^ type preamble
"%offset = OpLoad _ {1}",
"%24 = OpAccessChain %ptr_sampled_image_2d %image_2d_var %offset",