mirror of
https://github.com/EmbarkStudios/rust-gpu.git
synced 2024-11-21 22:34:34 +00:00
builder_spirv: replace the constants BiHashMap
with two FxHashMap
s.
This commit is contained in:
parent
3a9e7781d5
commit
00fe5a975e
@ -17,7 +17,7 @@ use rustc_middle::ty::Ty;
|
||||
use rustc_span::Span;
|
||||
use rustc_target::abi::{Abi, Align, Scalar, Size};
|
||||
use std::convert::TryInto;
|
||||
use std::iter::empty;
|
||||
use std::iter::{self, empty};
|
||||
use std::ops::Range;
|
||||
|
||||
macro_rules! simple_op {
|
||||
@ -125,7 +125,7 @@ fn memset_dynamic_scalar(
|
||||
.composite_construct(
|
||||
composite_type,
|
||||
None,
|
||||
std::iter::repeat(fill_var).take(byte_width),
|
||||
iter::repeat(fill_var).take(byte_width),
|
||||
)
|
||||
.unwrap();
|
||||
let result_type = if is_float {
|
||||
@ -214,15 +214,18 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
|
||||
let elem_pat = self.memset_const_pattern(&self.lookup_type(element), fill_byte);
|
||||
self.constant_composite(
|
||||
ty.clone().def(self.span(), self),
|
||||
vec![elem_pat; count as usize],
|
||||
iter::repeat(elem_pat).take(count as usize),
|
||||
)
|
||||
.def(self)
|
||||
}
|
||||
SpirvType::Array { element, count } => {
|
||||
let elem_pat = self.memset_const_pattern(&self.lookup_type(element), fill_byte);
|
||||
let count = self.builder.lookup_const_u64(count).unwrap() as usize;
|
||||
self.constant_composite(ty.clone().def(self.span(), self), vec![elem_pat; count])
|
||||
.def(self)
|
||||
self.constant_composite(
|
||||
ty.clone().def(self.span(), self),
|
||||
iter::repeat(elem_pat).take(count),
|
||||
)
|
||||
.def(self)
|
||||
}
|
||||
SpirvType::RuntimeArray { .. } => {
|
||||
self.fatal("memset on runtime arrays not implemented yet")
|
||||
@ -267,7 +270,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
|
||||
.composite_construct(
|
||||
ty.clone().def(self.span(), self),
|
||||
None,
|
||||
std::iter::repeat(elem_pat).take(count),
|
||||
iter::repeat(elem_pat).take(count),
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
@ -277,7 +280,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
|
||||
.composite_construct(
|
||||
ty.clone().def(self.span(), self),
|
||||
None,
|
||||
std::iter::repeat(elem_pat).take(count as usize),
|
||||
iter::repeat(elem_pat).take(count as usize),
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
@ -1791,13 +1794,13 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
|
||||
}
|
||||
.def(self.span(), self);
|
||||
if self.builder.lookup_const(elt).is_some() {
|
||||
self.constant_composite(result_type, vec![elt.def(self); num_elts])
|
||||
self.constant_composite(result_type, iter::repeat(elt.def(self)).take(num_elts))
|
||||
} else {
|
||||
self.emit()
|
||||
.composite_construct(
|
||||
result_type,
|
||||
None,
|
||||
std::iter::repeat(elt.def(self)).take(num_elts),
|
||||
iter::repeat(elt.def(self)).take(num_elts),
|
||||
)
|
||||
.unwrap()
|
||||
.with_type(result_type)
|
||||
|
@ -1,13 +1,14 @@
|
||||
use crate::builder;
|
||||
use crate::codegen_cx::CodegenCx;
|
||||
use crate::spirv_type::SpirvType;
|
||||
use bimap::BiHashMap;
|
||||
use rspirv::dr::{Block, Builder, Module, Operand};
|
||||
use rspirv::spirv::{AddressingModel, Capability, MemoryModel, Op, Word};
|
||||
use rspirv::{binary::Assemble, binary::Disassemble};
|
||||
use rustc_data_structures::fx::FxHashMap;
|
||||
use rustc_middle::bug;
|
||||
use rustc_span::{Span, DUMMY_SP};
|
||||
use std::cell::{RefCell, RefMut};
|
||||
use std::rc::Rc;
|
||||
use std::{fs::File, io::Write, path::Path};
|
||||
|
||||
#[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq, Hash)]
|
||||
@ -178,7 +179,7 @@ pub enum SpirvConst {
|
||||
/// f64 isn't hash, so store bits
|
||||
F64(u64),
|
||||
Bool(bool),
|
||||
Composite(Vec<Word>),
|
||||
Composite(Rc<[Word]>),
|
||||
Null,
|
||||
Undef,
|
||||
}
|
||||
@ -220,7 +221,11 @@ pub struct BuilderCursor {
|
||||
|
||||
pub struct BuilderSpirv {
|
||||
builder: RefCell<Builder>,
|
||||
constants: RefCell<BiHashMap<WithType<SpirvConst>, Word>>,
|
||||
|
||||
// Bidirectional maps between `SpirvConst` and the ID of the defined global
|
||||
// (e.g. `OpConstant...`) instruction.
|
||||
const_to_id: RefCell<FxHashMap<WithType<SpirvConst>, Word>>,
|
||||
id_to_const: RefCell<FxHashMap<Word, SpirvConst>>,
|
||||
}
|
||||
|
||||
impl BuilderSpirv {
|
||||
@ -263,7 +268,8 @@ impl BuilderSpirv {
|
||||
}
|
||||
Self {
|
||||
builder: RefCell::new(builder),
|
||||
constants: Default::default(),
|
||||
const_to_id: Default::default(),
|
||||
id_to_const: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -338,7 +344,7 @@ impl BuilderSpirv {
|
||||
pub fn def_constant(&self, ty: Word, val: SpirvConst) -> SpirvValue {
|
||||
let val_with_type = WithType { ty, val };
|
||||
let mut builder = self.builder(BuilderCursor::default());
|
||||
if let Some(id) = self.constants.borrow_mut().get_by_left(&val_with_type) {
|
||||
if let Some(id) = self.const_to_id.borrow().get(&val_with_type) {
|
||||
return id.with_type(ty);
|
||||
}
|
||||
let id = match val_with_type.val {
|
||||
@ -357,16 +363,22 @@ impl BuilderSpirv {
|
||||
SpirvConst::Null => builder.constant_null(ty),
|
||||
SpirvConst::Undef => builder.undef(ty, None),
|
||||
};
|
||||
self.constants
|
||||
.borrow_mut()
|
||||
.insert_no_overwrite(val_with_type, id)
|
||||
.unwrap();
|
||||
assert_matches!(
|
||||
self.const_to_id
|
||||
.borrow_mut()
|
||||
.insert(val_with_type.clone(), id),
|
||||
None
|
||||
);
|
||||
assert_matches!(
|
||||
self.id_to_const.borrow_mut().insert(id, val_with_type.val),
|
||||
None
|
||||
);
|
||||
id.with_type(ty)
|
||||
}
|
||||
|
||||
pub fn lookup_const(&self, def: SpirvValue) -> Option<SpirvConst> {
|
||||
match def.kind {
|
||||
SpirvValueKind::Def(id) => Some(self.constants.borrow().get_by_right(&id)?.val.clone()),
|
||||
SpirvValueKind::Def(id) => Some(self.id_to_const.borrow().get(&id)?.clone()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
@ -105,8 +105,9 @@ impl<'tcx> CodegenCx<'tcx> {
|
||||
self.builder.def_constant(ty, SpirvConst::Bool(val))
|
||||
}
|
||||
|
||||
pub fn constant_composite(&self, ty: Word, val: Vec<Word>) -> SpirvValue {
|
||||
self.builder.def_constant(ty, SpirvConst::Composite(val))
|
||||
pub fn constant_composite(&self, ty: Word, fields: impl Iterator<Item = Word>) -> SpirvValue {
|
||||
self.builder
|
||||
.def_constant(ty, SpirvConst::Composite(fields.collect()))
|
||||
}
|
||||
|
||||
pub fn constant_null(&self, ty: Word) -> SpirvValue {
|
||||
@ -182,7 +183,7 @@ impl<'tcx> ConstMethods<'tcx> for CodegenCx<'tcx> {
|
||||
field_names: None,
|
||||
}
|
||||
.def(DUMMY_SP, self);
|
||||
self.constant_composite(struct_ty, elts.iter().map(|f| f.def_cx(self)).collect())
|
||||
self.constant_composite(struct_ty, elts.iter().map(|f| f.def_cx(self)))
|
||||
}
|
||||
|
||||
fn const_to_opt_uint(&self, v: Self::Value) -> Option<u64> {
|
||||
@ -448,7 +449,7 @@ impl<'tcx> CodegenCx<'tcx> {
|
||||
"create_const_alloc must consume all bytes of an Allocation after an unsized struct"
|
||||
);
|
||||
}
|
||||
self.constant_composite(ty, values)
|
||||
self.constant_composite(ty, values.into_iter())
|
||||
}
|
||||
SpirvType::Opaque { name } => self.tcx.sess.fatal(&format!(
|
||||
"Cannot create const alloc of type opaque: {}",
|
||||
@ -456,12 +457,10 @@ impl<'tcx> CodegenCx<'tcx> {
|
||||
)),
|
||||
SpirvType::Array { element, count } => {
|
||||
let count = self.builder.lookup_const_u64(count).unwrap() as usize;
|
||||
let values = (0..count)
|
||||
.map(|_| {
|
||||
self.create_const_alloc2(alloc, offset, element)
|
||||
.def_cx(self)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let values = (0..count).map(|_| {
|
||||
self.create_const_alloc2(alloc, offset, element)
|
||||
.def_cx(self)
|
||||
});
|
||||
self.constant_composite(ty, values)
|
||||
}
|
||||
SpirvType::Vector { element, count } => {
|
||||
@ -469,16 +468,15 @@ impl<'tcx> CodegenCx<'tcx> {
|
||||
.sizeof(self)
|
||||
.expect("create_const_alloc: Vectors must be sized");
|
||||
let final_offset = *offset + total_size;
|
||||
let values = (0..count)
|
||||
.map(|_| {
|
||||
self.create_const_alloc2(alloc, offset, element)
|
||||
.def_cx(self)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let values = (0..count).map(|_| {
|
||||
self.create_const_alloc2(alloc, offset, element)
|
||||
.def_cx(self)
|
||||
});
|
||||
let result = self.constant_composite(ty, values);
|
||||
assert!(*offset <= final_offset);
|
||||
// Vectors sometimes have padding at the end (e.g. vec3), skip over it.
|
||||
*offset = final_offset;
|
||||
self.constant_composite(ty, values)
|
||||
result
|
||||
}
|
||||
SpirvType::RuntimeArray { element } => {
|
||||
let mut values = Vec::new();
|
||||
@ -488,7 +486,7 @@ impl<'tcx> CodegenCx<'tcx> {
|
||||
.def_cx(self),
|
||||
);
|
||||
}
|
||||
let result = self.constant_composite(ty, values);
|
||||
let result = self.constant_composite(ty, values.into_iter());
|
||||
// TODO: Figure out how to do this. Compiling the below crashes both clspv *and* llvm-spirv:
|
||||
/*
|
||||
__constant struct A {
|
||||
|
Loading…
Reference in New Issue
Block a user