mirror of
https://github.com/EmbarkStudios/rust-gpu.git
synced 2024-11-21 22:34:34 +00:00
Fuse OpPtrAccessChain with previous OpAccessChain (#835)
* Fuse OpPtrAccessChain with previous OpAccessChain * constant-fold add, and change mul to wrap
This commit is contained in:
parent
3067848d1c
commit
480cd048e0
@ -641,7 +641,12 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
|
||||
self.emit().unreachable().unwrap();
|
||||
}
|
||||
|
||||
simple_op! {add, i_add}
|
||||
simple_op! {
|
||||
add, i_add,
|
||||
fold_const {
|
||||
int(a, b) => a.wrapping_add(b)
|
||||
}
|
||||
}
|
||||
simple_op! {fadd, f_add}
|
||||
simple_op! {fadd_fast, f_add} // fast=normal
|
||||
simple_op! {sub, i_sub}
|
||||
@ -652,7 +657,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
|
||||
// HACK(eddyb) `rustc_codegen_ssa` relies on `Builder` methods doing
|
||||
// on-the-fly constant-folding, for e.g. intrinsics that copy memory.
|
||||
fold_const {
|
||||
int(a, b) => a * b
|
||||
int(a, b) => a.wrapping_mul(b)
|
||||
}
|
||||
}
|
||||
simple_op! {fmul, f_mul}
|
||||
|
@ -13,7 +13,7 @@ use crate::abi::ConvSpirvType;
|
||||
use crate::builder_spirv::{BuilderCursor, SpirvValue, SpirvValueExt};
|
||||
use crate::codegen_cx::CodegenCx;
|
||||
use crate::spirv_type::SpirvType;
|
||||
use rspirv::spirv::Word;
|
||||
use rspirv::spirv::{self, Word};
|
||||
use rustc_codegen_ssa::mir::operand::OperandValue;
|
||||
use rustc_codegen_ssa::mir::place::PlaceRef;
|
||||
use rustc_codegen_ssa::traits::{
|
||||
@ -98,8 +98,29 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
|
||||
self.current_span.unwrap_or(DUMMY_SP)
|
||||
}
|
||||
|
||||
// Given an ID, check if it's defined by an OpAccessChain, and if it is, return its ptr/indices
|
||||
fn find_access_chain(&self, id: spirv::Word) -> Option<(spirv::Word, Vec<spirv::Word>)> {
|
||||
let emit = self.emit();
|
||||
let module = emit.module_ref();
|
||||
let func = &module.functions[emit.selected_function().unwrap()];
|
||||
let ptr_def_inst = func.all_inst_iter().find(|inst| inst.result_id == Some(id));
|
||||
if let Some(ptr_def_inst) = ptr_def_inst {
|
||||
if ptr_def_inst.class.opcode == spirv::Op::AccessChain
|
||||
|| ptr_def_inst.class.opcode == spirv::Op::InBoundsAccessChain
|
||||
{
|
||||
let ptr = ptr_def_inst.operands[0].unwrap_id_ref();
|
||||
let indices = ptr_def_inst.operands[1..]
|
||||
.iter()
|
||||
.map(|op| op.unwrap_id_ref())
|
||||
.collect::<Vec<spirv::Word>>();
|
||||
return Some((ptr, indices));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub fn gep_help(
|
||||
&self,
|
||||
&mut self,
|
||||
ty: Word,
|
||||
ptr: SpirvValue,
|
||||
indices: &[SpirvValue],
|
||||
@ -134,42 +155,58 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
|
||||
pointee: result_pointee_type,
|
||||
}
|
||||
.def(self.span(), self);
|
||||
if self.builder.lookup_const_u64(indices[0]) == Some(0) {
|
||||
|
||||
let ptr_id = ptr.def(self);
|
||||
if let Some((original_ptr, mut original_indices)) = self.find_access_chain(ptr_id) {
|
||||
// Transform the following:
|
||||
// OpAccessChain original_ptr [a, b, c]
|
||||
// OpPtrAccessChain ptr base [d, e, f]
|
||||
// into
|
||||
// OpAccessChain original_ptr [a, b, c + base, d, e, f]
|
||||
// to remove the need for OpPtrAccessChain
|
||||
let last = original_indices.last_mut().unwrap();
|
||||
*last = self
|
||||
.add(last.with_type(indices[0].ty), indices[0])
|
||||
.def(self);
|
||||
original_indices.append(&mut result_indices);
|
||||
let zero = self.constant_int(indices[0].ty, 0);
|
||||
self.emit_access_chain(
|
||||
result_type,
|
||||
original_ptr,
|
||||
zero,
|
||||
original_indices,
|
||||
is_inbounds,
|
||||
)
|
||||
} else {
|
||||
self.emit_access_chain(result_type, ptr_id, indices[0], result_indices, is_inbounds)
|
||||
}
|
||||
}
|
||||
|
||||
fn emit_access_chain(
|
||||
&self,
|
||||
result_type: spirv::Word,
|
||||
pointer: spirv::Word,
|
||||
base: SpirvValue,
|
||||
indices: Vec<spirv::Word>,
|
||||
is_inbounds: bool,
|
||||
) -> SpirvValue {
|
||||
let mut emit = self.emit();
|
||||
if self.builder.lookup_const_u64(base) == Some(0) {
|
||||
if is_inbounds {
|
||||
self.emit()
|
||||
.in_bounds_access_chain(result_type, None, ptr.def(self), result_indices)
|
||||
.unwrap()
|
||||
.with_type(result_type)
|
||||
emit.in_bounds_access_chain(result_type, None, pointer, indices)
|
||||
} else {
|
||||
self.emit()
|
||||
.access_chain(result_type, None, ptr.def(self), result_indices)
|
||||
.unwrap()
|
||||
.with_type(result_type)
|
||||
emit.access_chain(result_type, None, pointer, indices)
|
||||
}
|
||||
.unwrap()
|
||||
.with_type(result_type)
|
||||
} else {
|
||||
let result = if is_inbounds {
|
||||
self.emit()
|
||||
.in_bounds_ptr_access_chain(
|
||||
result_type,
|
||||
None,
|
||||
ptr.def(self),
|
||||
indices[0].def(self),
|
||||
result_indices,
|
||||
)
|
||||
.unwrap()
|
||||
.with_type(result_type)
|
||||
emit.in_bounds_ptr_access_chain(result_type, None, pointer, base.def(self), indices)
|
||||
} else {
|
||||
self.emit()
|
||||
.ptr_access_chain(
|
||||
result_type,
|
||||
None,
|
||||
ptr.def(self),
|
||||
indices[0].def(self),
|
||||
result_indices,
|
||||
)
|
||||
.unwrap()
|
||||
.with_type(result_type)
|
||||
};
|
||||
emit.ptr_access_chain(result_type, None, pointer, base.def(self), indices)
|
||||
}
|
||||
.unwrap()
|
||||
.with_type(result_type);
|
||||
self.zombie(
|
||||
result.def(self),
|
||||
"Cannot offset a pointer to an arbitrary element",
|
||||
|
8
tests/ui/lang/consts/issue-834.rs
Normal file
8
tests/ui/lang/consts/issue-834.rs
Normal file
@ -0,0 +1,8 @@
|
||||
// build-pass
|
||||
|
||||
use spirv_std as _;
|
||||
|
||||
#[spirv(fragment)]
|
||||
pub fn main() {
|
||||
let arr = [0u32; 32];
|
||||
}
|
Loading…
Reference in New Issue
Block a user