From 00fe5a975e855eba29d75e1bb83b668c791faed8 Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Thu, 8 Apr 2021 14:51:15 +0300 Subject: [PATCH] builder_spirv: replace the constants `BiHashMap` with two `FxHashMap`s. --- .../src/builder/builder_methods.rs | 21 +++++++----- .../rustc_codegen_spirv/src/builder_spirv.rs | 32 +++++++++++------ .../src/codegen_cx/constant.rs | 34 +++++++++---------- 3 files changed, 50 insertions(+), 37 deletions(-) diff --git a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs index 540fac2567..b190b8316b 100644 --- a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs +++ b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs @@ -17,7 +17,7 @@ use rustc_middle::ty::Ty; use rustc_span::Span; use rustc_target::abi::{Abi, Align, Scalar, Size}; use std::convert::TryInto; -use std::iter::empty; +use std::iter::{self, empty}; use std::ops::Range; macro_rules! simple_op { @@ -125,7 +125,7 @@ fn memset_dynamic_scalar( .composite_construct( composite_type, None, - std::iter::repeat(fill_var).take(byte_width), + iter::repeat(fill_var).take(byte_width), ) .unwrap(); let result_type = if is_float { @@ -214,15 +214,18 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { let elem_pat = self.memset_const_pattern(&self.lookup_type(element), fill_byte); self.constant_composite( ty.clone().def(self.span(), self), - vec![elem_pat; count as usize], + iter::repeat(elem_pat).take(count as usize), ) .def(self) } SpirvType::Array { element, count } => { let elem_pat = self.memset_const_pattern(&self.lookup_type(element), fill_byte); let count = self.builder.lookup_const_u64(count).unwrap() as usize; - self.constant_composite(ty.clone().def(self.span(), self), vec![elem_pat; count]) - .def(self) + self.constant_composite( + ty.clone().def(self.span(), self), + iter::repeat(elem_pat).take(count), + ) + .def(self) } SpirvType::RuntimeArray { .. } => { self.fatal("memset on runtime arrays not implemented yet") @@ -267,7 +270,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { .composite_construct( ty.clone().def(self.span(), self), None, - std::iter::repeat(elem_pat).take(count), + iter::repeat(elem_pat).take(count), ) .unwrap() } @@ -277,7 +280,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { .composite_construct( ty.clone().def(self.span(), self), None, - std::iter::repeat(elem_pat).take(count as usize), + iter::repeat(elem_pat).take(count as usize), ) .unwrap() } @@ -1791,13 +1794,13 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { } .def(self.span(), self); if self.builder.lookup_const(elt).is_some() { - self.constant_composite(result_type, vec![elt.def(self); num_elts]) + self.constant_composite(result_type, iter::repeat(elt.def(self)).take(num_elts)) } else { self.emit() .composite_construct( result_type, None, - std::iter::repeat(elt.def(self)).take(num_elts), + iter::repeat(elt.def(self)).take(num_elts), ) .unwrap() .with_type(result_type) diff --git a/crates/rustc_codegen_spirv/src/builder_spirv.rs b/crates/rustc_codegen_spirv/src/builder_spirv.rs index fd4d225d39..5249fdb8b1 100644 --- a/crates/rustc_codegen_spirv/src/builder_spirv.rs +++ b/crates/rustc_codegen_spirv/src/builder_spirv.rs @@ -1,13 +1,14 @@ use crate::builder; use crate::codegen_cx::CodegenCx; use crate::spirv_type::SpirvType; -use bimap::BiHashMap; use rspirv::dr::{Block, Builder, Module, Operand}; use rspirv::spirv::{AddressingModel, Capability, MemoryModel, Op, Word}; use rspirv::{binary::Assemble, binary::Disassemble}; +use rustc_data_structures::fx::FxHashMap; use rustc_middle::bug; use rustc_span::{Span, DUMMY_SP}; use std::cell::{RefCell, RefMut}; +use std::rc::Rc; use std::{fs::File, io::Write, path::Path}; #[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq, Hash)] @@ -178,7 +179,7 @@ pub enum SpirvConst { /// f64 isn't hash, so store bits F64(u64), Bool(bool), - Composite(Vec), + Composite(Rc<[Word]>), Null, Undef, } @@ -220,7 +221,11 @@ pub struct BuilderCursor { pub struct BuilderSpirv { builder: RefCell, - constants: RefCell, Word>>, + + // Bidirectional maps between `SpirvConst` and the ID of the defined global + // (e.g. `OpConstant...`) instruction. + const_to_id: RefCell, Word>>, + id_to_const: RefCell>, } impl BuilderSpirv { @@ -263,7 +268,8 @@ impl BuilderSpirv { } Self { builder: RefCell::new(builder), - constants: Default::default(), + const_to_id: Default::default(), + id_to_const: Default::default(), } } @@ -338,7 +344,7 @@ impl BuilderSpirv { pub fn def_constant(&self, ty: Word, val: SpirvConst) -> SpirvValue { let val_with_type = WithType { ty, val }; let mut builder = self.builder(BuilderCursor::default()); - if let Some(id) = self.constants.borrow_mut().get_by_left(&val_with_type) { + if let Some(id) = self.const_to_id.borrow().get(&val_with_type) { return id.with_type(ty); } let id = match val_with_type.val { @@ -357,16 +363,22 @@ impl BuilderSpirv { SpirvConst::Null => builder.constant_null(ty), SpirvConst::Undef => builder.undef(ty, None), }; - self.constants - .borrow_mut() - .insert_no_overwrite(val_with_type, id) - .unwrap(); + assert_matches!( + self.const_to_id + .borrow_mut() + .insert(val_with_type.clone(), id), + None + ); + assert_matches!( + self.id_to_const.borrow_mut().insert(id, val_with_type.val), + None + ); id.with_type(ty) } pub fn lookup_const(&self, def: SpirvValue) -> Option { match def.kind { - SpirvValueKind::Def(id) => Some(self.constants.borrow().get_by_right(&id)?.val.clone()), + SpirvValueKind::Def(id) => Some(self.id_to_const.borrow().get(&id)?.clone()), _ => None, } } diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs b/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs index b0a8d4a868..ec7d29e59d 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs @@ -105,8 +105,9 @@ impl<'tcx> CodegenCx<'tcx> { self.builder.def_constant(ty, SpirvConst::Bool(val)) } - pub fn constant_composite(&self, ty: Word, val: Vec) -> SpirvValue { - self.builder.def_constant(ty, SpirvConst::Composite(val)) + pub fn constant_composite(&self, ty: Word, fields: impl Iterator) -> SpirvValue { + self.builder + .def_constant(ty, SpirvConst::Composite(fields.collect())) } pub fn constant_null(&self, ty: Word) -> SpirvValue { @@ -182,7 +183,7 @@ impl<'tcx> ConstMethods<'tcx> for CodegenCx<'tcx> { field_names: None, } .def(DUMMY_SP, self); - self.constant_composite(struct_ty, elts.iter().map(|f| f.def_cx(self)).collect()) + self.constant_composite(struct_ty, elts.iter().map(|f| f.def_cx(self))) } fn const_to_opt_uint(&self, v: Self::Value) -> Option { @@ -448,7 +449,7 @@ impl<'tcx> CodegenCx<'tcx> { "create_const_alloc must consume all bytes of an Allocation after an unsized struct" ); } - self.constant_composite(ty, values) + self.constant_composite(ty, values.into_iter()) } SpirvType::Opaque { name } => self.tcx.sess.fatal(&format!( "Cannot create const alloc of type opaque: {}", @@ -456,12 +457,10 @@ impl<'tcx> CodegenCx<'tcx> { )), SpirvType::Array { element, count } => { let count = self.builder.lookup_const_u64(count).unwrap() as usize; - let values = (0..count) - .map(|_| { - self.create_const_alloc2(alloc, offset, element) - .def_cx(self) - }) - .collect::>(); + let values = (0..count).map(|_| { + self.create_const_alloc2(alloc, offset, element) + .def_cx(self) + }); self.constant_composite(ty, values) } SpirvType::Vector { element, count } => { @@ -469,16 +468,15 @@ impl<'tcx> CodegenCx<'tcx> { .sizeof(self) .expect("create_const_alloc: Vectors must be sized"); let final_offset = *offset + total_size; - let values = (0..count) - .map(|_| { - self.create_const_alloc2(alloc, offset, element) - .def_cx(self) - }) - .collect::>(); + let values = (0..count).map(|_| { + self.create_const_alloc2(alloc, offset, element) + .def_cx(self) + }); + let result = self.constant_composite(ty, values); assert!(*offset <= final_offset); // Vectors sometimes have padding at the end (e.g. vec3), skip over it. *offset = final_offset; - self.constant_composite(ty, values) + result } SpirvType::RuntimeArray { element } => { let mut values = Vec::new(); @@ -488,7 +486,7 @@ impl<'tcx> CodegenCx<'tcx> { .def_cx(self), ); } - let result = self.constant_composite(ty, values); + let result = self.constant_composite(ty, values.into_iter()); // TODO: Figure out how to do this. Compiling the below crashes both clspv *and* llvm-spirv: /* __constant struct A {