mirror of
https://github.com/EmbarkStudios/rust-gpu.git
synced 2024-11-22 14:56:27 +00:00
Implement type caching
This is prep for implementing field offsets
This commit is contained in:
parent
5ae80d3d57
commit
3a9eea264f
@ -3,7 +3,6 @@ use rspirv::spirv::{StorageClass, Word};
|
||||
use rustc_middle::ty::{layout::TyAndLayout, TyKind};
|
||||
use rustc_target::abi::{Abi, FieldsShape, LayoutOf, Primitive, Scalar, Size};
|
||||
use std::fmt;
|
||||
use std::iter::empty;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Ord, PartialOrd, Hash)]
|
||||
pub enum SpirvType {
|
||||
@ -11,9 +10,6 @@ pub enum SpirvType {
|
||||
Bool,
|
||||
Integer(u32, bool),
|
||||
Float(u32),
|
||||
// TODO: Do we fold this into Adt?
|
||||
/// Zero Sized Type
|
||||
ZST,
|
||||
/// This uses the rustc definition of "adt", i.e. a struct, enum, or union
|
||||
Adt {
|
||||
// TODO: enums/unions
|
||||
@ -40,6 +36,10 @@ pub enum SpirvType {
|
||||
impl SpirvType {
|
||||
/// Note: Builder::type_* should be called *nowhere else* but here, to ensure CodegenCx::type_defs stays up-to-date
|
||||
pub fn def<'spv, 'tcx>(&self, cx: &CodegenCx<'spv, 'tcx>) -> Word {
|
||||
if let Some(&cached) = cx.type_cache.borrow().get(self) {
|
||||
return cached;
|
||||
}
|
||||
//let cached = cx.type_cache.borrow_mut().entry(self);
|
||||
// TODO: rspirv does a linear search to dedupe, probably want to cache here.
|
||||
let result = match *self {
|
||||
SpirvType::Void => cx.emit_global().type_void(),
|
||||
@ -48,7 +48,6 @@ impl SpirvType {
|
||||
.emit_global()
|
||||
.type_int(width, if signedness { 1 } else { 0 }),
|
||||
SpirvType::Float(width) => cx.emit_global().type_float(width),
|
||||
SpirvType::ZST => cx.emit_global().type_struct(empty()),
|
||||
SpirvType::Adt { ref field_types } => {
|
||||
cx.emit_global().type_struct(field_types.iter().cloned())
|
||||
}
|
||||
@ -65,10 +64,23 @@ impl SpirvType {
|
||||
.emit_global()
|
||||
.type_function(return_type, arguments.iter().cloned()),
|
||||
};
|
||||
cx.type_defs
|
||||
.borrow_mut()
|
||||
.entry(result)
|
||||
.or_insert_with(|| self.clone());
|
||||
// Change to expect_none if/when stabilized
|
||||
assert!(
|
||||
cx.type_defs
|
||||
.borrow_mut()
|
||||
.insert(result, self.clone())
|
||||
.is_none(),
|
||||
"type_defs already had entry, caching failed? {:#?}",
|
||||
self.clone().debug(cx)
|
||||
);
|
||||
assert!(
|
||||
cx.type_cache
|
||||
.borrow_mut()
|
||||
.insert(self.clone(), result)
|
||||
.is_none(),
|
||||
"type_cache already had entry, caching failed? {:#?}",
|
||||
self.clone().debug(cx)
|
||||
);
|
||||
result
|
||||
}
|
||||
|
||||
@ -96,7 +108,6 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_, '_> {
|
||||
.field("signedness", &signedness)
|
||||
.finish(),
|
||||
SpirvType::Float(width) => f.debug_struct("Float").field("width", &width).finish(),
|
||||
SpirvType::ZST => f.debug_struct("ZST").finish(),
|
||||
SpirvType::Adt { ref field_types } => {
|
||||
let fields = field_types
|
||||
.iter()
|
||||
@ -155,7 +166,11 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_, '_> {
|
||||
|
||||
pub fn trans_type<'spv, 'tcx>(cx: &CodegenCx<'spv, 'tcx>, ty: TyAndLayout<'tcx>) -> Word {
|
||||
if ty.is_zst() {
|
||||
return SpirvType::ZST.def(cx);
|
||||
// An empty struct is zero-sized
|
||||
return SpirvType::Adt {
|
||||
field_types: Vec::new(),
|
||||
}
|
||||
.def(cx);
|
||||
}
|
||||
|
||||
// Note: ty.abi is orthogonal to ty.variants and ty.fields, e.g. `ManuallyDrop<Result<isize, isize>>`
|
||||
|
@ -16,10 +16,22 @@ use rustc_target::abi::{Abi, Align, Size};
|
||||
use std::iter::empty;
|
||||
use std::ops::Range;
|
||||
|
||||
macro_rules! assert_ty_eq {
|
||||
($codegen_cx:expr, $left:expr, $right:expr) => {
|
||||
assert_eq!(
|
||||
$left,
|
||||
$right,
|
||||
"Expected types to be equal:\n{:#?}\n==\n{:#?}",
|
||||
$codegen_cx.debug_type($left),
|
||||
$codegen_cx.debug_type($right)
|
||||
)
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! simple_op {
|
||||
($func_name:ident, $inst_name:ident) => {
|
||||
fn $func_name(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value {
|
||||
assert_eq!(lhs.ty, rhs.ty);
|
||||
assert_ty_eq!(self, lhs.ty, rhs.ty);
|
||||
let result_type = lhs.ty;
|
||||
self.emit()
|
||||
.$inst_name(result_type, None, lhs.def, rhs.def)
|
||||
@ -313,7 +325,7 @@ impl<'a, 'spv, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'spv, 'tcx> {
|
||||
} => pointee,
|
||||
ty => panic!("store called on variable that wasn't a pointer: {:?}", ty),
|
||||
};
|
||||
assert_eq!(ptr_elem_ty, val.ty);
|
||||
assert_ty_eq!(self, ptr_elem_ty, val.ty);
|
||||
self.emit().store(ptr.def, val.def, None, empty()).unwrap();
|
||||
val
|
||||
}
|
||||
@ -507,7 +519,7 @@ impl<'a, 'spv, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'spv, 'tcx> {
|
||||
// TODO: Do we want to assert signedness matches the opcode? Is it possible to have one that doesn't match? Does
|
||||
// spir-v allow nonmatching instructions?
|
||||
use IntPredicate::*;
|
||||
assert_eq!(lhs.ty, rhs.ty);
|
||||
assert_ty_eq!(self, lhs.ty, rhs.ty);
|
||||
let b = SpirvType::Bool.def(self);
|
||||
let mut e = self.emit();
|
||||
match op {
|
||||
@ -528,7 +540,7 @@ impl<'a, 'spv, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'spv, 'tcx> {
|
||||
|
||||
fn fcmp(&mut self, op: RealPredicate, lhs: Self::Value, rhs: Self::Value) -> Self::Value {
|
||||
use RealPredicate::*;
|
||||
assert_eq!(lhs.ty, rhs.ty);
|
||||
assert_ty_eq!(self, lhs.ty, rhs.ty);
|
||||
let b = SpirvType::Bool.def(self);
|
||||
let mut e = self.emit();
|
||||
match op {
|
||||
@ -594,7 +606,7 @@ impl<'a, 'spv, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'spv, 'tcx> {
|
||||
then_val: Self::Value,
|
||||
else_val: Self::Value,
|
||||
) -> Self::Value {
|
||||
assert_eq!(then_val.ty, else_val.ty);
|
||||
assert_ty_eq!(self, then_val.ty, else_val.ty);
|
||||
let result_type = then_val.ty;
|
||||
self.emit()
|
||||
.select(result_type, None, cond.def, then_val.def, else_val.def)
|
||||
@ -659,7 +671,9 @@ impl<'a, 'spv, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'spv, 'tcx> {
|
||||
|
||||
fn insert_value(&mut self, agg_val: Self::Value, elt: Self::Value, idx: u64) -> Self::Value {
|
||||
match self.lookup_type(agg_val.ty) {
|
||||
SpirvType::Adt { field_types } => assert_eq!(field_types[idx as usize], elt.ty),
|
||||
SpirvType::Adt { field_types } => {
|
||||
assert_ty_eq!(self, field_types[idx as usize], elt.ty)
|
||||
}
|
||||
other => panic!("insert_value not implemented on type {:?}", other),
|
||||
};
|
||||
self.emit()
|
||||
|
@ -202,11 +202,25 @@ impl<'a, 'spv, 'tcx> ArgAbiMethods<'tcx> for Builder<'a, 'spv, 'tcx> {
|
||||
|
||||
fn store_arg(
|
||||
&mut self,
|
||||
_arg_abi: &ArgAbi<'tcx, Ty<'tcx>>,
|
||||
_val: Self::Value,
|
||||
_dst: PlaceRef<'tcx, Self::Value>,
|
||||
arg_abi: &ArgAbi<'tcx, Ty<'tcx>>,
|
||||
val: Self::Value,
|
||||
dst: PlaceRef<'tcx, Self::Value>,
|
||||
) {
|
||||
todo!()
|
||||
if arg_abi.is_ignore() {
|
||||
return;
|
||||
}
|
||||
if arg_abi.is_sized_indirect() {
|
||||
OperandValue::Ref(val, None, arg_abi.layout.align.abi).store(self, dst)
|
||||
} else if arg_abi.is_unsized_indirect() {
|
||||
panic!("unsized `ArgAbi` must be handled through `store_fn_arg`");
|
||||
} else if let PassMode::Cast(cast) = arg_abi.mode {
|
||||
panic!(
|
||||
"TODO: PassMode::Cast not implemented yet for store_arg: {:?}",
|
||||
cast
|
||||
);
|
||||
} else {
|
||||
OperandValue::Immediate(val).store(self, dst);
|
||||
}
|
||||
}
|
||||
|
||||
fn arg_memory_ty(&self, arg_abi: &ArgAbi<'tcx, Ty<'tcx>>) -> Self::Type {
|
||||
|
@ -133,6 +133,7 @@ impl BuilderSpirv {
|
||||
|
||||
fn def_constant(&self, ty: Word, val: Operand) -> Word {
|
||||
let mut builder = self.builder.borrow_mut();
|
||||
// TODO: Cache these instead of doing a search.
|
||||
for inst in &builder.module_ref().types_global_values {
|
||||
if inst.class.opcode == Op::Constant
|
||||
&& inst.result_type == Some(ty)
|
||||
|
@ -33,11 +33,18 @@ use std::collections::HashMap;
|
||||
pub struct CodegenCx<'spv, 'tcx> {
|
||||
pub tcx: TyCtxt<'tcx>,
|
||||
pub codegen_unit: &'tcx CodegenUnit<'tcx>,
|
||||
/// Not actually used much, builder is probably what you want (see comment on ModuleSpirv type)
|
||||
pub spirv_module: &'spv ModuleSpirv,
|
||||
/// Spir-v module builder
|
||||
pub builder: BuilderSpirv,
|
||||
/// Map from MIR function to spir-v function ID
|
||||
pub function_defs: RefCell<HashMap<Instance<'tcx>, SpirvValue>>,
|
||||
/// Map from function ID to parameter list
|
||||
pub function_parameter_values: RefCell<HashMap<Word, Vec<SpirvValue>>>,
|
||||
/// Map from ID to structure
|
||||
pub type_defs: RefCell<HashMap<Word, SpirvType>>,
|
||||
/// Inverse of type_defs (used to cache generating types)
|
||||
pub type_cache: RefCell<HashMap<SpirvType, Word>>,
|
||||
}
|
||||
|
||||
impl<'spv, 'tcx> CodegenCx<'spv, 'tcx> {
|
||||
@ -54,6 +61,7 @@ impl<'spv, 'tcx> CodegenCx<'spv, 'tcx> {
|
||||
function_defs: RefCell::new(HashMap::new()),
|
||||
function_parameter_values: RefCell::new(HashMap::new()),
|
||||
type_defs: RefCell::new(HashMap::new()),
|
||||
type_cache: RefCell::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user