Fuse OpPtrAccessChain with previous OpAccessChain (#835)

* Fuse OpPtrAccessChain with previous OpAccessChain

* constant-fold add, and change mul to wrap
This commit is contained in:
Ashley Hauck 2021-12-28 11:39:37 +01:00 committed by GitHub
parent 3067848d1c
commit 480cd048e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 84 additions and 34 deletions

View File

@ -641,7 +641,12 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
self.emit().unreachable().unwrap();
}
simple_op! {add, i_add}
simple_op! {
add, i_add,
fold_const {
int(a, b) => a.wrapping_add(b)
}
}
simple_op! {fadd, f_add}
simple_op! {fadd_fast, f_add} // fast=normal
simple_op! {sub, i_sub}
@ -652,7 +657,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
// HACK(eddyb) `rustc_codegen_ssa` relies on `Builder` methods doing
// on-the-fly constant-folding, for e.g. intrinsics that copy memory.
fold_const {
int(a, b) => a * b
int(a, b) => a.wrapping_mul(b)
}
}
simple_op! {fmul, f_mul}

View File

@ -13,7 +13,7 @@ use crate::abi::ConvSpirvType;
use crate::builder_spirv::{BuilderCursor, SpirvValue, SpirvValueExt};
use crate::codegen_cx::CodegenCx;
use crate::spirv_type::SpirvType;
use rspirv::spirv::Word;
use rspirv::spirv::{self, Word};
use rustc_codegen_ssa::mir::operand::OperandValue;
use rustc_codegen_ssa::mir::place::PlaceRef;
use rustc_codegen_ssa::traits::{
@ -98,8 +98,29 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
self.current_span.unwrap_or(DUMMY_SP)
}
// Given an ID, check if it's defined by an OpAccessChain, and if it is, return its ptr/indices
fn find_access_chain(&self, id: spirv::Word) -> Option<(spirv::Word, Vec<spirv::Word>)> {
let emit = self.emit();
let module = emit.module_ref();
let func = &module.functions[emit.selected_function().unwrap()];
let ptr_def_inst = func.all_inst_iter().find(|inst| inst.result_id == Some(id));
if let Some(ptr_def_inst) = ptr_def_inst {
if ptr_def_inst.class.opcode == spirv::Op::AccessChain
|| ptr_def_inst.class.opcode == spirv::Op::InBoundsAccessChain
{
let ptr = ptr_def_inst.operands[0].unwrap_id_ref();
let indices = ptr_def_inst.operands[1..]
.iter()
.map(|op| op.unwrap_id_ref())
.collect::<Vec<spirv::Word>>();
return Some((ptr, indices));
}
}
None
}
pub fn gep_help(
&self,
&mut self,
ty: Word,
ptr: SpirvValue,
indices: &[SpirvValue],
@ -134,42 +155,58 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
pointee: result_pointee_type,
}
.def(self.span(), self);
if self.builder.lookup_const_u64(indices[0]) == Some(0) {
let ptr_id = ptr.def(self);
if let Some((original_ptr, mut original_indices)) = self.find_access_chain(ptr_id) {
// Transform the following:
// OpAccessChain original_ptr [a, b, c]
// OpPtrAccessChain ptr base [d, e, f]
// into
// OpAccessChain original_ptr [a, b, c + base, d, e, f]
// to remove the need for OpPtrAccessChain
let last = original_indices.last_mut().unwrap();
*last = self
.add(last.with_type(indices[0].ty), indices[0])
.def(self);
original_indices.append(&mut result_indices);
let zero = self.constant_int(indices[0].ty, 0);
self.emit_access_chain(
result_type,
original_ptr,
zero,
original_indices,
is_inbounds,
)
} else {
self.emit_access_chain(result_type, ptr_id, indices[0], result_indices, is_inbounds)
}
}
fn emit_access_chain(
&self,
result_type: spirv::Word,
pointer: spirv::Word,
base: SpirvValue,
indices: Vec<spirv::Word>,
is_inbounds: bool,
) -> SpirvValue {
let mut emit = self.emit();
if self.builder.lookup_const_u64(base) == Some(0) {
if is_inbounds {
self.emit()
.in_bounds_access_chain(result_type, None, ptr.def(self), result_indices)
.unwrap()
.with_type(result_type)
emit.in_bounds_access_chain(result_type, None, pointer, indices)
} else {
self.emit()
.access_chain(result_type, None, ptr.def(self), result_indices)
.unwrap()
.with_type(result_type)
emit.access_chain(result_type, None, pointer, indices)
}
.unwrap()
.with_type(result_type)
} else {
let result = if is_inbounds {
self.emit()
.in_bounds_ptr_access_chain(
result_type,
None,
ptr.def(self),
indices[0].def(self),
result_indices,
)
.unwrap()
.with_type(result_type)
emit.in_bounds_ptr_access_chain(result_type, None, pointer, base.def(self), indices)
} else {
self.emit()
.ptr_access_chain(
result_type,
None,
ptr.def(self),
indices[0].def(self),
result_indices,
)
.unwrap()
.with_type(result_type)
};
emit.ptr_access_chain(result_type, None, pointer, base.def(self), indices)
}
.unwrap()
.with_type(result_type);
self.zombie(
result.def(self),
"Cannot offset a pointer to an arbitrary element",

View File

@ -0,0 +1,8 @@
// build-pass
use spirv_std as _;
#[spirv(fragment)]
pub fn main() {
let arr = [0u32; 32];
}