Arena-allocate slices to replace Vecs in SpirvType and SpirvConst.

This commit is contained in:
Eduard-Mihai Burtescu 2022-11-30 06:15:42 +02:00 committed by Eduard-Mihai Burtescu
parent 1000dece4a
commit acb05d3799
11 changed files with 260 additions and 162 deletions

View File

@ -17,8 +17,8 @@ use rustc_middle::ty::{
};
use rustc_middle::{bug, span_bug};
use rustc_span::def_id::DefId;
use rustc_span::Span;
use rustc_span::DUMMY_SP;
use rustc_span::{Span, Symbol};
use rustc_target::abi::call::{ArgAbi, ArgAttributes, FnAbi, PassMode};
use rustc_target::abi::{
Abi, Align, FieldsShape, LayoutS, Primitive, Scalar, Size, TagEncoding, VariantIdx, Variants,
@ -300,6 +300,7 @@ impl<'tcx> ConvSpirvType<'tcx> for PointeeTy<'tcx> {
impl<'tcx> ConvSpirvType<'tcx> for FnAbi<'tcx, Ty<'tcx>> {
fn spirv_type(&self, span: Span, cx: &CodegenCx<'tcx>) -> Word {
// FIXME(eddyb) use `AccumulateVec`s just like `rustc` itself does.
let mut argument_types = Vec::new();
let return_type = match self.ret.mode {
@ -332,7 +333,7 @@ impl<'tcx> ConvSpirvType<'tcx> for FnAbi<'tcx, Ty<'tcx>> {
SpirvType::Function {
return_type,
arguments: argument_types,
arguments: &argument_types,
}
.def(span, cx)
}
@ -364,8 +365,8 @@ impl<'tcx> ConvSpirvType<'tcx> for TyAndLayout<'tcx> {
def_id: def_id_for_spirv_type_adt(*self),
size: Some(Size::ZERO),
align: Align::from_bytes(0).unwrap(),
field_types: Vec::new(),
field_offsets: Vec::new(),
field_types: &[],
field_offsets: &[],
field_names: None,
}
.def_with_name(cx, span, TyLayoutNameKey::from(*self)),
@ -416,12 +417,13 @@ impl<'tcx> ConvSpirvType<'tcx> for TyAndLayout<'tcx> {
} else {
Some(self.size)
};
// FIXME(eddyb) use `ArrayVec` here.
let mut field_names = Vec::new();
if let TyKind::Adt(adt, _) = self.ty.kind() {
if let Variants::Single { index } = self.variants {
for i in self.fields.index_by_increasing_offset() {
let field = &adt.variants()[index].fields[i];
field_names.push(field.name.to_ident_string());
field_names.push(field.name);
}
}
}
@ -429,10 +431,10 @@ impl<'tcx> ConvSpirvType<'tcx> for TyAndLayout<'tcx> {
def_id: def_id_for_spirv_type_adt(*self),
size,
align: self.align.abi,
field_types: vec![a, b],
field_offsets: vec![a_offset, b_offset],
field_types: &[a, b],
field_offsets: &[a_offset, b_offset],
field_names: if field_names.len() == 2 {
Some(field_names)
Some(&field_names)
} else {
None
},
@ -598,8 +600,8 @@ fn trans_aggregate<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>
def_id: def_id_for_spirv_type_adt(ty),
size: Some(Size::ZERO),
align: Align::from_bytes(0).unwrap(),
field_types: Vec::new(),
field_offsets: Vec::new(),
field_types: &[],
field_offsets: &[],
field_names: None,
}
.def_with_name(cx, span, TyLayoutNameKey::from(ty))
@ -664,6 +666,7 @@ pub fn auto_struct_layout<'tcx>(
cx: &CodegenCx<'tcx>,
field_types: &[Word],
) -> (Vec<Size>, Option<Size>, Align) {
// FIXME(eddyb) use `AccumulateVec`s just like `rustc` itself does.
let mut field_offsets = Vec::with_capacity(field_types.len());
let mut offset = Some(Size::ZERO);
let mut max_align = Align::from_bytes(0).unwrap();
@ -688,6 +691,7 @@ pub fn auto_struct_layout<'tcx>(
fn trans_struct<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>) -> Word {
let size = if ty.is_unsized() { None } else { Some(ty.size) };
let align = ty.align.abi;
// FIXME(eddyb) use `AccumulateVec`s just like `rustc` itself does.
let mut field_types = Vec::new();
let mut field_offsets = Vec::new();
let mut field_names = Vec::new();
@ -699,9 +703,10 @@ fn trans_struct<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>) -
if let Variants::Single { index } = ty.variants {
if let TyKind::Adt(adt, _) = ty.ty.kind() {
let field = &adt.variants()[index].fields[i];
field_names.push(field.name.to_ident_string());
field_names.push(field.name);
} else {
field_names.push(format!("{}", i));
// FIXME(eddyb) this looks like something that should exist in rustc.
field_names.push(Symbol::intern(&format!("{i}")));
}
} else {
if let TyKind::Adt(_, _) = ty.ty.kind() {
@ -709,7 +714,7 @@ fn trans_struct<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>) -
span_bug!(span, "Variants::Multiple not TyKind::Adt");
}
if i == 0 {
field_names.push("discriminant".to_string());
field_names.push(cx.sym.discriminant);
} else {
cx.tcx.sess.fatal("Variants::Multiple has multiple fields")
}
@ -719,9 +724,9 @@ fn trans_struct<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>) -
def_id: def_id_for_spirv_type_adt(ty),
size,
align,
field_types,
field_offsets,
field_names: Some(field_names),
field_types: &field_types,
field_offsets: &field_offsets,
field_names: Some(&field_names),
}
.def_with_name(cx, span, TyLayoutNameKey::from(ty))
}

