diff --git a/rspirv-linker/src/lib.rs b/rspirv-linker/src/lib.rs index 1cb5931849..8f8020b654 100644 --- a/rspirv-linker/src/lib.rs +++ b/rspirv-linker/src/lib.rs @@ -31,6 +31,23 @@ pub fn load(bytes: &[u8]) -> rspirv::dr::Module { loader.module() } +fn operand_idref(op: &rspirv::dr::Operand) -> Option { + match *op { + rspirv::dr::Operand::IdMemorySemantics(w) + | rspirv::dr::Operand::IdScope(w) + | rspirv::dr::Operand::IdRef(w) => Some(w), + _ => None, + } +} +fn operand_idref_mut(op: &mut rspirv::dr::Operand) -> Option<&mut spirv::Word> { + match op { + rspirv::dr::Operand::IdMemorySemantics(w) + | rspirv::dr::Operand::IdScope(w) + | rspirv::dr::Operand::IdRef(w) => Some(w), + _ => None, + } +} + fn shift_ids(module: &mut rspirv::dr::Module, add: u32) { module.all_inst_iter_mut().for_each(|inst| { if let Some(ref mut result_id) = &mut inst.result_id { @@ -41,11 +58,10 @@ fn shift_ids(module: &mut rspirv::dr::Module, add: u32) { *result_type += add; } - inst.operands.iter_mut().for_each(|op| match op { - rspirv::dr::Operand::IdMemorySemantics(w) - | rspirv::dr::Operand::IdScope(w) - | rspirv::dr::Operand::IdRef(w) => *w += add, - _ => {} + inst.operands.iter_mut().for_each(|op| { + if let Some(w) = operand_idref_mut(op) { + *w += add + } }) }); } @@ -58,15 +74,12 @@ fn replace_all_uses_with(module: &mut rspirv::dr::Module, before: u32, after: u3 } } - inst.operands.iter_mut().for_each(|op| match op { - rspirv::dr::Operand::IdMemorySemantics(w) - | rspirv::dr::Operand::IdScope(w) - | rspirv::dr::Operand::IdRef(w) => { + inst.operands.iter_mut().for_each(|op| { + if let Some(w) = operand_idref_mut(op) { if *w == before { *w = after } } - _ => {} }) }); } @@ -130,16 +143,7 @@ fn kill_with_id(insts: &mut Vec, id: u32) { return false; } - match inst.operands[0] { - rspirv::dr::Operand::IdMemorySemantics(w) - | rspirv::dr::Operand::IdScope(w) - | rspirv::dr::Operand::IdRef(w) - if w == id => - { - true - } - _ => false, - } + matches!(operand_idref(&inst.operands[0]), Some(w) if w == id) }) } @@ -188,7 +192,7 @@ fn remove_duplicate_types(module: &mut rspirv::dr::Module) { *id = rules.get(id).copied().unwrap_or(*id); } for op in &mut inst.operands { - if let rspirv::dr::Operand::IdRef(ref mut id) = op { + if let Some(id) = operand_idref_mut(op) { *id = rules.get(id).copied().unwrap_or(*id); } } @@ -268,6 +272,7 @@ fn remove_duplicate_types(module: &mut rspirv::dr::Module) { // 2) Erase this instruction. Because we're iterating over this vec, removing an element is hard, so // clear it with OpNop, and then remove it in the retain() call below. assert!(old_value.is_none()); + println!("killing duplicate {:?} -> {}", *inst, entry.get()); *inst = rspirv::dr::Instruction::new(spirv::Op::Nop, None, None, vec![]); } } @@ -278,6 +283,10 @@ fn remove_duplicate_types(module: &mut rspirv::dr::Module) { .types_global_values .retain(|op| op.class.opcode != spirv::Op::Nop); + for (key, value) in &rewrite_rules { + println!("rewrite {} -> {}", key, value); + } + // Apply the rewrite rules to the whole module for inst in module.all_inst_iter_mut() { rewrite_inst_with_rules(inst, &rewrite_rules); @@ -572,6 +581,8 @@ fn remove_zombies(module: &mut rspirv::dr::Module) { ) = (&inst.operands[0], &inst.operands[2]) { return Some((id, reason.to_string())); + } else { + panic!("Invalid OpDecorateString") } } None @@ -595,12 +606,9 @@ fn remove_zombies(module: &mut rspirv::dr::Module) { return Some(reason); } } - inst.operands.iter().find_map(|op| match op { - rspirv::dr::Operand::IdMemorySemantics(w) - | rspirv::dr::Operand::IdScope(w) - | rspirv::dr::Operand::IdRef(w) => zombie.get(w).copied(), - _ => None, - }) + inst.operands + .iter() + .find_map(|op| operand_idref(op).and_then(|w| zombie.get(&w).copied())) } fn is_zombie<'a>( @@ -776,19 +784,34 @@ fn compact_ids(module: &mut rspirv::dr::Module) -> u32 { *result_type = insert(*result_type); } - inst.operands.iter_mut().for_each(|op| match op { - rspirv::dr::Operand::IdMemorySemantics(w) - | rspirv::dr::Operand::IdScope(w) - | rspirv::dr::Operand::IdRef(w) => { + inst.operands.iter_mut().for_each(|op| { + if let Some(w) = operand_idref_mut(op) { *w = insert(*w); } - _ => {} }) }); remap.len() as u32 + 1 } +fn max_bound(module: &rspirv::dr::Module) -> u32 { + let mut max = 0; + for inst in module.all_inst_iter() { + if let Some(result_id) = inst.result_id { + max = max.max(result_id); + } + if let Some(result_type) = inst.result_type { + max = max.max(result_type); + } + inst.operands.iter().for_each(|op| { + if let Some(w) = operand_idref(op) { + max = max.max(w); + } + }) + } + max + 1 +} + fn sort_globals(module: &mut rspirv::dr::Module) { let mut ts = TopologicalSort::::new(); @@ -799,13 +822,8 @@ fn sort_globals(module: &mut rspirv::dr::Module) { } for op in &t.operands { - match op { - rspirv::dr::Operand::IdMemorySemantics(w) - | rspirv::dr::Operand::IdScope(w) - | rspirv::dr::Operand::IdRef(w) => { - ts.add_dependency(*w, result_id); // the op defining the IdRef should come before our op / result_id - } - _ => {} + if let Some(w) = operand_idref(op) { + ts.add_dependency(w, result_id); // the op defining the IdRef should come before our op / result_id } } } @@ -954,14 +972,9 @@ enum AggregateType { } fn op_def(def: &DefAnalyzer, operand: &rspirv::dr::Operand) -> rspirv::dr::Instruction { - def.def(match operand { - rspirv::dr::Operand::IdMemorySemantics(w) - | rspirv::dr::Operand::IdScope(w) - | rspirv::dr::Operand::IdRef(w) => *w, - _ => panic!("Expected ID"), - }) - .unwrap() - .clone() + def.def(operand_idref(operand).expect("Expected ID")) + .unwrap() + .clone() } fn extract_literal_int_as_u64(op: &rspirv::dr::Operand) -> u64 { @@ -1168,8 +1181,14 @@ pub fn link(inputs: &mut [&mut rspirv::dr::Module], opts: &Options) -> Result