Implement auto-vectorization (#629)

* Implement auto-vectorization

* Run mem2reg a lot
This commit is contained in:
Ashley Hauck 2021-05-31 11:22:02 +02:00 committed by GitHub
parent fcf6ee76c8
commit 1046f7cdf6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 480 additions and 13 deletions

View File

@ -27,14 +27,14 @@ pub fn mem2reg(
let preds = compute_preds(&func.blocks, &reachable);
let idom = compute_idom(&preds, &reachable);
let dominance_frontier = compute_dominance_frontier(&preds, &idom);
insert_phis_all(
while insert_phis_all(
header,
types_global_values,
pointer_to_pointee,
constants,
&mut func.blocks,
dominance_frontier,
);
&dominance_frontier,
) {}
}
fn label_to_index(blocks: &[Block], id: Word) -> usize {
@ -146,15 +146,16 @@ fn compute_dominance_frontier(
dominance_frontier
}
// Returns true if variables were rewritten
fn insert_phis_all(
header: &mut ModuleHeader,
types_global_values: &mut Vec<Instruction>,
pointer_to_pointee: &FxHashMap<Word, Word>,
constants: &FxHashMap<Word, u32>,
blocks: &mut [Block],
dominance_frontier: Vec<FxHashSet<usize>>,
) {
let thing = blocks[0]
dominance_frontier: &[FxHashSet<usize>],
) -> bool {
let var_maps_and_types = blocks[0]
.instructions
.iter()
.filter(|inst| inst.class.opcode == Op::Variable)
@ -167,8 +168,11 @@ fn insert_phis_all(
))
})
.collect::<Vec<_>>();
for &(ref var_map, base_var_type) in &thing {
let blocks_with_phi = insert_phis(blocks, &dominance_frontier, var_map);
if var_maps_and_types.is_empty() {
return false;
}
for &(ref var_map, base_var_type) in &var_maps_and_types {
let blocks_with_phi = insert_phis(blocks, dominance_frontier, var_map);
let mut renamer = Renamer {
header,
types_global_values,
@ -185,7 +189,8 @@ fn insert_phis_all(
apply_rewrite_rules(&renamer.rewrite_rules, blocks);
remove_nops(blocks);
}
remove_old_variables(blocks, &thing);
remove_old_variables(blocks, &var_maps_and_types);
true
}
#[derive(Debug)]
@ -488,11 +493,14 @@ fn remove_nops(blocks: &mut [Block]) {
}
}
fn remove_old_variables(blocks: &mut [Block], thing: &[(FxHashMap<u32, VarInfo>, u32)]) {
fn remove_old_variables(
blocks: &mut [Block],
var_maps_and_types: &[(FxHashMap<u32, VarInfo>, u32)],
) {
blocks[0].instructions.retain(|inst| {
inst.class.opcode != Op::Variable || {
let result_id = inst.result_id.unwrap();
thing
var_maps_and_types
.iter()
.all(|(var_map, _)| !var_map.contains_key(&result_id))
}
@ -502,7 +510,9 @@ fn remove_old_variables(blocks: &mut [Block], thing: &[(FxHashMap<u32, VarInfo>,
!matches!(inst.class.opcode, Op::AccessChain | Op::InBoundsAccessChain)
|| inst.operands.iter().all(|op| {
op.id_ref_any().map_or(true, |id| {
thing.iter().all(|(var_map, _)| !var_map.contains_key(&id))
var_maps_and_types
.iter()
.all(|(var_map, _)| !var_map.contains_key(&id))
})
})
})

View File

@ -7,6 +7,7 @@ mod import_export_link;
mod inline;
mod mem2reg;
mod new_structurizer;
mod peephole_opts;
mod simple_passes;
mod specializer;
mod structurizer;
@ -242,6 +243,15 @@ pub fn link(sess: &Session, mut inputs: Vec<Module>, opts: &Options) -> Result<L
}
}
{
let _timer = sess.timer("peephole_opts");
let types = peephole_opts::collect_types(&output);
for func in &mut output.functions {
peephole_opts::composite_construct(&types, func);
peephole_opts::vector_ops(output.header.as_mut().unwrap(), &types, func);
}
}
{
let _timer = sess.timer("link_remove_duplicate_lines");
duplicates::remove_duplicate_lines(&mut output);

View File

@ -0,0 +1,447 @@
use rspirv::dr::{Function, Instruction, Module, ModuleHeader, Operand};
use rspirv::spirv::{Op, Word};
use rustc_data_structures::fx::FxHashMap;
use rustc_middle::bug;
pub fn collect_types(module: &Module) -> FxHashMap<Word, Instruction> {
module
.types_global_values
.iter()
.filter_map(|inst| Some((inst.result_id?, inst.clone())))
.collect()
}
fn composite_count(types: &FxHashMap<Word, Instruction>, ty_id: Word) -> Option<usize> {
let ty = types.get(&ty_id)?;
match ty.class.opcode {
Op::TypeStruct => Some(ty.operands.len()),
Op::TypeVector => Some(ty.operands[1].unwrap_literal_int32() as usize),
Op::TypeArray => {
let length_id = ty.operands[1].unwrap_id_ref();
let const_inst = types.get(&length_id)?;
if const_inst.class.opcode != Op::Constant {
return None;
}
let const_ty = types.get(&const_inst.result_type.unwrap())?;
if const_ty.class.opcode != Op::TypeInt {
return None;
}
let const_value = match const_inst.operands[0] {
Operand::LiteralInt32(v) => v as usize,
Operand::LiteralInt64(v) => v as usize,
_ => bug!(),
};
Some(const_value)
}
_ => None,
}
}
/// Given a chain of `OpCompositeInsert` instructions where all slots of the composite are
/// assigned, replace the chain with a single `OpCompositeConstruct`.
pub fn composite_construct(types: &FxHashMap<Word, Instruction>, function: &mut Function) {
let defs = function
.all_inst_iter()
.filter_map(|inst| Some((inst.result_id?, inst.clone())))
.collect::<FxHashMap<Word, Instruction>>();
for block in &mut function.blocks {
for inst in &mut block.instructions {
if inst.class.opcode != Op::CompositeInsert {
continue;
}
// Get the number of components to expect
let component_count = match composite_count(types, inst.result_type.unwrap()) {
Some(c) => c,
None => continue,
};
// Remember a map of index -> value for that index. If any index is missing (None)
// afterwards, then we know not all slots have been filled in, and we should skip
// optimizing this chain.
let mut components = vec![None; component_count];
let mut cur_inst: &Instruction = inst;
// Start looping from the current instruction, through each instruction in the chain.
while cur_inst.class.opcode == Op::CompositeInsert {
if cur_inst.operands.len() != 3 {
// If there's more than one index, skip optimizing this chain.
break;
}
let value = cur_inst.operands[0].unwrap_id_ref();
let index = cur_inst.operands[2].unwrap_literal_int32() as usize;
if index >= components.len() {
// Theoretically shouldn't happen, as it's invalid SPIR-V if the index is out
// of bounds, but just stop optimizing instead of panicing here.
break;
}
components[index] = Some(value);
// Follow back one in the chain of OpCompositeInsert
cur_inst = match defs.get(&cur_inst.operands[1].unwrap_id_ref()) {
Some(i) => i,
None => break,
};
}
// If all components are filled in (collect() returns Some), replace it with
// `OpCompositeConstruct`
if let Some(composite_construct_operands) = components
.into_iter()
.map(|v| v.map(Operand::IdRef))
.collect::<Option<Vec<_>>>()
{
// Leave all the other instructions in the chain as dead code for other passes
// to clean up.
*inst = Instruction::new(
Op::CompositeConstruct,
inst.result_type,
inst.result_id,
composite_construct_operands,
);
}
}
}
}
#[derive(Debug)]
enum IdentifiedOperand {
/// The operand to the vectorized operation is a straight-up vector.
Vector(Word),
/// The operand to the vectorized operation is a collection of scalars that need to be packed
/// together with OpCompositeConstruct before using the vectorized operation.
Scalars(Vec<Word>),
/// The operand to the vectorized operation is some non-value: for example, the `instruction`
/// operand in OpExtInst.
NonValue(Operand),
}
/// Given an ID ref to a `OpCompositeExtract`, get the vector it's extracting from, and the field
/// index.
fn get_composite_and_index(
types: &FxHashMap<Word, Instruction>,
defs: &FxHashMap<Word, Instruction>,
id: Word,
vector_width: u32,
) -> Option<(Word, u32)> {
let inst = defs.get(&id)?;
if inst.class.opcode != Op::CompositeExtract {
return None;
}
if inst.operands.len() != 2 {
// If the index is more than one deep, bail.
return None;
}
let composite = inst.operands[0].unwrap_id_ref();
let index = inst.operands[1].unwrap_literal_int32();
let composite_def = defs.get(&composite).or_else(|| types.get(&composite))?;
let vector_def = types.get(&composite_def.result_type.unwrap())?;
// Make sure it's a vector and has the width we're expecting.
// Width mismatch would be doing something like `vec2(a.x + b.x, a.y + b.y)` where `a` is a
// vec4 - if we optimized it to just `a + b`, it'd be incorrect.
if vector_def.class.opcode != Op::TypeVector
|| vector_width != vector_def.operands[1].unwrap_literal_int32()
{
return None;
}
Some((composite, index))
}
/// Given a bunch of operands (`results[n].operands[operand_index]`), where all those operands
/// refer to an `OpCompositeExtract` of the same vector (with proper indices, etc.), return that
/// vector.
fn match_vector_operand(
types: &FxHashMap<Word, Instruction>,
defs: &FxHashMap<Word, Instruction>,
results: &[&Instruction],
operand_index: usize,
vector_width: u32,
) -> Option<Word> {
let operand_zero = match results[0].operands[operand_index] {
Operand::IdRef(id) => id,
_ => {
return None;
}
};
// Extract the composite used for the first component.
let composite_zero = match get_composite_and_index(types, defs, operand_zero, vector_width) {
Some((composite_zero, 0)) => composite_zero,
_ => {
return None;
}
};
// Check the same composite is used for every other component (and indices line up)
for (expected_index, result) in results.iter().enumerate().skip(1) {
let operand = match result.operands[operand_index] {
Operand::IdRef(id) => id,
_ => {
return None;
}
};
let (composite, actual_index) =
match get_composite_and_index(types, defs, operand, vector_width) {
Some(x) => x,
None => {
return None;
}
};
// If the source composite isn't all from the same composite, or the index
// isn't the right index, break.
if composite != composite_zero || expected_index != actual_index as usize {
return None;
}
}
Some(composite_zero)
}
/// Either extract out the vector behind each scalar component (see `match_vector_operand`), or
/// just return the collection of scalars for this operand (to be constructed into a vector via
/// `OpCompositeConstruct`).
fn match_vector_or_scalars_operand(
types: &FxHashMap<Word, Instruction>,
defs: &FxHashMap<Word, Instruction>,
results: &[&Instruction],
operand_index: usize,
vector_width: u32,
) -> Option<IdentifiedOperand> {
if let Some(composite) = match_vector_operand(types, defs, results, operand_index, vector_width)
{
Some(IdentifiedOperand::Vector(composite))
} else {
let operands = results
.iter()
.map(|inst| match inst.operands[operand_index] {
Operand::IdRef(id) => Some(id),
_ => None,
})
.collect::<Option<Vec<_>>>()?;
Some(IdentifiedOperand::Scalars(operands))
}
}
/// Make sure all the operands are the same at this index, and return that operand. This is used
/// in, for example, the `instruction` operand for `OpExtInst`.
fn match_all_same_operand(results: &[&Instruction], operand_index: usize) -> Option<Operand> {
let operand_zero = &results[0].operands[operand_index];
if results
.iter()
.skip(1)
.all(|inst| &inst.operands[operand_index] == operand_zero)
{
Some(operand_zero.clone())
} else {
None
}
}
/// Find the proper operands for the vectorized operation. This means finding the backing vector
/// for each scalar component, etc.
fn match_operands(
types: &FxHashMap<Word, Instruction>,
defs: &FxHashMap<Word, Instruction>,
results: &[&Instruction],
vector_width: u32,
) -> Option<Vec<IdentifiedOperand>> {
let operation_opcode = results[0].class.opcode;
// Check to make sure they're all the same opcode, and have the same number of arguments.
if results.iter().skip(1).any(|r| {
r.class.opcode != operation_opcode || r.operands.len() != results[0].operands.len()
}) {
return None;
}
// TODO: There are probably other instructions relevant here.
match operation_opcode {
Op::IAdd
| Op::FAdd
| Op::ISub
| Op::FSub
| Op::IMul
| Op::FMul
| Op::UDiv
| Op::SDiv
| Op::FDiv
| Op::UMod
| Op::SRem
| Op::FRem
| Op::FMod
| Op::ShiftRightLogical
| Op::ShiftRightArithmetic
| Op::ShiftLeftLogical
| Op::BitwiseOr
| Op::BitwiseXor
| Op::BitwiseAnd => {
let left = match_vector_or_scalars_operand(types, defs, results, 0, vector_width)?;
let right = match_vector_or_scalars_operand(types, defs, results, 1, vector_width)?;
match (left, right) {
// Style choice: If all arguments are scalars, don't fuse this operation.
(IdentifiedOperand::Scalars(_), IdentifiedOperand::Scalars(_)) => None,
(left, right) => Some(vec![left, right]),
}
}
Op::SNegate | Op::FNegate | Op::Not | Op::BitReverse => {
let value = match_vector_operand(types, defs, results, 0, vector_width)?;
Some(vec![IdentifiedOperand::Vector(value)])
}
Op::ExtInst => {
let set = match_all_same_operand(results, 0)?;
let instruction = match_all_same_operand(results, 1)?;
let parameters = (2..results[0].operands.len())
.map(|i| match_vector_or_scalars_operand(types, defs, results, i, vector_width));
// Do some trickery to reduce allocations.
let operands = std::array::IntoIter::new([
Some(IdentifiedOperand::NonValue(set)),
Some(IdentifiedOperand::NonValue(instruction)),
])
.chain(parameters)
.collect::<Option<Vec<_>>>()?;
if operands
.iter()
.skip(2)
.all(|p| matches!(p, &IdentifiedOperand::Scalars(_)))
{
// Style choice: If all arguments are scalars, don't fuse this operation.
return None;
}
Some(operands)
}
_ => None,
}
}
fn process_instruction(
header: &mut ModuleHeader,
types: &FxHashMap<Word, Instruction>,
defs: &FxHashMap<Word, Instruction>,
instructions: &mut Vec<Instruction>,
instruction_index: &mut usize,
) -> Option<Instruction> {
let inst = &instructions[*instruction_index];
// Basic sanity checks
if inst.class.opcode != Op::CompositeConstruct {
return None;
}
let inst_result_id = inst.result_id.unwrap();
let vector_ty = inst.result_type.unwrap();
let vector_ty_inst = match types.get(&vector_ty) {
Some(inst) => inst,
_ => return None,
};
if vector_ty_inst.class.opcode != Op::TypeVector {
return None;
}
let vector_width = vector_ty_inst.operands[1].unwrap_literal_int32();
// `results` is the defining instruction for each scalar component of the final result.
let results = match inst
.operands
.iter()
.map(|op| defs.get(&op.unwrap_id_ref()))
.collect::<Option<Vec<_>>>()
{
Some(r) => r,
None => return None,
};
let operation_opcode = results[0].class.opcode;
// Figure out the operands for the vectorized instruction.
let composite_arguments = match_operands(types, defs, &results, vector_width)?;
// Fun little optimization: SPIR-V has a fancy OpVectorTimesScalar instruction. If we have a
// vector times a collection of scalars, and the scalars are all the same, reduce it!
if operation_opcode == Op::FMul && composite_arguments.len() == 2 {
if let (&IdentifiedOperand::Vector(composite), IdentifiedOperand::Scalars(scalars))
| (IdentifiedOperand::Scalars(scalars), &IdentifiedOperand::Vector(composite)) =
(&composite_arguments[0], &composite_arguments[1])
{
let scalar = scalars[0];
if scalars.iter().skip(1).all(|&s| s == scalar) {
return Some(Instruction::new(
Op::VectorTimesScalar,
inst.result_type,
inst.result_id,
vec![Operand::IdRef(composite), Operand::IdRef(scalar)],
));
}
}
}
// Map the operands into their concrete representations: vectors and non-values stay as-is, but
// we need to emit an OpCompositeConstruct instruction for scalar collections.
let operands = composite_arguments
.into_iter()
.map(|operand| match operand {
IdentifiedOperand::Vector(composite) => Operand::IdRef(composite),
IdentifiedOperand::NonValue(operand) => operand,
IdentifiedOperand::Scalars(scalars) => {
let id = super::id(header);
// spirv-opt will transform this into an OpConstantComposite if all arguments are
// constant, so we don't have to worry about that.
instructions.insert(
*instruction_index,
Instruction::new(
Op::CompositeConstruct,
Some(vector_ty),
Some(id),
scalars.into_iter().map(Operand::IdRef).collect(),
),
);
*instruction_index += 1;
Operand::IdRef(id)
}
})
.collect();
Some(Instruction::new(
operation_opcode,
Some(vector_ty),
Some(inst_result_id),
operands,
))
}
/// Fuse a sequence of scalar operations into a single vector operation. For example:
/// ```
/// %x_0 = OpCompositeExtract %x 0
/// %x_1 = OpCompositeExtract %x 1
/// %y_0 = OpCompositeExtract %y 0
/// %y_1 = OpCompositeExtract %y 1
/// %r_0 = OpAdd %x_0 %y_0
/// %r_1 = OpAdd %x_1 %y_1
/// %r = OpCompositeConstruct %r_0 %r_1
/// ```
/// into
/// ```
/// %r = OpAdd %x %y
/// ```
/// (We don't remove the intermediate instructions, however, in case they're used elsewhere - we
/// let spirv-opt remove them if they're actually dead)
pub fn vector_ops(
header: &mut ModuleHeader,
types: &FxHashMap<Word, Instruction>,
function: &mut Function,
) {
let defs = function
.all_inst_iter()
.filter_map(|inst| Some((inst.result_id?, inst.clone())))
.collect::<FxHashMap<Word, Instruction>>();
for block in &mut function.blocks {
// It'd be nice to iterate over &mut block.instructions, but there's a weird case: if we
// have a vector plus a collection of scalars, we want to pack the collection of scalars
// into a vector and do a vector+vector op. That means we need to insert an extra
// OpCompositeConstruct into the block, so, we need to manually keep track of the current
// index and do a while loop.
let mut instruction_index = 0;
while instruction_index < block.instructions.len() {
if let Some(result) = process_instruction(
header,
types,
&defs,
&mut block.instructions,
&mut instruction_index,
) {
// Leave all the other instructions in the chain as dead code for other passes
// to clean up.
block.instructions[instruction_index] = result;
}
instruction_index += 1;
}
}
}

View File

@ -5,7 +5,7 @@ OpLine %5 7 12
%10 = OpArrayLength %11 %8 0
OpLine %5 7 0
%12 = OpCompositeInsert %13 %6 %14 0
%15 = OpCompositeInsert %13 %10 %12 1
%15 = OpCompositeConstruct %13 %6 %10
OpLine %5 8 21
%16 = OpULessThan %17 %9 %10
OpLine %5 8 21