View File

@ -176,7 +176,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
semantics
}
fn memset_const_pattern(&self, ty: &SpirvType, fill_byte: u8) -> Word {
fn memset_const_pattern(&self, ty: &SpirvType<'tcx>, fill_byte: u8) -> Word {
match *ty {
SpirvType::Void => self.fatal("memset invalid on void pattern"),
SpirvType::Bool => self.fatal("memset invalid on bool pattern"),
@ -212,7 +212,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
SpirvType::Vector { element, count } | SpirvType::Matrix { element, count } => {
let elem_pat = self.memset_const_pattern(&self.lookup_type(element), fill_byte);
self.constant_composite(
ty.clone().def(self.span(), self),
ty.def(self.span(), self),
iter::repeat(elem_pat).take(count as usize),
)
.def(self)
@ -221,7 +221,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
let elem_pat = self.memset_const_pattern(&self.lookup_type(element), fill_byte);
let count = self.builder.lookup_const_u64(count).unwrap() as usize;
self.constant_composite(
ty.clone().def(self.span(), self),
ty.def(self.span(), self),
iter::repeat(elem_pat).take(count),
)
.def(self)
@ -242,7 +242,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
}
}
fn memset_dynamic_pattern(&self, ty: &SpirvType, fill_var: Word) -> Word {
fn memset_dynamic_pattern(&self, ty: &SpirvType<'tcx>, fill_var: Word) -> Word {
match *ty {
SpirvType::Void => self.fatal("memset invalid on void pattern"),
SpirvType::Bool => self.fatal("memset invalid on bool pattern"),
@ -270,7 +270,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
let count = self.builder.lookup_const_u64(count).unwrap() as usize;
self.emit()
.composite_construct(
ty.clone().def(self.span(), self),
ty.def(self.span(), self),
None,
iter::repeat(elem_pat).take(count),
)
@ -280,7 +280,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
let elem_pat = self.memset_dynamic_pattern(&self.lookup_type(element), fill_var);
self.emit()
.composite_construct(
ty.clone().def(self.span(), self),
ty.def(self.span(), self),
None,
iter::repeat(elem_pat).take(count as usize),
)
@ -1260,9 +1260,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
};
let pointee_kind = self.lookup_type(pointee);
let result_pointee_type = match pointee_kind {
SpirvType::Adt {
ref field_types, ..
} => field_types[idx as usize],
SpirvType::Adt { field_types, .. } => field_types[idx as usize],
SpirvType::Array { element, .. }
| SpirvType::RuntimeArray { element, .. }
| SpirvType::Vector { element, .. }
@ -2345,7 +2343,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
),
};
for (argument, argument_type) in args.iter().zip(argument_types) {
for (argument, &argument_type) in args.iter().zip(argument_types) {
assert_ty_eq!(self, argument.ty, argument_type);
}
let libm_intrinsic = self.libm_intrinsics.borrow().get(&callee_val).copied();

View File

@ -1081,7 +1081,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
(OperandKind::LiteralContextDependentNumber, Some(word)) => {
assert!(matches!(inst.class.opcode, Op::Constant | Op::SpecConstant));
let ty = inst.result_type.unwrap();
fn parse(ty: SpirvType, w: &str) -> Result<dr::Operand, String> {
fn parse(ty: SpirvType<'_>, w: &str) -> Result<dr::Operand, String> {
fn fmt(x: impl ToString) -> String {
x.to_string()
}

View File

@ -13,7 +13,6 @@ use rustc_span::symbol::Symbol;
use rustc_span::{Span, DUMMY_SP};
use std::assert_matches::assert_matches;
use std::cell::{RefCell, RefMut};
use std::rc::Rc;
use std::{fs::File, io::Write, path::Path};
#[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq, Hash)]
@ -68,7 +67,7 @@ impl SpirvValue {
pub fn const_fold_load(self, cx: &CodegenCx<'_>) -> Option<Self> {
match self.kind {
SpirvValueKind::Def(id) | SpirvValueKind::IllegalConst(id) => {
let entry = cx.builder.id_to_const.borrow().get(&id)?.clone();
let &entry = cx.builder.id_to_const.borrow().get(&id)?;
match entry.val {
SpirvConst::PtrTo { pointee } => {
let ty = match cx.lookup_type(self.ty) {
@ -213,8 +212,8 @@ impl SpirvValueExt for Word {
}
}
#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
pub enum SpirvConst {
#[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
pub enum SpirvConst<'tcx> {
U32(u32),
U64(u64),
/// f32 isn't hash, so store bits
@ -232,8 +231,7 @@ pub enum SpirvConst {
// different functions, but of the same type, don't overlap their zombies.
ZombieUndefForFnAddr,
// FIXME(eddyb) use `tcx.arena.dropless` to get `&'tcx [_]`, instead of `Rc`.
Composite(Rc<[Word]>),
Composite(&'tcx [Word]),
/// Pointer to constant data, i.e. `&pointee`, represented as an `OpVariable`
/// in the `Private` storage class, and with `pointee` as its initializer.
@ -242,6 +240,40 @@ pub enum SpirvConst {
},
}
impl SpirvConst<'_> {
/// Replace `&[T]` fields with `&'tcx [T]` ones produced by calling
/// `tcx.arena.dropless.alloc_slice(...)` - this is done late for two reasons:
/// 1. it avoids allocating in the arena when the cache would be hit anyway,
/// which would create "garbage" (as in, unreachable allocations)
/// (ideally these would also be interned, but that's even more refactors)
/// 2. an empty slice is disallowed (as it's usually handled as a special
/// case elsewhere, e.g. `rustc`'s `ty::List` - sadly we can't use that)
fn tcx_arena_alloc_slices<'tcx>(self, cx: &CodegenCx<'tcx>) -> SpirvConst<'tcx> {
fn arena_alloc_slice<'tcx, T: Copy>(cx: &CodegenCx<'tcx>, xs: &[T]) -> &'tcx [T] {
if xs.is_empty() {
&[]
} else {
cx.tcx.arena.dropless.alloc_slice(xs)
}
}
match self {
// FIXME(eddyb) these are all noop cases, could they be automated?
SpirvConst::U32(v) => SpirvConst::U32(v),
SpirvConst::U64(v) => SpirvConst::U64(v),
SpirvConst::F32(v) => SpirvConst::F32(v),
SpirvConst::F64(v) => SpirvConst::F64(v),
SpirvConst::Bool(v) => SpirvConst::Bool(v),
SpirvConst::Null => SpirvConst::Null,
SpirvConst::Undef => SpirvConst::Undef,
SpirvConst::ZombieUndefForFnAddr => SpirvConst::ZombieUndefForFnAddr,
SpirvConst::PtrTo { pointee } => SpirvConst::PtrTo { pointee },
SpirvConst::Composite(fields) => SpirvConst::Composite(arena_alloc_slice(cx, fields)),
}
}
}
#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
struct WithType<V> {
ty: Word,
@ -317,22 +349,22 @@ pub struct BuilderCursor {
pub block: Option<usize>,
}
pub struct BuilderSpirv {
pub struct BuilderSpirv<'tcx> {
builder: RefCell<Builder>,
// Bidirectional maps between `SpirvConst` and the ID of the defined global
// (e.g. `OpConstant...`) instruction.
// NOTE(eddyb) both maps have `WithConstLegality` around their keys, which
// allows getting that legality information without additional lookups.
const_to_id: RefCell<FxHashMap<WithType<SpirvConst>, WithConstLegality<Word>>>,
id_to_const: RefCell<FxHashMap<Word, WithConstLegality<SpirvConst>>>,
const_to_id: RefCell<FxHashMap<WithType<SpirvConst<'tcx>>, WithConstLegality<Word>>>,
id_to_const: RefCell<FxHashMap<Word, WithConstLegality<SpirvConst<'tcx>>>>,
string_cache: RefCell<FxHashMap<String, Word>>,
enabled_capabilities: FxHashSet<Capability>,
enabled_extensions: FxHashSet<Symbol>,
}
impl BuilderSpirv {
impl<'tcx> BuilderSpirv<'tcx> {
pub fn new(sym: &Symbols, target: &SpirvTarget, features: &[TargetFeature]) -> Self {
let version = target.spirv_version();
let memory_model = target.memory_model();
@ -457,7 +489,12 @@ impl BuilderSpirv {
bug!("Function not found: {}", id);
}
pub fn def_constant(&self, ty: Word, val: SpirvConst) -> SpirvValue {
pub(crate) fn def_constant_cx(
&self,
ty: Word,
val: SpirvConst<'_>,
cx: &CodegenCx<'tcx>,
) -> SpirvValue {
let val_with_type = WithType { ty, val };
let mut builder = self.builder(BuilderCursor::default());
if let Some(entry) = self.const_to_id.borrow().get(&val_with_type) {
@ -486,7 +523,7 @@ impl BuilderSpirv {
SpirvConst::Null => builder.constant_null(ty),
SpirvConst::Undef | SpirvConst::ZombieUndefForFnAddr => builder.undef(ty, None),
SpirvConst::Composite(ref v) => builder.constant_composite(ty, v.iter().copied()),
SpirvConst::Composite(v) => builder.constant_composite(ty, v.iter().copied()),
SpirvConst::PtrTo { pointee } => {
builder.variable(ty, None, StorageClass::Private, Some(pointee))
@ -517,7 +554,7 @@ impl BuilderSpirv {
Ok(())
}
SpirvConst::Composite(ref v) => v.iter().fold(Ok(()), |composite_legal, field| {
SpirvConst::Composite(v) => v.iter().fold(Ok(()), |composite_legal, field| {
let field_entry = &self.id_to_const.borrow()[field];
let field_legal_in_composite = field_entry.legal.and(
// `field` is itself some legal `SpirvConst`, but can we have
@ -556,14 +593,11 @@ impl BuilderSpirv {
}
},
};
let val = val.tcx_arena_alloc_slices(cx);
assert_matches!(
self.const_to_id.borrow_mut().insert(
WithType {
ty,
val: val.clone()
},
WithConstLegality { val: id, legal }
),
self.const_to_id
.borrow_mut()
.insert(WithType { ty, val }, WithConstLegality { val: id, legal }),
None
);
assert_matches!(
@ -581,10 +615,10 @@ impl BuilderSpirv {
SpirvValue { kind, ty }
}
pub fn lookup_const(&self, def: SpirvValue) -> Option<SpirvConst> {
pub fn lookup_const(&self, def: SpirvValue) -> Option<SpirvConst<'tcx>> {
match def.kind {
SpirvValueKind::Def(id) | SpirvValueKind::IllegalConst(id) => {
Some(self.id_to_const.borrow().get(&id)?.val.clone())
Some(self.id_to_const.borrow().get(&id)?.val)
}
_ => None,
}

View File

@ -12,34 +12,38 @@ use rustc_span::{Span, DUMMY_SP};
use rustc_target::abi::{self, AddressSpace, HasDataLayout, Integer, Primitive, Size};
impl<'tcx> CodegenCx<'tcx> {
pub fn def_constant(&self, ty: Word, val: SpirvConst<'_>) -> SpirvValue {
self.builder.def_constant_cx(ty, val, self)
}
pub fn constant_u8(&self, span: Span, val: u8) -> SpirvValue {
let ty = SpirvType::Integer(8, false).def(span, self);
self.builder.def_constant(ty, SpirvConst::U32(val as u32))
self.def_constant(ty, SpirvConst::U32(val as u32))
}
pub fn constant_i16(&self, span: Span, val: i16) -> SpirvValue {
let ty = SpirvType::Integer(16, true).def(span, self);
self.builder.def_constant(ty, SpirvConst::U32(val as u32))
self.def_constant(ty, SpirvConst::U32(val as u32))
}
pub fn constant_u16(&self, span: Span, val: u16) -> SpirvValue {
let ty = SpirvType::Integer(16, false).def(span, self);
self.builder.def_constant(ty, SpirvConst::U32(val as u32))
self.def_constant(ty, SpirvConst::U32(val as u32))
}
pub fn constant_i32(&self, span: Span, val: i32) -> SpirvValue {
let ty = SpirvType::Integer(32, true).def(span, self);
self.builder.def_constant(ty, SpirvConst::U32(val as u32))
self.def_constant(ty, SpirvConst::U32(val as u32))
}
pub fn constant_u32(&self, span: Span, val: u32) -> SpirvValue {
let ty = SpirvType::Integer(32, false).def(span, self);
self.builder.def_constant(ty, SpirvConst::U32(val))
self.def_constant(ty, SpirvConst::U32(val))
}
pub fn constant_u64(&self, span: Span, val: u64) -> SpirvValue {
let ty = SpirvType::Integer(64, false).def(span, self);
self.builder.def_constant(ty, SpirvConst::U64(val))
self.def_constant(ty, SpirvConst::U64(val))
}
pub fn constant_int(&self, ty: Word, val: u64) -> SpirvValue {
@ -47,7 +51,7 @@ impl<'tcx> CodegenCx<'tcx> {
SpirvType::Integer(bits @ 8..=32, signed) => {
let size = Size::from_bits(bits);
let val = val as u128;
self.builder.def_constant(
self.def_constant(
ty,
SpirvConst::U32(if signed {
size.sign_extend(val)
@ -56,9 +60,9 @@ impl<'tcx> CodegenCx<'tcx> {
} as u32),
)
}
SpirvType::Integer(64, _) => self.builder.def_constant(ty, SpirvConst::U64(val)),
SpirvType::Integer(64, _) => self.def_constant(ty, SpirvConst::U64(val)),
SpirvType::Bool => match val {
0 | 1 => self.builder.def_constant(ty, SpirvConst::Bool(val != 0)),
0 | 1 => self.def_constant(ty, SpirvConst::Bool(val != 0)),
_ => self
.tcx
.sess
@ -78,24 +82,18 @@ impl<'tcx> CodegenCx<'tcx> {
pub fn constant_f32(&self, span: Span, val: f32) -> SpirvValue {
let ty = SpirvType::Float(32).def(span, self);
self.builder
.def_constant(ty, SpirvConst::F32(val.to_bits()))
self.def_constant(ty, SpirvConst::F32(val.to_bits()))
}
pub fn constant_f64(&self, span: Span, val: f64) -> SpirvValue {
let ty = SpirvType::Float(64).def(span, self);
self.builder
.def_constant(ty, SpirvConst::F64(val.to_bits()))
self.def_constant(ty, SpirvConst::F64(val.to_bits()))
}
pub fn constant_float(&self, ty: Word, val: f64) -> SpirvValue {
match self.lookup_type(ty) {
SpirvType::Float(32) => self
.builder
.def_constant(ty, SpirvConst::F32((val as f32).to_bits())),
SpirvType::Float(64) => self
.builder
.def_constant(ty, SpirvConst::F64(val.to_bits())),
SpirvType::Float(32) => self.def_constant(ty, SpirvConst::F32((val as f32).to_bits())),
SpirvType::Float(64) => self.def_constant(ty, SpirvConst::F64(val.to_bits())),
other => self.tcx.sess.fatal(&format!(
"constant_float invalid on type {}",
other.debug(ty, self)
@ -105,20 +103,20 @@ impl<'tcx> CodegenCx<'tcx> {
pub fn constant_bool(&self, span: Span, val: bool) -> SpirvValue {
let ty = SpirvType::Bool.def(span, self);
self.builder.def_constant(ty, SpirvConst::Bool(val))
self.def_constant(ty, SpirvConst::Bool(val))
}
pub fn constant_composite(&self, ty: Word, fields: impl Iterator<Item = Word>) -> SpirvValue {
self.builder
.def_constant(ty, SpirvConst::Composite(fields.collect()))
// FIXME(eddyb) use `AccumulateVec`s just like `rustc` itself does.
self.def_constant(ty, SpirvConst::Composite(&fields.collect::<Vec<_>>()))
}
pub fn constant_null(&self, ty: Word) -> SpirvValue {
self.builder.def_constant(ty, SpirvConst::Null)
self.def_constant(ty, SpirvConst::Null)
}
pub fn undef(&self, ty: Word) -> SpirvValue {
self.builder.def_constant(ty, SpirvConst::Undef)
self.def_constant(ty, SpirvConst::Undef)
}
}
@ -172,7 +170,7 @@ impl<'tcx> ConstMethods<'tcx> for CodegenCx<'tcx> {
.spirv_type(DUMMY_SP, self);
// FIXME(eddyb) include the actual byte data.
(
self.builder.def_constant(
self.def_constant(
self.type_ptr_to(str_ty),
SpirvConst::PtrTo {
pointee: self.undef(str_ty).def_cx(self),
@ -183,14 +181,15 @@ impl<'tcx> ConstMethods<'tcx> for CodegenCx<'tcx> {
}
fn const_struct(&self, elts: &[Self::Value], _packed: bool) -> Self::Value {
// Presumably this will get bitcasted to the right type?
// FIXME(eddyb) use `AccumulateVec`s just like `rustc` itself does.
let field_types = elts.iter().map(|f| f.ty).collect::<Vec<_>>();
let (field_offsets, size, align) = crate::abi::auto_struct_layout(self, &field_types);
let struct_ty = SpirvType::Adt {
def_id: None,
size,
align,
field_types,
field_offsets,
field_types: &field_types,
field_offsets: &field_offsets,
field_names: None,
}
.def(DUMMY_SP, self);

View File

@ -259,7 +259,7 @@ impl<'tcx> PreDefineMethods<'tcx> for CodegenCx<'tcx> {
impl<'tcx> StaticMethods for CodegenCx<'tcx> {
fn static_addr_of(&self, cv: Self::Value, _align: Align, _kind: Option<&str>) -> Self::Value {
self.builder.def_constant(
self.def_constant(
self.type_ptr_to(cv.ty),
SpirvConst::PtrTo {
pointee: cv.def_cx(self),

View File

@ -125,7 +125,7 @@ impl<'tcx> CodegenCx<'tcx> {
let void = SpirvType::Void.def(span, self);
let fn_void_void = SpirvType::Function {
return_type: void,
arguments: vec![],
arguments: &[],
}
.def(span, self);
let mut emit = self.emit_global();
@ -666,7 +666,7 @@ impl<'tcx> CodegenCx<'tcx> {
SpirvType::Bool => *has_bool = true,
SpirvType::Integer(_, _) | SpirvType::Float(64) => *must_be_flat = true,
SpirvType::Adt { field_types, .. } => {
for f in field_types {
for &f in field_types {
recurse(cx, f, has_bool, must_be_flat);
}
}
@ -683,7 +683,7 @@ impl<'tcx> CodegenCx<'tcx> {
arguments,
} => {
recurse(cx, return_type, has_bool, must_be_flat);
for a in arguments {
for &a in arguments {
recurse(cx, a, has_bool, must_be_flat);
}
}

View File

@ -40,7 +40,7 @@ pub struct CodegenCx<'tcx> {
pub tcx: TyCtxt<'tcx>,
pub codegen_unit: &'tcx CodegenUnit<'tcx>,
/// Spir-v module builder
pub builder: BuilderSpirv,
pub builder: BuilderSpirv<'tcx>,
/// Map from MIR function to spir-v function ID
pub instances: RefCell<FxHashMap<Instance<'tcx>, SpirvValue>>,
/// Map from function ID to parameter list
@ -144,14 +144,12 @@ impl<'tcx> CodegenCx<'tcx> {
self.builder.builder(cursor)
}
// FIXME(eddyb) this should not clone the `SpirvType` back out, but we'd have
// to fix ~100 callsites (to add a `*` deref), if we change this signature.
#[track_caller]
pub fn lookup_type(&self, ty: Word) -> SpirvType {
(*self.type_cache.lookup(ty)).clone()
pub fn lookup_type(&self, ty: Word) -> SpirvType<'tcx> {
self.type_cache.lookup(ty)
}
pub fn debug_type<'cx>(&'cx self, ty: Word) -> SpirvTypePrinter<'cx, 'tcx> {
pub fn debug_type(&self, ty: Word) -> SpirvTypePrinter<'_, 'tcx> {
self.lookup_type(ty).debug(ty, self)
}
@ -537,8 +535,7 @@ impl<'tcx> MiscMethods<'tcx> for CodegenCx<'tcx> {
if self.is_system_crate() {
// Create these undefs up front instead of on demand in SpirvValue::def because
// SpirvValue::def can't use cx.emit()
self.builder
.def_constant(ty, SpirvConst::ZombieUndefForFnAddr);
self.def_constant(ty, SpirvConst::ZombieUndefForFnAddr);
}
SpirvValue {

View File

@ -180,18 +180,19 @@ impl<'tcx> BaseTypeMethods<'tcx> for CodegenCx<'tcx> {
fn type_func(&self, args: &[Self::Type], ret: Self::Type) -> Self::Type {
SpirvType::Function {
return_type: ret,
arguments: args.to_vec(),
arguments: args,
}
.def(DUMMY_SP, self)
}
fn type_struct(&self, els: &[Self::Type], _packed: bool) -> Self::Type {
// FIXME(eddyb) use `AccumulateVec`s just like `rustc` itself does.
let (field_offsets, size, align) = crate::abi::auto_struct_layout(self, els);
SpirvType::Adt {
def_id: None,
align,
size,
field_types: els.to_vec(),
field_offsets,
field_types: els,
field_offsets: &field_offsets,
field_names: None,
}
.def(DUMMY_SP, self)

View File

@ -7,12 +7,11 @@ use rspirv::spirv::{Capability, Decoration, Dim, ImageFormat, StorageClass, Word
use rustc_data_structures::fx::FxHashMap;
use rustc_middle::span_bug;
use rustc_span::def_id::DefId;
use rustc_span::Span;
use rustc_span::{Span, Symbol};
use rustc_target::abi::{Align, Size};
use std::cell::RefCell;
use std::fmt;
use std::iter;
use std::rc::Rc;
use std::sync::{LazyLock, Mutex};
/// Spir-v types are represented as simple Words, which are the `result_id` of instructions like
@ -21,8 +20,10 @@ use std::sync::{LazyLock, Mutex};
/// information. All types that are emitted are registered in `CodegenCx`, so you can always look
/// up the definition of a `Word` via `cx.lookup_type`. Note that this type doesn't actually store
/// the `result_id` of the type declaration instruction, merely the contents.
#[derive(Clone, Debug, PartialEq, Eq, Ord, PartialOrd, Hash)]
pub enum SpirvType {
//
// FIXME(eddyb) should `SpirvType`s be behind `&'tcx` from `tcx.arena.dropless`?
#[derive(Copy, Clone, Debug, PartialEq, Eq, Ord, PartialOrd, Hash)]
pub enum SpirvType<'tcx> {
Void,
Bool,
Integer(u32, bool),
@ -36,9 +37,9 @@ pub enum SpirvType {
align: Align,
size: Option<Size>,
field_types: Vec<Word>,
field_offsets: Vec<Size>,
field_names: Option<Vec<String>>,
field_types: &'tcx [Word],
field_offsets: &'tcx [Size],
field_names: Option<&'tcx [Symbol]>,
},
Vector {
element: Word,
@ -63,7 +64,7 @@ pub enum SpirvType {
},
Function {
return_type: Word,
arguments: Vec<Word>,
arguments: &'tcx [Word],
},
Image {
sampled_type: Word,
@ -89,7 +90,7 @@ pub enum SpirvType {
RayQueryKhr,
}
impl SpirvType {
impl SpirvType<'_> {
/// Note: `Builder::type_*` should be called *nowhere else* but here, to ensure
/// `CodegenCx::type_defs` stays up-to-date
pub fn def(self, def_span: Span, cx: &CodegenCx<'_>) -> Word {
@ -149,9 +150,9 @@ impl SpirvType {
def_id: _,
align: _,
size: _,
ref field_types,
ref field_offsets,
ref field_names,
field_types,
field_offsets,
field_names,
} => {
let mut emit = cx.emit_global();
let result = emit.type_struct_id(id, field_types.iter().cloned());
@ -168,7 +169,7 @@ impl SpirvType {
}
if let Some(field_names) = field_names {
for (index, field_name) in field_names.iter().enumerate() {
emit.member_name(result, index as u32, field_name);
emit.member_name(result, index as u32, field_name.as_str());
}
}
result
@ -215,7 +216,7 @@ impl SpirvType {
.emit_global()
.type_pointer(id, StorageClass::Generic, pointee);
// no pointers to functions
if let Self::Function { .. } = cx.lookup_type(pointee) {
if let SpirvType::Function { .. } = cx.lookup_type(pointee) {
cx.zombie_even_in_user_code(
result,
def_span,
@ -226,7 +227,7 @@ impl SpirvType {
}
Self::Function {
return_type,
ref arguments,
arguments,
} => cx
.emit_global()
.type_function_id(id, return_type, arguments.iter().cloned()),
@ -271,7 +272,7 @@ impl SpirvType {
result
}
};
cx.type_cache_def(result, self, def_span);
cx.type_cache_def(result, self.tcx_arena_alloc_slices(cx), def_span);
result
}
@ -291,7 +292,7 @@ impl SpirvType {
cx.emit_global()
.type_pointer(Some(id), StorageClass::Generic, pointee);
// no pointers to functions
if let Self::Function { .. } = cx.lookup_type(pointee) {
if let SpirvType::Function { .. } = cx.lookup_type(pointee) {
cx.zombie_even_in_user_code(
result,
def_span,
@ -305,7 +306,7 @@ impl SpirvType {
.sess
.fatal(&format!("def_with_id invalid for type {:?}", other)),
};
cx.type_cache_def(result, self, def_span);
cx.type_cache_def(result, self.tcx_arena_alloc_slices(cx), def_span);
result
}
@ -327,17 +328,7 @@ impl SpirvType {
id
}
/// Use this if you want a pretty type printing that recursively prints the types within (e.g. struct fields)
pub fn debug<'cx, 'tcx>(
// FIXME(eddyb) don't require a clone of the `SpirvType` here.
self,
id: Word,
cx: &'cx CodegenCx<'tcx>,
) -> SpirvTypePrinter<'cx, 'tcx> {
SpirvTypePrinter { ty: self, id, cx }
}
pub fn sizeof<'tcx>(&self, cx: &CodegenCx<'tcx>) -> Option<Size> {
pub fn sizeof(&self, cx: &CodegenCx<'_>) -> Option<Size> {
let result = match *self {
// Types that have a dynamic size, or no concept of size at all.
Self::Void | Self::RuntimeArray { .. } | Self::Function { .. } => return None,
@ -364,7 +355,7 @@ impl SpirvType {
Some(result)
}
pub fn alignof<'tcx>(&self, cx: &CodegenCx<'tcx>) -> Align {
pub fn alignof(&self, cx: &CodegenCx<'_>) -> Align {
match *self {
// Types that have no concept of size or alignment.
Self::Void | Self::Function { .. } => Align::from_bytes(0).unwrap(),
@ -392,13 +383,95 @@ impl SpirvType {
Self::InterfaceBlock { inner_type } => cx.lookup_type(inner_type).alignof(cx),
}
}
/// Replace `&[T]` fields with `&'tcx [T]` ones produced by calling
/// `tcx.arena.dropless.alloc_slice(...)` - this is done late for two reasons:
/// 1. it avoids allocating in the arena when the cache would be hit anyway,
/// which would create "garbage" (as in, unreachable allocations)
/// (ideally these would also be interned, but that's even more refactors)
/// 2. an empty slice is disallowed (as it's usually handled as a special
/// case elsewhere, e.g. `rustc`'s `ty::List` - sadly we can't use that)
fn tcx_arena_alloc_slices<'tcx>(self, cx: &CodegenCx<'tcx>) -> SpirvType<'tcx> {
fn arena_alloc_slice<'tcx, T: Copy>(cx: &CodegenCx<'tcx>, xs: &[T]) -> &'tcx [T] {
if xs.is_empty() {
&[]
} else {
cx.tcx.arena.dropless.alloc_slice(xs)
}
}
match self {
// FIXME(eddyb) these are all noop cases, could they be automated?
SpirvType::Void => SpirvType::Void,
SpirvType::Bool => SpirvType::Bool,
SpirvType::Integer(width, signedness) => SpirvType::Integer(width, signedness),
SpirvType::Float(width) => SpirvType::Float(width),
SpirvType::Vector { element, count } => SpirvType::Vector { element, count },
SpirvType::Matrix { element, count } => SpirvType::Matrix { element, count },
SpirvType::Array { element, count } => SpirvType::Array { element, count },
SpirvType::RuntimeArray { element } => SpirvType::RuntimeArray { element },
SpirvType::Pointer { pointee } => SpirvType::Pointer { pointee },
SpirvType::Image {
sampled_type,
dim,
depth,
arrayed,
multisampled,
sampled,
image_format,
} => SpirvType::Image {
sampled_type,
dim,
depth,
arrayed,
multisampled,
sampled,
image_format,
},
SpirvType::Sampler => SpirvType::Sampler,
SpirvType::SampledImage { image_type } => SpirvType::SampledImage { image_type },
SpirvType::InterfaceBlock { inner_type } => SpirvType::InterfaceBlock { inner_type },
SpirvType::AccelerationStructureKhr => SpirvType::AccelerationStructureKhr,
SpirvType::RayQueryKhr => SpirvType::RayQueryKhr,
// Only these variants have any slices to arena-allocate.
SpirvType::Adt {
def_id,
align,
size,
field_types,
field_offsets,
field_names,
} => SpirvType::Adt {
def_id,
align,
size,
field_types: arena_alloc_slice(cx, field_types),
field_offsets: arena_alloc_slice(cx, field_offsets),
field_names: field_names.map(|field_names| arena_alloc_slice(cx, field_names)),
},
SpirvType::Function {
return_type,
arguments,
} => SpirvType::Function {
return_type,
arguments: arena_alloc_slice(cx, arguments),
},
}
}
}
pub struct SpirvTypePrinter<'cx, 'tcx> {
impl<'a> SpirvType<'a> {
/// Use this if you want a pretty type printing that recursively prints the types within (e.g. struct fields)
pub fn debug<'tcx>(self, id: Word, cx: &'a CodegenCx<'tcx>) -> SpirvTypePrinter<'a, 'tcx> {
SpirvTypePrinter { ty: self, id, cx }
}
}
pub struct SpirvTypePrinter<'a, 'tcx> {
id: Word,
// FIXME(eddyb) don't require a clone of the `SpirvType` here.
ty: SpirvType,
cx: &'cx CodegenCx<'tcx>,
ty: SpirvType<'a>,
cx: &'a CodegenCx<'tcx>,
}
/// Types can be recursive, e.g. a struct can contain a pointer to itself. So, we need to keep
@ -434,9 +507,9 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_> {
def_id,
align,
size,
ref field_types,
ref field_offsets,
ref field_names,
field_types,
field_offsets,
field_names,
} => {
let fields = field_types
.iter()
@ -448,8 +521,8 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_> {
.field("align", &align)
.field("size", &size)
.field("field_types", &fields)
.field("field_offsets", field_offsets)
.field("field_names", field_names)
.field("field_offsets", &field_offsets)
.field("field_names", &field_names)
.finish()
}
SpirvType::Vector { element, count } => f
@ -489,7 +562,7 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_> {
.finish(),
SpirvType::Function {
return_type,
ref arguments,
arguments,
} => {
let args = arguments
.iter()
@ -582,7 +655,7 @@ impl SpirvTypePrinter<'_, '_> {
def_id: _,
align: _,
size: _,
ref field_types,
field_types,
field_offsets: _,
ref field_names,
} => {
@ -641,7 +714,7 @@ impl SpirvTypePrinter<'_, '_> {
}
SpirvType::Function {
return_type,
ref arguments,
arguments,
} => {
f.write_str("fn(")?;
for (index, &arg) in arguments.iter().enumerate() {
@ -692,9 +765,8 @@ impl SpirvTypePrinter<'_, '_> {
#[derive(Default)]
pub struct TypeCache<'tcx> {
// FIXME(eddyb) use `tcx.arena.dropless` to get `&'tcx _`, instead of `Rc`.
pub id_to_spirv_type: RefCell<FxHashMap<Word, Rc<SpirvType>>>,
pub spirv_type_to_id: RefCell<FxHashMap<Rc<SpirvType>, Word>>,
pub id_to_spirv_type: RefCell<FxHashMap<Word, SpirvType<'tcx>>>,
pub spirv_type_to_id: RefCell<FxHashMap<SpirvType<'tcx>, Word>>,
/// Recursive pointer breaking
pub recursive_pointee_cache: RecursivePointeeCache<'tcx>,
@ -704,54 +776,40 @@ pub struct TypeCache<'tcx> {
type_names: RefCell<FxHashMap<Word, IndexSet<TyLayoutNameKey<'tcx>>>>,
}
impl TypeCache<'_> {
fn get(&self, ty: &SpirvType) -> Option<Word> {
impl<'tcx> TypeCache<'tcx> {
fn get(&self, ty: &SpirvType<'_>) -> Option<Word> {
self.spirv_type_to_id.borrow().get(ty).copied()
}
#[track_caller]
pub fn lookup(&self, id: Word) -> Rc<SpirvType> {
self.id_to_spirv_type
pub fn lookup(&self, id: Word) -> SpirvType<'tcx> {
*self
.id_to_spirv_type
.borrow()
.get(&id)
.expect("tried to lookup ID that wasn't a type, or has no definition")
.clone()
}
}
impl CodegenCx<'_> {
fn type_cache_def(&self, id: Word, ty: SpirvType, def_span: Span) {
let ty = Rc::new(ty);
if let Some(old_ty) = self
.type_cache
.id_to_spirv_type
.borrow_mut()
.insert(id, ty.clone())
{
impl<'tcx> CodegenCx<'tcx> {
fn type_cache_def(&self, id: Word, ty: SpirvType<'tcx>, def_span: Span) {
if let Some(old_ty) = self.type_cache.id_to_spirv_type.borrow_mut().insert(id, ty) {
span_bug!(
def_span,
"SPIR-V type with ID %{id} is being redefined\n\
old type: {old_ty}\n\
new type: {ty}",
// FIXME(eddyb) don't clone `SpirvType`s here.
old_ty = (*old_ty).clone().debug(id, self),
ty = (*ty).clone().debug(id, self)
old_ty = old_ty.debug(id, self),
ty = ty.debug(id, self)
);
}
if let Some(old_id) = self
.type_cache
.spirv_type_to_id
.borrow_mut()
.insert(ty.clone(), id)
{
if let Some(old_id) = self.type_cache.spirv_type_to_id.borrow_mut().insert(ty, id) {
span_bug!(
def_span,
"SPIR-V type is changing IDs (%{old_id} -> %{id}):\n\
{ty}",
// FIXME(eddyb) don't clone a `SpirvType` here.
ty = (*ty).clone().debug(id, self)
ty = ty.debug(id, self)
);
}
}

View File

@ -16,6 +16,7 @@ pub struct Symbols {
// Used by `is_blocklisted_fn`.
pub fmt_decimal: Symbol,
pub discriminant: Symbol,
pub rust_gpu: Symbol,
pub spirv: Symbol,
pub spirv_std: Symbol,
@ -24,9 +25,11 @@ pub struct Symbols {
pub entry_point_name: Symbol,
pub spv_intel_shader_integer_functions2: Symbol,
pub spv_khr_vulkan_memory_model: Symbol,
descriptor_set: Symbol,
binding: Symbol,
input_attachment_index: Symbol,
attributes: FxHashMap<Symbol, SpirvAttribute>,
execution_modes: FxHashMap<Symbol, (ExecutionMode, ExecutionModeExtraDim)>,
pub libm_intrinsics: FxHashMap<Symbol, libm_intrinsics::LibmIntrinsic>,
@ -373,6 +376,7 @@ impl Symbols {
Self {
fmt_decimal: Symbol::intern("fmt_decimal"),
discriminant: Symbol::intern("discriminant"),
rust_gpu: Symbol::intern("rust_gpu"),
spirv: Symbol::intern("spirv"),
spirv_std: Symbol::intern("spirv_std"),
@ -383,9 +387,11 @@ impl Symbols {
"SPV_INTEL_shader_integer_functions2",
),
spv_khr_vulkan_memory_model: Symbol::intern("SPV_KHR_vulkan_memory_model"),
descriptor_set: Symbol::intern("descriptor_set"),
binding: Symbol::intern("binding"),
input_attachment_index: Symbol::intern("input_attachment_index"),
attributes,
execution_modes,
libm_intrinsics,