builder_spirv: replace the constants BiHashMap with two FxHashMaps.

This commit is contained in:
Eduard-Mihai Burtescu 2021-04-08 14:51:15 +03:00 committed by Eduard-Mihai Burtescu
parent 3a9e7781d5
commit 00fe5a975e
3 changed files with 50 additions and 37 deletions

View File

@ -17,7 +17,7 @@ use rustc_middle::ty::Ty;
use rustc_span::Span;
use rustc_target::abi::{Abi, Align, Scalar, Size};
use std::convert::TryInto;
use std::iter::empty;
use std::iter::{self, empty};
use std::ops::Range;
macro_rules! simple_op {
@ -125,7 +125,7 @@ fn memset_dynamic_scalar(
.composite_construct(
composite_type,
None,
std::iter::repeat(fill_var).take(byte_width),
iter::repeat(fill_var).take(byte_width),
)
.unwrap();
let result_type = if is_float {
@ -214,15 +214,18 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
let elem_pat = self.memset_const_pattern(&self.lookup_type(element), fill_byte);
self.constant_composite(
ty.clone().def(self.span(), self),
vec![elem_pat; count as usize],
iter::repeat(elem_pat).take(count as usize),
)
.def(self)
}
SpirvType::Array { element, count } => {
let elem_pat = self.memset_const_pattern(&self.lookup_type(element), fill_byte);
let count = self.builder.lookup_const_u64(count).unwrap() as usize;
self.constant_composite(ty.clone().def(self.span(), self), vec![elem_pat; count])
.def(self)
self.constant_composite(
ty.clone().def(self.span(), self),
iter::repeat(elem_pat).take(count),
)
.def(self)
}
SpirvType::RuntimeArray { .. } => {
self.fatal("memset on runtime arrays not implemented yet")
@ -267,7 +270,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
.composite_construct(
ty.clone().def(self.span(), self),
None,
std::iter::repeat(elem_pat).take(count),
iter::repeat(elem_pat).take(count),
)
.unwrap()
}
@ -277,7 +280,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
.composite_construct(
ty.clone().def(self.span(), self),
None,
std::iter::repeat(elem_pat).take(count as usize),
iter::repeat(elem_pat).take(count as usize),
)
.unwrap()
}
@ -1791,13 +1794,13 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
}
.def(self.span(), self);
if self.builder.lookup_const(elt).is_some() {
self.constant_composite(result_type, vec![elt.def(self); num_elts])
self.constant_composite(result_type, iter::repeat(elt.def(self)).take(num_elts))
} else {
self.emit()
.composite_construct(
result_type,
None,
std::iter::repeat(elt.def(self)).take(num_elts),
iter::repeat(elt.def(self)).take(num_elts),
)
.unwrap()
.with_type(result_type)

View File

@ -1,13 +1,14 @@
use crate::builder;
use crate::codegen_cx::CodegenCx;
use crate::spirv_type::SpirvType;
use bimap::BiHashMap;
use rspirv::dr::{Block, Builder, Module, Operand};
use rspirv::spirv::{AddressingModel, Capability, MemoryModel, Op, Word};
use rspirv::{binary::Assemble, binary::Disassemble};
use rustc_data_structures::fx::FxHashMap;
use rustc_middle::bug;
use rustc_span::{Span, DUMMY_SP};
use std::cell::{RefCell, RefMut};
use std::rc::Rc;
use std::{fs::File, io::Write, path::Path};
#[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq, Hash)]
@ -178,7 +179,7 @@ pub enum SpirvConst {
/// f64 isn't hash, so store bits
F64(u64),
Bool(bool),
Composite(Vec<Word>),
Composite(Rc<[Word]>),
Null,
Undef,
}
@ -220,7 +221,11 @@ pub struct BuilderCursor {
pub struct BuilderSpirv {
builder: RefCell<Builder>,
constants: RefCell<BiHashMap<WithType<SpirvConst>, Word>>,
// Bidirectional maps between `SpirvConst` and the ID of the defined global
// (e.g. `OpConstant...`) instruction.
const_to_id: RefCell<FxHashMap<WithType<SpirvConst>, Word>>,
id_to_const: RefCell<FxHashMap<Word, SpirvConst>>,
}
impl BuilderSpirv {
@ -263,7 +268,8 @@ impl BuilderSpirv {
}
Self {
builder: RefCell::new(builder),
constants: Default::default(),
const_to_id: Default::default(),
id_to_const: Default::default(),
}
}
@ -338,7 +344,7 @@ impl BuilderSpirv {
pub fn def_constant(&self, ty: Word, val: SpirvConst) -> SpirvValue {
let val_with_type = WithType { ty, val };
let mut builder = self.builder(BuilderCursor::default());
if let Some(id) = self.constants.borrow_mut().get_by_left(&val_with_type) {
if let Some(id) = self.const_to_id.borrow().get(&val_with_type) {
return id.with_type(ty);
}
let id = match val_with_type.val {
@ -357,16 +363,22 @@ impl BuilderSpirv {
SpirvConst::Null => builder.constant_null(ty),
SpirvConst::Undef => builder.undef(ty, None),
};
self.constants
.borrow_mut()
.insert_no_overwrite(val_with_type, id)
.unwrap();
assert_matches!(
self.const_to_id
.borrow_mut()
.insert(val_with_type.clone(), id),
None
);
assert_matches!(
self.id_to_const.borrow_mut().insert(id, val_with_type.val),
None
);
id.with_type(ty)
}
pub fn lookup_const(&self, def: SpirvValue) -> Option<SpirvConst> {
match def.kind {
SpirvValueKind::Def(id) => Some(self.constants.borrow().get_by_right(&id)?.val.clone()),
SpirvValueKind::Def(id) => Some(self.id_to_const.borrow().get(&id)?.clone()),
_ => None,
}
}

View File

@ -105,8 +105,9 @@ impl<'tcx> CodegenCx<'tcx> {
self.builder.def_constant(ty, SpirvConst::Bool(val))
}
pub fn constant_composite(&self, ty: Word, val: Vec<Word>) -> SpirvValue {
self.builder.def_constant(ty, SpirvConst::Composite(val))
pub fn constant_composite(&self, ty: Word, fields: impl Iterator<Item = Word>) -> SpirvValue {
self.builder
.def_constant(ty, SpirvConst::Composite(fields.collect()))
}
pub fn constant_null(&self, ty: Word) -> SpirvValue {
@ -182,7 +183,7 @@ impl<'tcx> ConstMethods<'tcx> for CodegenCx<'tcx> {
field_names: None,
}
.def(DUMMY_SP, self);
self.constant_composite(struct_ty, elts.iter().map(|f| f.def_cx(self)).collect())
self.constant_composite(struct_ty, elts.iter().map(|f| f.def_cx(self)))
}
fn const_to_opt_uint(&self, v: Self::Value) -> Option<u64> {
@ -448,7 +449,7 @@ impl<'tcx> CodegenCx<'tcx> {
"create_const_alloc must consume all bytes of an Allocation after an unsized struct"
);
}
self.constant_composite(ty, values)
self.constant_composite(ty, values.into_iter())
}
SpirvType::Opaque { name } => self.tcx.sess.fatal(&format!(
"Cannot create const alloc of type opaque: {}",
@ -456,12 +457,10 @@ impl<'tcx> CodegenCx<'tcx> {
)),
SpirvType::Array { element, count } => {
let count = self.builder.lookup_const_u64(count).unwrap() as usize;
let values = (0..count)
.map(|_| {
self.create_const_alloc2(alloc, offset, element)
.def_cx(self)
})
.collect::<Vec<_>>();
let values = (0..count).map(|_| {
self.create_const_alloc2(alloc, offset, element)
.def_cx(self)
});
self.constant_composite(ty, values)
}
SpirvType::Vector { element, count } => {
@ -469,16 +468,15 @@ impl<'tcx> CodegenCx<'tcx> {
.sizeof(self)
.expect("create_const_alloc: Vectors must be sized");
let final_offset = *offset + total_size;
let values = (0..count)
.map(|_| {
self.create_const_alloc2(alloc, offset, element)
.def_cx(self)
})
.collect::<Vec<_>>();
let values = (0..count).map(|_| {
self.create_const_alloc2(alloc, offset, element)
.def_cx(self)
});
let result = self.constant_composite(ty, values);
assert!(*offset <= final_offset);
// Vectors sometimes have padding at the end (e.g. vec3), skip over it.
*offset = final_offset;
self.constant_composite(ty, values)
result
}
SpirvType::RuntimeArray { element } => {
let mut values = Vec::new();
@ -488,7 +486,7 @@ impl<'tcx> CodegenCx<'tcx> {
.def_cx(self),
);
}
let result = self.constant_composite(ty, values);
let result = self.constant_composite(ty, values.into_iter());
// TODO: Figure out how to do this. Compiling the below crashes both clspv *and* llvm-spirv:
/*
__constant struct A {