diff --git a/crates/rustc_codegen_spirv/src/linker/mem2reg.rs b/crates/rustc_codegen_spirv/src/linker/mem2reg.rs index 3716566061..c45509feb9 100644 --- a/crates/rustc_codegen_spirv/src/linker/mem2reg.rs +++ b/crates/rustc_codegen_spirv/src/linker/mem2reg.rs @@ -27,14 +27,14 @@ pub fn mem2reg( let preds = compute_preds(&func.blocks, &reachable); let idom = compute_idom(&preds, &reachable); let dominance_frontier = compute_dominance_frontier(&preds, &idom); - insert_phis_all( + while insert_phis_all( header, types_global_values, pointer_to_pointee, constants, &mut func.blocks, - dominance_frontier, - ); + &dominance_frontier, + ) {} } fn label_to_index(blocks: &[Block], id: Word) -> usize { @@ -146,15 +146,16 @@ fn compute_dominance_frontier( dominance_frontier } +// Returns true if variables were rewritten fn insert_phis_all( header: &mut ModuleHeader, types_global_values: &mut Vec, pointer_to_pointee: &FxHashMap, constants: &FxHashMap, blocks: &mut [Block], - dominance_frontier: Vec>, -) { - let thing = blocks[0] + dominance_frontier: &[FxHashSet], +) -> bool { + let var_maps_and_types = blocks[0] .instructions .iter() .filter(|inst| inst.class.opcode == Op::Variable) @@ -167,8 +168,11 @@ fn insert_phis_all( )) }) .collect::>(); - for &(ref var_map, base_var_type) in &thing { - let blocks_with_phi = insert_phis(blocks, &dominance_frontier, var_map); + if var_maps_and_types.is_empty() { + return false; + } + for &(ref var_map, base_var_type) in &var_maps_and_types { + let blocks_with_phi = insert_phis(blocks, dominance_frontier, var_map); let mut renamer = Renamer { header, types_global_values, @@ -185,7 +189,8 @@ fn insert_phis_all( apply_rewrite_rules(&renamer.rewrite_rules, blocks); remove_nops(blocks); } - remove_old_variables(blocks, &thing); + remove_old_variables(blocks, &var_maps_and_types); + true } #[derive(Debug)] @@ -488,11 +493,14 @@ fn remove_nops(blocks: &mut [Block]) { } } -fn remove_old_variables(blocks: &mut [Block], thing: &[(FxHashMap, u32)]) { +fn remove_old_variables( + blocks: &mut [Block], + var_maps_and_types: &[(FxHashMap, u32)], +) { blocks[0].instructions.retain(|inst| { inst.class.opcode != Op::Variable || { let result_id = inst.result_id.unwrap(); - thing + var_maps_and_types .iter() .all(|(var_map, _)| !var_map.contains_key(&result_id)) } @@ -502,7 +510,9 @@ fn remove_old_variables(blocks: &mut [Block], thing: &[(FxHashMap, !matches!(inst.class.opcode, Op::AccessChain | Op::InBoundsAccessChain) || inst.operands.iter().all(|op| { op.id_ref_any().map_or(true, |id| { - thing.iter().all(|(var_map, _)| !var_map.contains_key(&id)) + var_maps_and_types + .iter() + .all(|(var_map, _)| !var_map.contains_key(&id)) }) }) }) diff --git a/crates/rustc_codegen_spirv/src/linker/mod.rs b/crates/rustc_codegen_spirv/src/linker/mod.rs index 40c440e226..a2bd47d78f 100644 --- a/crates/rustc_codegen_spirv/src/linker/mod.rs +++ b/crates/rustc_codegen_spirv/src/linker/mod.rs @@ -7,6 +7,7 @@ mod import_export_link; mod inline; mod mem2reg; mod new_structurizer; +mod peephole_opts; mod simple_passes; mod specializer; mod structurizer; @@ -242,6 +243,15 @@ pub fn link(sess: &Session, mut inputs: Vec, opts: &Options) -> Result FxHashMap { + module + .types_global_values + .iter() + .filter_map(|inst| Some((inst.result_id?, inst.clone()))) + .collect() +} + +fn composite_count(types: &FxHashMap, ty_id: Word) -> Option { + let ty = types.get(&ty_id)?; + match ty.class.opcode { + Op::TypeStruct => Some(ty.operands.len()), + Op::TypeVector => Some(ty.operands[1].unwrap_literal_int32() as usize), + Op::TypeArray => { + let length_id = ty.operands[1].unwrap_id_ref(); + let const_inst = types.get(&length_id)?; + if const_inst.class.opcode != Op::Constant { + return None; + } + let const_ty = types.get(&const_inst.result_type.unwrap())?; + if const_ty.class.opcode != Op::TypeInt { + return None; + } + let const_value = match const_inst.operands[0] { + Operand::LiteralInt32(v) => v as usize, + Operand::LiteralInt64(v) => v as usize, + _ => bug!(), + }; + Some(const_value) + } + _ => None, + } +} + +/// Given a chain of `OpCompositeInsert` instructions where all slots of the composite are +/// assigned, replace the chain with a single `OpCompositeConstruct`. +pub fn composite_construct(types: &FxHashMap, function: &mut Function) { + let defs = function + .all_inst_iter() + .filter_map(|inst| Some((inst.result_id?, inst.clone()))) + .collect::>(); + for block in &mut function.blocks { + for inst in &mut block.instructions { + if inst.class.opcode != Op::CompositeInsert { + continue; + } + // Get the number of components to expect + let component_count = match composite_count(types, inst.result_type.unwrap()) { + Some(c) => c, + None => continue, + }; + // Remember a map of index -> value for that index. If any index is missing (None) + // afterwards, then we know not all slots have been filled in, and we should skip + // optimizing this chain. + let mut components = vec![None; component_count]; + let mut cur_inst: &Instruction = inst; + // Start looping from the current instruction, through each instruction in the chain. + while cur_inst.class.opcode == Op::CompositeInsert { + if cur_inst.operands.len() != 3 { + // If there's more than one index, skip optimizing this chain. + break; + } + let value = cur_inst.operands[0].unwrap_id_ref(); + let index = cur_inst.operands[2].unwrap_literal_int32() as usize; + if index >= components.len() { + // Theoretically shouldn't happen, as it's invalid SPIR-V if the index is out + // of bounds, but just stop optimizing instead of panicing here. + break; + } + components[index] = Some(value); + // Follow back one in the chain of OpCompositeInsert + cur_inst = match defs.get(&cur_inst.operands[1].unwrap_id_ref()) { + Some(i) => i, + None => break, + }; + } + // If all components are filled in (collect() returns Some), replace it with + // `OpCompositeConstruct` + if let Some(composite_construct_operands) = components + .into_iter() + .map(|v| v.map(Operand::IdRef)) + .collect::>>() + { + // Leave all the other instructions in the chain as dead code for other passes + // to clean up. + *inst = Instruction::new( + Op::CompositeConstruct, + inst.result_type, + inst.result_id, + composite_construct_operands, + ); + } + } + } +} + +#[derive(Debug)] +enum IdentifiedOperand { + /// The operand to the vectorized operation is a straight-up vector. + Vector(Word), + /// The operand to the vectorized operation is a collection of scalars that need to be packed + /// together with OpCompositeConstruct before using the vectorized operation. + Scalars(Vec), + /// The operand to the vectorized operation is some non-value: for example, the `instruction` + /// operand in OpExtInst. + NonValue(Operand), +} + +/// Given an ID ref to a `OpCompositeExtract`, get the vector it's extracting from, and the field +/// index. +fn get_composite_and_index( + types: &FxHashMap, + defs: &FxHashMap, + id: Word, + vector_width: u32, +) -> Option<(Word, u32)> { + let inst = defs.get(&id)?; + if inst.class.opcode != Op::CompositeExtract { + return None; + } + if inst.operands.len() != 2 { + // If the index is more than one deep, bail. + return None; + } + let composite = inst.operands[0].unwrap_id_ref(); + let index = inst.operands[1].unwrap_literal_int32(); + + let composite_def = defs.get(&composite).or_else(|| types.get(&composite))?; + let vector_def = types.get(&composite_def.result_type.unwrap())?; + + // Make sure it's a vector and has the width we're expecting. + // Width mismatch would be doing something like `vec2(a.x + b.x, a.y + b.y)` where `a` is a + // vec4 - if we optimized it to just `a + b`, it'd be incorrect. + if vector_def.class.opcode != Op::TypeVector + || vector_width != vector_def.operands[1].unwrap_literal_int32() + { + return None; + } + + Some((composite, index)) +} + +/// Given a bunch of operands (`results[n].operands[operand_index]`), where all those operands +/// refer to an `OpCompositeExtract` of the same vector (with proper indices, etc.), return that +/// vector. +fn match_vector_operand( + types: &FxHashMap, + defs: &FxHashMap, + results: &[&Instruction], + operand_index: usize, + vector_width: u32, +) -> Option { + let operand_zero = match results[0].operands[operand_index] { + Operand::IdRef(id) => id, + _ => { + return None; + } + }; + // Extract the composite used for the first component. + let composite_zero = match get_composite_and_index(types, defs, operand_zero, vector_width) { + Some((composite_zero, 0)) => composite_zero, + _ => { + return None; + } + }; + // Check the same composite is used for every other component (and indices line up) + for (expected_index, result) in results.iter().enumerate().skip(1) { + let operand = match result.operands[operand_index] { + Operand::IdRef(id) => id, + _ => { + return None; + } + }; + let (composite, actual_index) = + match get_composite_and_index(types, defs, operand, vector_width) { + Some(x) => x, + None => { + return None; + } + }; + // If the source composite isn't all from the same composite, or the index + // isn't the right index, break. + if composite != composite_zero || expected_index != actual_index as usize { + return None; + } + } + Some(composite_zero) +} + +/// Either extract out the vector behind each scalar component (see `match_vector_operand`), or +/// just return the collection of scalars for this operand (to be constructed into a vector via +/// `OpCompositeConstruct`). +fn match_vector_or_scalars_operand( + types: &FxHashMap, + defs: &FxHashMap, + results: &[&Instruction], + operand_index: usize, + vector_width: u32, +) -> Option { + if let Some(composite) = match_vector_operand(types, defs, results, operand_index, vector_width) + { + Some(IdentifiedOperand::Vector(composite)) + } else { + let operands = results + .iter() + .map(|inst| match inst.operands[operand_index] { + Operand::IdRef(id) => Some(id), + _ => None, + }) + .collect::>>()?; + Some(IdentifiedOperand::Scalars(operands)) + } +} + +/// Make sure all the operands are the same at this index, and return that operand. This is used +/// in, for example, the `instruction` operand for `OpExtInst`. +fn match_all_same_operand(results: &[&Instruction], operand_index: usize) -> Option { + let operand_zero = &results[0].operands[operand_index]; + if results + .iter() + .skip(1) + .all(|inst| &inst.operands[operand_index] == operand_zero) + { + Some(operand_zero.clone()) + } else { + None + } +} + +/// Find the proper operands for the vectorized operation. This means finding the backing vector +/// for each scalar component, etc. +fn match_operands( + types: &FxHashMap, + defs: &FxHashMap, + results: &[&Instruction], + vector_width: u32, +) -> Option> { + let operation_opcode = results[0].class.opcode; + // Check to make sure they're all the same opcode, and have the same number of arguments. + if results.iter().skip(1).any(|r| { + r.class.opcode != operation_opcode || r.operands.len() != results[0].operands.len() + }) { + return None; + } + // TODO: There are probably other instructions relevant here. + match operation_opcode { + Op::IAdd + | Op::FAdd + | Op::ISub + | Op::FSub + | Op::IMul + | Op::FMul + | Op::UDiv + | Op::SDiv + | Op::FDiv + | Op::UMod + | Op::SRem + | Op::FRem + | Op::FMod + | Op::ShiftRightLogical + | Op::ShiftRightArithmetic + | Op::ShiftLeftLogical + | Op::BitwiseOr + | Op::BitwiseXor + | Op::BitwiseAnd => { + let left = match_vector_or_scalars_operand(types, defs, results, 0, vector_width)?; + let right = match_vector_or_scalars_operand(types, defs, results, 1, vector_width)?; + match (left, right) { + // Style choice: If all arguments are scalars, don't fuse this operation. + (IdentifiedOperand::Scalars(_), IdentifiedOperand::Scalars(_)) => None, + (left, right) => Some(vec![left, right]), + } + } + Op::SNegate | Op::FNegate | Op::Not | Op::BitReverse => { + let value = match_vector_operand(types, defs, results, 0, vector_width)?; + Some(vec![IdentifiedOperand::Vector(value)]) + } + Op::ExtInst => { + let set = match_all_same_operand(results, 0)?; + let instruction = match_all_same_operand(results, 1)?; + let parameters = (2..results[0].operands.len()) + .map(|i| match_vector_or_scalars_operand(types, defs, results, i, vector_width)); + // Do some trickery to reduce allocations. + let operands = std::array::IntoIter::new([ + Some(IdentifiedOperand::NonValue(set)), + Some(IdentifiedOperand::NonValue(instruction)), + ]) + .chain(parameters) + .collect::>>()?; + if operands + .iter() + .skip(2) + .all(|p| matches!(p, &IdentifiedOperand::Scalars(_))) + { + // Style choice: If all arguments are scalars, don't fuse this operation. + return None; + } + Some(operands) + } + _ => None, + } +} + +fn process_instruction( + header: &mut ModuleHeader, + types: &FxHashMap, + defs: &FxHashMap, + instructions: &mut Vec, + instruction_index: &mut usize, +) -> Option { + let inst = &instructions[*instruction_index]; + // Basic sanity checks + if inst.class.opcode != Op::CompositeConstruct { + return None; + } + let inst_result_id = inst.result_id.unwrap(); + let vector_ty = inst.result_type.unwrap(); + let vector_ty_inst = match types.get(&vector_ty) { + Some(inst) => inst, + _ => return None, + }; + if vector_ty_inst.class.opcode != Op::TypeVector { + return None; + } + let vector_width = vector_ty_inst.operands[1].unwrap_literal_int32(); + // `results` is the defining instruction for each scalar component of the final result. + let results = match inst + .operands + .iter() + .map(|op| defs.get(&op.unwrap_id_ref())) + .collect::>>() + { + Some(r) => r, + None => return None, + }; + + let operation_opcode = results[0].class.opcode; + // Figure out the operands for the vectorized instruction. + let composite_arguments = match_operands(types, defs, &results, vector_width)?; + + // Fun little optimization: SPIR-V has a fancy OpVectorTimesScalar instruction. If we have a + // vector times a collection of scalars, and the scalars are all the same, reduce it! + if operation_opcode == Op::FMul && composite_arguments.len() == 2 { + if let (&IdentifiedOperand::Vector(composite), IdentifiedOperand::Scalars(scalars)) + | (IdentifiedOperand::Scalars(scalars), &IdentifiedOperand::Vector(composite)) = + (&composite_arguments[0], &composite_arguments[1]) + { + let scalar = scalars[0]; + if scalars.iter().skip(1).all(|&s| s == scalar) { + return Some(Instruction::new( + Op::VectorTimesScalar, + inst.result_type, + inst.result_id, + vec![Operand::IdRef(composite), Operand::IdRef(scalar)], + )); + } + } + } + + // Map the operands into their concrete representations: vectors and non-values stay as-is, but + // we need to emit an OpCompositeConstruct instruction for scalar collections. + let operands = composite_arguments + .into_iter() + .map(|operand| match operand { + IdentifiedOperand::Vector(composite) => Operand::IdRef(composite), + IdentifiedOperand::NonValue(operand) => operand, + IdentifiedOperand::Scalars(scalars) => { + let id = super::id(header); + // spirv-opt will transform this into an OpConstantComposite if all arguments are + // constant, so we don't have to worry about that. + instructions.insert( + *instruction_index, + Instruction::new( + Op::CompositeConstruct, + Some(vector_ty), + Some(id), + scalars.into_iter().map(Operand::IdRef).collect(), + ), + ); + *instruction_index += 1; + Operand::IdRef(id) + } + }) + .collect(); + + Some(Instruction::new( + operation_opcode, + Some(vector_ty), + Some(inst_result_id), + operands, + )) +} + +/// Fuse a sequence of scalar operations into a single vector operation. For example: +/// ``` +/// %x_0 = OpCompositeExtract %x 0 +/// %x_1 = OpCompositeExtract %x 1 +/// %y_0 = OpCompositeExtract %y 0 +/// %y_1 = OpCompositeExtract %y 1 +/// %r_0 = OpAdd %x_0 %y_0 +/// %r_1 = OpAdd %x_1 %y_1 +/// %r = OpCompositeConstruct %r_0 %r_1 +/// ``` +/// into +/// ``` +/// %r = OpAdd %x %y +/// ``` +/// (We don't remove the intermediate instructions, however, in case they're used elsewhere - we +/// let spirv-opt remove them if they're actually dead) +pub fn vector_ops( + header: &mut ModuleHeader, + types: &FxHashMap, + function: &mut Function, +) { + let defs = function + .all_inst_iter() + .filter_map(|inst| Some((inst.result_id?, inst.clone()))) + .collect::>(); + for block in &mut function.blocks { + // It'd be nice to iterate over &mut block.instructions, but there's a weird case: if we + // have a vector plus a collection of scalars, we want to pack the collection of scalars + // into a vector and do a vector+vector op. That means we need to insert an extra + // OpCompositeConstruct into the block, so, we need to manually keep track of the current + // index and do a while loop. + let mut instruction_index = 0; + while instruction_index < block.instructions.len() { + if let Some(result) = process_instruction( + header, + types, + &defs, + &mut block.instructions, + &mut instruction_index, + ) { + // Leave all the other instructions in the chain as dead code for other passes + // to clean up. + block.instructions[instruction_index] = result; + } + + instruction_index += 1; + } + } +} diff --git a/tests/ui/dis/index_user_dst.stderr b/tests/ui/dis/index_user_dst.stderr index 162201fd2f..2cc20726c0 100644 --- a/tests/ui/dis/index_user_dst.stderr +++ b/tests/ui/dis/index_user_dst.stderr @@ -5,7 +5,7 @@ OpLine %5 7 12 %10 = OpArrayLength %11 %8 0 OpLine %5 7 0 %12 = OpCompositeInsert %13 %6 %14 0 -%15 = OpCompositeInsert %13 %10 %12 1 +%15 = OpCompositeConstruct %13 %6 %10 OpLine %5 8 21 %16 = OpULessThan %17 %9 %10 OpLine %5 8 21