Implement type caching

This is prep for implementing field offsets
This commit is contained in:
khyperia 2020-08-25 14:31:30 +02:00
parent 5ae80d3d57
commit 3a9eea264f
5 changed files with 73 additions and 21 deletions

View File

@ -3,7 +3,6 @@ use rspirv::spirv::{StorageClass, Word};
use rustc_middle::ty::{layout::TyAndLayout, TyKind};
use rustc_target::abi::{Abi, FieldsShape, LayoutOf, Primitive, Scalar, Size};
use std::fmt;
use std::iter::empty;
#[derive(Clone, Debug, PartialEq, Eq, Ord, PartialOrd, Hash)]
pub enum SpirvType {
@ -11,9 +10,6 @@ pub enum SpirvType {
Bool,
Integer(u32, bool),
Float(u32),
// TODO: Do we fold this into Adt?
/// Zero Sized Type
ZST,
/// This uses the rustc definition of "adt", i.e. a struct, enum, or union
Adt {
// TODO: enums/unions
@ -40,6 +36,10 @@ pub enum SpirvType {
impl SpirvType {
/// Note: Builder::type_* should be called *nowhere else* but here, to ensure CodegenCx::type_defs stays up-to-date
pub fn def<'spv, 'tcx>(&self, cx: &CodegenCx<'spv, 'tcx>) -> Word {
if let Some(&cached) = cx.type_cache.borrow().get(self) {
return cached;
}
//let cached = cx.type_cache.borrow_mut().entry(self);
// TODO: rspirv does a linear search to dedupe, probably want to cache here.
let result = match *self {
SpirvType::Void => cx.emit_global().type_void(),
@ -48,7 +48,6 @@ impl SpirvType {
.emit_global()
.type_int(width, if signedness { 1 } else { 0 }),
SpirvType::Float(width) => cx.emit_global().type_float(width),
SpirvType::ZST => cx.emit_global().type_struct(empty()),
SpirvType::Adt { ref field_types } => {
cx.emit_global().type_struct(field_types.iter().cloned())
}
@ -65,10 +64,23 @@ impl SpirvType {
.emit_global()
.type_function(return_type, arguments.iter().cloned()),
};
cx.type_defs
.borrow_mut()
.entry(result)
.or_insert_with(|| self.clone());
// Change to expect_none if/when stabilized
assert!(
cx.type_defs
.borrow_mut()
.insert(result, self.clone())
.is_none(),
"type_defs already had entry, caching failed? {:#?}",
self.clone().debug(cx)
);
assert!(
cx.type_cache
.borrow_mut()
.insert(self.clone(), result)
.is_none(),
"type_cache already had entry, caching failed? {:#?}",
self.clone().debug(cx)
);
result
}
@ -96,7 +108,6 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_, '_> {
.field("signedness", &signedness)
.finish(),
SpirvType::Float(width) => f.debug_struct("Float").field("width", &width).finish(),
SpirvType::ZST => f.debug_struct("ZST").finish(),
SpirvType::Adt { ref field_types } => {
let fields = field_types
.iter()
@ -155,7 +166,11 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_, '_> {
pub fn trans_type<'spv, 'tcx>(cx: &CodegenCx<'spv, 'tcx>, ty: TyAndLayout<'tcx>) -> Word {
if ty.is_zst() {
return SpirvType::ZST.def(cx);
// An empty struct is zero-sized
return SpirvType::Adt {
field_types: Vec::new(),
}
.def(cx);
}
// Note: ty.abi is orthogonal to ty.variants and ty.fields, e.g. `ManuallyDrop<Result<isize, isize>>`

View File

@ -16,10 +16,22 @@ use rustc_target::abi::{Abi, Align, Size};
use std::iter::empty;
use std::ops::Range;
macro_rules! assert_ty_eq {
($codegen_cx:expr, $left:expr, $right:expr) => {
assert_eq!(
$left,
$right,
"Expected types to be equal:\n{:#?}\n==\n{:#?}",
$codegen_cx.debug_type($left),
$codegen_cx.debug_type($right)
)
};
}
macro_rules! simple_op {
($func_name:ident, $inst_name:ident) => {
fn $func_name(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value {
assert_eq!(lhs.ty, rhs.ty);
assert_ty_eq!(self, lhs.ty, rhs.ty);
let result_type = lhs.ty;
self.emit()
.$inst_name(result_type, None, lhs.def, rhs.def)
@ -313,7 +325,7 @@ impl<'a, 'spv, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'spv, 'tcx> {
} => pointee,
ty => panic!("store called on variable that wasn't a pointer: {:?}", ty),
};
assert_eq!(ptr_elem_ty, val.ty);
assert_ty_eq!(self, ptr_elem_ty, val.ty);
self.emit().store(ptr.def, val.def, None, empty()).unwrap();
val
}
@ -507,7 +519,7 @@ impl<'a, 'spv, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'spv, 'tcx> {
// TODO: Do we want to assert signedness matches the opcode? Is it possible to have one that doesn't match? Does
// spir-v allow nonmatching instructions?
use IntPredicate::*;
assert_eq!(lhs.ty, rhs.ty);
assert_ty_eq!(self, lhs.ty, rhs.ty);
let b = SpirvType::Bool.def(self);
let mut e = self.emit();
match op {
@ -528,7 +540,7 @@ impl<'a, 'spv, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'spv, 'tcx> {
fn fcmp(&mut self, op: RealPredicate, lhs: Self::Value, rhs: Self::Value) -> Self::Value {
use RealPredicate::*;
assert_eq!(lhs.ty, rhs.ty);
assert_ty_eq!(self, lhs.ty, rhs.ty);
let b = SpirvType::Bool.def(self);
let mut e = self.emit();
match op {
@ -594,7 +606,7 @@ impl<'a, 'spv, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'spv, 'tcx> {
then_val: Self::Value,
else_val: Self::Value,
) -> Self::Value {
assert_eq!(then_val.ty, else_val.ty);
assert_ty_eq!(self, then_val.ty, else_val.ty);
let result_type = then_val.ty;
self.emit()
.select(result_type, None, cond.def, then_val.def, else_val.def)
@ -659,7 +671,9 @@ impl<'a, 'spv, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'spv, 'tcx> {
fn insert_value(&mut self, agg_val: Self::Value, elt: Self::Value, idx: u64) -> Self::Value {
match self.lookup_type(agg_val.ty) {
SpirvType::Adt { field_types } => assert_eq!(field_types[idx as usize], elt.ty),
SpirvType::Adt { field_types } => {
assert_ty_eq!(self, field_types[idx as usize], elt.ty)
}
other => panic!("insert_value not implemented on type {:?}", other),
};
self.emit()

View File

@ -202,11 +202,25 @@ impl<'a, 'spv, 'tcx> ArgAbiMethods<'tcx> for Builder<'a, 'spv, 'tcx> {
fn store_arg(
&mut self,
_arg_abi: &ArgAbi<'tcx, Ty<'tcx>>,
_val: Self::Value,
_dst: PlaceRef<'tcx, Self::Value>,
arg_abi: &ArgAbi<'tcx, Ty<'tcx>>,
val: Self::Value,
dst: PlaceRef<'tcx, Self::Value>,
) {
todo!()
if arg_abi.is_ignore() {
return;
}
if arg_abi.is_sized_indirect() {
OperandValue::Ref(val, None, arg_abi.layout.align.abi).store(self, dst)
} else if arg_abi.is_unsized_indirect() {
panic!("unsized `ArgAbi` must be handled through `store_fn_arg`");
} else if let PassMode::Cast(cast) = arg_abi.mode {
panic!(
"TODO: PassMode::Cast not implemented yet for store_arg: {:?}",
cast
);
} else {
OperandValue::Immediate(val).store(self, dst);
}
}
fn arg_memory_ty(&self, arg_abi: &ArgAbi<'tcx, Ty<'tcx>>) -> Self::Type {

View File

@ -133,6 +133,7 @@ impl BuilderSpirv {
fn def_constant(&self, ty: Word, val: Operand) -> Word {
let mut builder = self.builder.borrow_mut();
// TODO: Cache these instead of doing a search.
for inst in &builder.module_ref().types_global_values {
if inst.class.opcode == Op::Constant
&& inst.result_type == Some(ty)

View File

@ -33,11 +33,18 @@ use std::collections::HashMap;
pub struct CodegenCx<'spv, 'tcx> {
pub tcx: TyCtxt<'tcx>,
pub codegen_unit: &'tcx CodegenUnit<'tcx>,
/// Not actually used much, builder is probably what you want (see comment on ModuleSpirv type)
pub spirv_module: &'spv ModuleSpirv,
/// Spir-v module builder
pub builder: BuilderSpirv,
/// Map from MIR function to spir-v function ID
pub function_defs: RefCell<HashMap<Instance<'tcx>, SpirvValue>>,
/// Map from function ID to parameter list
pub function_parameter_values: RefCell<HashMap<Word, Vec<SpirvValue>>>,
/// Map from ID to structure
pub type_defs: RefCell<HashMap<Word, SpirvType>>,
/// Inverse of type_defs (used to cache generating types)
pub type_cache: RefCell<HashMap<SpirvType, Word>>,
}
impl<'spv, 'tcx> CodegenCx<'spv, 'tcx> {
@ -54,6 +61,7 @@ impl<'spv, 'tcx> CodegenCx<'spv, 'tcx> {
function_defs: RefCell::new(HashMap::new()),
function_parameter_values: RefCell::new(HashMap::new()),
type_defs: RefCell::new(HashMap::new()),
type_cache: RefCell::new(HashMap::new()),
}
}