diff --git a/Cargo.lock b/Cargo.lock index c542133f02..eb49f5587a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1083,7 +1083,7 @@ dependencies = [ [[package]] name = "rspirv" version = "0.7.0" -source = "git+https://github.com/gfx-rs/rspirv.git?rev=1addc7d33ae1460ffa683e2e6311e466ac876c23#1addc7d33ae1460ffa683e2e6311e466ac876c23" +source = "git+https://github.com/gfx-rs/rspirv.git?rev=f11f8797bd4df2d1d22cf10767b39a5119c57551#f11f8797bd4df2d1d22cf10767b39a5119c57551" dependencies = [ "derive_more", "fxhash", @@ -1272,7 +1272,7 @@ version = "0.1.0" [[package]] name = "spirv_headers" version = "1.5.0" -source = "git+https://github.com/gfx-rs/rspirv.git?rev=1addc7d33ae1460ffa683e2e6311e466ac876c23#1addc7d33ae1460ffa683e2e6311e466ac876c23" +source = "git+https://github.com/gfx-rs/rspirv.git?rev=f11f8797bd4df2d1d22cf10767b39a5119c57551#f11f8797bd4df2d1d22cf10767b39a5119c57551" dependencies = [ "bitflags 1.2.1", "num-traits 0.2.12", diff --git a/Cargo.toml b/Cargo.toml index 6fb2e3ca24..f71b4c1645 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,4 +8,4 @@ members = [ ] [patch.crates-io] -rspirv = { git = "https://github.com/gfx-rs/rspirv.git", rev = "1addc7d33ae1460ffa683e2e6311e466ac876c23" } +rspirv = { git = "https://github.com/gfx-rs/rspirv.git", rev = "f11f8797bd4df2d1d22cf10767b39a5119c57551" } diff --git a/rspirv-linker/src/capability_computation.rs b/rspirv-linker/src/capability_computation.rs index dca0325f43..90425c3571 100644 --- a/rspirv-linker/src/capability_computation.rs +++ b/rspirv-linker/src/capability_computation.rs @@ -1,4 +1,4 @@ -use rspirv::dr::{Module, Operand}; +use rspirv::dr::Module; use rspirv::spirv::{Capability, Op}; use std::collections::HashSet; @@ -12,32 +12,26 @@ fn compute_capabilities(module: &Module) -> HashSet { for inst in module.all_inst_iter() { set.extend(inst.class.capabilities); match inst.class.opcode { - Op::TypeInt => match inst.operands[0] { - Operand::LiteralInt32(width) => match width { - 8 => { - set.insert(Capability::Int8); - } - 16 => { - set.insert(Capability::Int16); - } - 64 => { - set.insert(Capability::Int64); - } - _ => {} - }, - _ => panic!(), + Op::TypeInt => match inst.operands[0].unwrap_literal_int32() { + 8 => { + set.insert(Capability::Int8); + } + 16 => { + set.insert(Capability::Int16); + } + 64 => { + set.insert(Capability::Int64); + } + _ => {} }, - Op::TypeFloat => match inst.operands[0] { - Operand::LiteralInt32(width) => match width { - 16 => { - set.insert(Capability::Float16); - } - 64 => { - set.insert(Capability::Float64); - } - _ => {} - }, - _ => panic!(), + Op::TypeFloat => match inst.operands[0].unwrap_literal_int32() { + 16 => { + set.insert(Capability::Float16); + } + 64 => { + set.insert(Capability::Float64); + } + _ => {} }, _ => {} } @@ -53,11 +47,7 @@ fn compute_capabilities(module: &Module) -> HashSet { fn remove_capabilities(module: &mut Module, set: &HashSet) { module.capabilities.retain(|inst| { - inst.class.opcode != Op::Capability - || set.contains(match &inst.operands[0] { - Operand::Capability(s) => s, - _ => panic!(), - }) + inst.class.opcode != Op::Capability || set.contains(&inst.operands[0].unwrap_capability()) }); } @@ -65,21 +55,12 @@ pub fn remove_extra_extensions(module: &mut Module) { // TODO: Make this more generalized once this gets more advanced. let has_intel_integer_cap = module.capabilities.iter().any(|inst| { inst.class.opcode == Op::Capability - && match inst.operands[0] { - Operand::Capability(s) => s == Capability::IntegerFunctions2INTEL, - _ => panic!(), - } + && inst.operands[0].unwrap_capability() == Capability::IntegerFunctions2INTEL }); if !has_intel_integer_cap { module.extensions.retain(|inst| { inst.class.opcode != Op::Extension - || match &inst.operands[0] { - Operand::LiteralString(s) if s == "SPV_INTEL_shader_integer_functions2" => { - false - } - Operand::LiteralString(_) => true, - _ => panic!(), - } + || inst.operands[0].unwrap_literal_string() != "SPV_INTEL_shader_integer_functions2" }) } } diff --git a/rspirv-linker/src/dce.rs b/rspirv-linker/src/dce.rs index 5db2214c5e..6551191bc4 100644 --- a/rspirv-linker/src/dce.rs +++ b/rspirv-linker/src/dce.rs @@ -1,4 +1,3 @@ -use crate::operand_idref; use rspirv::dr::{Instruction, Module}; use rspirv::spirv::Word; use std::collections::HashSet; @@ -50,7 +49,7 @@ fn root(inst: &Instruction, rooted: &mut HashSet) -> bool { any |= rooted.insert(id); } for op in &inst.operands { - if let Some(id) = operand_idref(op) { + if let Some(id) = op.id_ref_any() { any |= rooted.insert(id); } } @@ -65,7 +64,7 @@ fn is_rooted(inst: &Instruction, rooted: &HashSet) -> bool { // referenced by roots inst.operands .iter() - .any(|op| operand_idref(op).map_or(false, |w| rooted.contains(&w))) + .any(|op| op.id_ref_any().map_or(false, |w| rooted.contains(&w))) } } diff --git a/rspirv-linker/src/def_analyzer.rs b/rspirv-linker/src/def_analyzer.rs index 76a1561d23..0505c4cfa1 100644 --- a/rspirv-linker/src/def_analyzer.rs +++ b/rspirv-linker/src/def_analyzer.rs @@ -1,4 +1,3 @@ -use crate::operand_idref; use rspirv::dr::{Instruction, Module, Operand}; use std::collections::HashMap; @@ -40,7 +39,7 @@ impl<'a> DefAnalyzer<'a> { /// /// Panics when provided an operand that doesn't reference an id, or that id is missing. pub fn op_def(&self, operand: &Operand) -> Instruction { - self.def(operand_idref(operand).expect("Expected ID")) + self.def(operand.id_ref_any().expect("Expected ID")) .unwrap() .clone() } diff --git a/rspirv-linker/src/duplicates.rs b/rspirv-linker/src/duplicates.rs index 178f260f4d..20d611986c 100644 --- a/rspirv-linker/src/duplicates.rs +++ b/rspirv-linker/src/duplicates.rs @@ -1,4 +1,3 @@ -use crate::{operand_idref, operand_idref_mut}; use rspirv::binary::Assemble; use rspirv::dr::{Instruction, Module, Operand}; use rspirv::spirv::{Op, Word}; @@ -9,21 +8,14 @@ pub fn remove_duplicate_extensions(module: &mut Module) { module.extensions.retain(|inst| { inst.class.opcode != Op::Extension - || set.insert(match &inst.operands[0] { - Operand::LiteralString(s) => s.clone(), - _ => panic!(), - }) + || set.insert(inst.operands[0].unwrap_literal_string().to_string()) }); } pub fn remove_duplicate_capablities(module: &mut Module) { let mut set = HashSet::new(); module.capabilities.retain(|inst| { - inst.class.opcode != Op::Capability - || set.insert(match inst.operands[0] { - Operand::Capability(s) => s, - _ => panic!(), - }) + inst.class.opcode != Op::Capability || set.insert(inst.operands[0].unwrap_capability()) }); } @@ -80,7 +72,7 @@ fn gather_annotations(annotations: &[Instruction]) -> HashMap> { let mut map = HashMap::new(); for inst in annotations { if inst.class.opcode == Op::Decorate || inst.class.opcode == Op::MemberDecorate { - match map.entry(operand_idref(&inst.operands[0]).unwrap()) { + match map.entry(inst.operands[0].id_ref_any().unwrap()) { hash_map::Entry::Vacant(entry) => { entry.insert(vec![make_annotation_key(inst)]); } @@ -142,7 +134,7 @@ fn rewrite_inst_with_rules(inst: &mut Instruction, rules: &HashMap) { *id = rules.get(id).copied().unwrap_or(*id); } for op in &mut inst.operands { - if let Some(id) = operand_idref_mut(op) { + if let Some(id) = op.id_ref_any_mut() { *id = rules.get(id).copied().unwrap_or(*id); } } diff --git a/rspirv-linker/src/import_export_link.rs b/rspirv-linker/src/import_export_link.rs index 8ccabade52..46d5f0e1ec 100644 --- a/rspirv-linker/src/import_export_link.rs +++ b/rspirv-linker/src/import_export_link.rs @@ -1,6 +1,6 @@ use crate::ty::trans_aggregate_type; -use crate::{operand_idref, operand_idref_mut, print_type, DefAnalyzer, LinkerError, Result}; -use rspirv::dr::{Instruction, Module, Operand}; +use crate::{print_type, DefAnalyzer, LinkerError, Result}; +use rspirv::dr::{Instruction, Module}; use rspirv::spirv::{Capability, Decoration, LinkageType, Op, Word}; use std::collections::{HashMap, HashSet}; @@ -60,22 +60,11 @@ fn find_import_export_pairs_and_killed_params( fn get_linkage_inst(inst: &Instruction) -> Option<(Word, &str, LinkageType)> { if inst.class.opcode == Op::Decorate - && inst.operands[1] == Operand::Decoration(Decoration::LinkageAttributes) + && inst.operands[1].unwrap_decoration() == Decoration::LinkageAttributes { - let id = match inst.operands[0] { - Operand::IdRef(i) => i, - _ => panic!("Expected IdRef"), - }; - - let name = match &inst.operands[2] { - Operand::LiteralString(s) => s, - _ => panic!("Expected LiteralString"), - }; - - let linkage_ty = match inst.operands[3] { - Operand::LinkageType(t) => t, - _ => panic!("Expected LinkageType"), - }; + let id = inst.operands[0].unwrap_id_ref(); + let name = inst.operands[2].unwrap_literal_string(); + let linkage_ty = inst.operands[3].unwrap_linkage_type(); Some((id, name, linkage_ty)) } else { None @@ -91,13 +80,7 @@ fn get_type_for_link(defs: &DefAnalyzer, id: Word) -> Word { Op::Variable => def_inst.result_type.unwrap(), // Note: the result_type of OpFunction is the return type, not the function type. The // function type is in operands[1]. - Op::Function => { - if let Operand::IdRef(id) = def_inst.operands[1] { - id - } else { - panic!("Expected IdRef"); - } - } + Op::Function => def_inst.operands[1].unwrap_id_ref(), _ => panic!("Unexpected op"), } } @@ -159,7 +142,7 @@ fn replace_all_uses_with(module: &mut Module, rules: &HashMap) { } inst.operands.iter_mut().for_each(|op| { - if let Some(w) = operand_idref_mut(op) { + if let Some(w) = op.id_ref_any_mut() { if let Some(&rewrite) = rules.get(w) { *w = rewrite; } @@ -182,13 +165,13 @@ fn kill_linkage_instructions(module: &mut Module, rewrite_rules: &HashMap` in a function! So, we inline all functions that take these "illegal" pointers, then //! run mem2reg (see mem2reg.rs) on the result to "unwrap" the Function pointer. +use crate::apply_rewrite_rules; use crate::mem2reg::compute_preds; -use crate::{apply_rewrite_rules, operand_idref}; use rspirv::dr::{Block, Function, Instruction, Module, ModuleHeader, Operand}; use rspirv::spirv::{FunctionControl, Op, StorageClass, Word}; use std::collections::{HashMap, HashSet}; @@ -40,10 +40,10 @@ pub fn inline(module: &mut Module) { }); // Drop OpName etc. for inlined functions module.debugs.retain(|inst| { - !inst - .operands - .iter() - .any(|op| operand_idref(op).map_or(false, |id| dropped_ids.contains(&id))) + !inst.operands.iter().any(|op| { + op.id_ref_any() + .map_or(false, |id| dropped_ids.contains(&id)) + }) }); let mut inliner = Inliner { header: &mut module.header.as_mut().unwrap(), @@ -72,14 +72,8 @@ fn compute_disallowed_argument_types(module: &Module) -> HashSet { for inst in &module.types_global_values { match inst.class.opcode { Op::TypePointer => { - let storage_class = match inst.operands[0] { - Operand::StorageClass(x) => x, - _ => panic!(), - }; - let pointee = match inst.operands[1] { - Operand::IdRef(x) => x, - _ => panic!(), - }; + let storage_class = inst.operands[0].unwrap_storage_class(); + let pointee = inst.operands[1].unwrap_id_ref(); if !allowed_argument_storage_classes.contains(&storage_class) || disallowed_pointees.contains(&pointee) || disallowed_argument_types.contains(&pointee) @@ -92,7 +86,7 @@ fn compute_disallowed_argument_types(module: &Module) -> HashSet { if inst .operands .iter() - .map(|op| operand_idref(op).unwrap()) + .map(|op| op.id_ref_any().unwrap()) .any(|id| disallowed_argument_types.contains(&id)) { disallowed_argument_types.insert(inst.result_id.unwrap()); @@ -100,14 +94,14 @@ fn compute_disallowed_argument_types(module: &Module) -> HashSet { if inst .operands .iter() - .map(|op| operand_idref(op).unwrap()) + .map(|op| op.id_ref_any().unwrap()) .any(|id| disallowed_pointees.contains(&id)) { disallowed_pointees.insert(inst.result_id.unwrap()); } } Op::TypeArray | Op::TypeRuntimeArray | Op::TypeVector => { - let id = operand_idref(&inst.operands[0]).unwrap(); + let id = inst.operands[0].id_ref_any().unwrap(); if disallowed_argument_types.contains(&id) { disallowed_argument_types.insert(inst.result_id.unwrap()); } @@ -123,10 +117,7 @@ fn compute_disallowed_argument_types(module: &Module) -> HashSet { fn should_inline(disallowed_argument_types: &HashSet, function: &Function) -> bool { let def = function.def.as_ref().unwrap(); - let control = match def.operands[0] { - Operand::FunctionControl(control) => control, - _ => panic!(), - }; + let control = def.operands[0].unwrap_function_control(); control.contains(FunctionControl::INLINE) || function .parameters @@ -160,8 +151,8 @@ impl Inliner<'_, '_> { // TODO: This is horribly slow, fix this let existing = self.types_global_values.iter().find(|inst| { inst.class.opcode == Op::TypePointer - && inst.operands[0] == Operand::StorageClass(StorageClass::Function) - && inst.operands[1] == Operand::IdRef(pointee) + && inst.operands[0].unwrap_storage_class() == StorageClass::Function + && inst.operands[1].unwrap_id_ref() == pointee }); if let Some(existing) = existing { return existing.result_id.unwrap(); @@ -203,7 +194,7 @@ impl Inliner<'_, '_> { index, inst, self.functions - .get(&operand_idref(&inst.operands[0]).unwrap()) + .get(&inst.operands[0].id_ref_any().unwrap()) .unwrap(), ) }) @@ -226,7 +217,7 @@ impl Inliner<'_, '_> { .operands .iter() .skip(1) - .map(|op| operand_idref(op).unwrap()); + .map(|op| op.id_ref_any().unwrap()); let callee_parameters = callee.parameters.iter().map(|inst| { assert!(inst.class.opcode == Op::FunctionParameter); inst.result_id.unwrap() @@ -330,7 +321,7 @@ fn get_inlined_blocks( let last = block.instructions.last().unwrap(); if let Op::Return | Op::ReturnValue = last.class.opcode { if Op::ReturnValue == last.class.opcode { - let return_value = operand_idref(&last.operands[0]).unwrap(); + let return_value = last.operands[0].id_ref_any().unwrap(); block.instructions.insert( block.instructions.len() - 1, Instruction::new( diff --git a/rspirv-linker/src/lib.rs b/rspirv-linker/src/lib.rs index 58736d9ffa..a05a723f2c 100644 --- a/rspirv-linker/src/lib.rs +++ b/rspirv-linker/src/lib.rs @@ -56,23 +56,6 @@ fn id(header: &mut ModuleHeader) -> Word { result } -fn operand_idref(op: &Operand) -> Option { - match *op { - Operand::IdMemorySemantics(w) | Operand::IdScope(w) | Operand::IdRef(w) => Some(w), - _ => None, - } -} -fn operand_idref_mut(op: &mut Operand) -> Option<&mut Word> { - match op { - Operand::IdMemorySemantics(w) | Operand::IdScope(w) | Operand::IdRef(w) => Some(w), - _ => None, - } -} - -fn label_of(block: &Block) -> Word { - block.label.as_ref().unwrap().result_id.unwrap() -} - fn print_type(defs: &DefAnalyzer, ty: &Instruction) -> String { format!("{}", ty::trans_aggregate_type(defs, ty).unwrap()) } @@ -85,13 +68,6 @@ fn extract_literal_int_as_u64(op: &Operand) -> u64 { } } -fn extract_literal_u32(op: &Operand) -> u32 { - match op { - Operand::LiteralInt32(v) => *v, - _ => panic!("Unexpected literal u32"), - } -} - fn apply_rewrite_rules(rewrite_rules: &HashMap, blocks: &mut [Block]) { let apply = |inst: &mut Instruction| { if let Some(ref mut id) = &mut inst.result_id { @@ -107,7 +83,7 @@ fn apply_rewrite_rules(rewrite_rules: &HashMap, blocks: &mut [Block] } inst.operands.iter_mut().for_each(|op| { - if let Some(id) = operand_idref_mut(op) { + if let Some(id) = op.id_ref_any_mut() { if let Some(&rewrite) = rewrite_rules.get(id) { *id = rewrite; } @@ -192,23 +168,18 @@ pub fn link( for inst in &output.types_global_values { match inst.class.opcode { Op::TypePointer => { - pointer_to_pointee.insert( - inst.result_id.unwrap(), - operand_idref(&inst.operands[1]).unwrap(), - ); + pointer_to_pointee + .insert(inst.result_id.unwrap(), inst.operands[1].unwrap_id_ref()); } Op::TypeInt - if inst.operands[0] == Operand::LiteralInt32(32) - && inst.operands[1] == Operand::LiteralInt32(0) => + if inst.operands[0].unwrap_literal_int32() == 32 + && inst.operands[1].unwrap_literal_int32() == 0 => { assert!(u32.is_none()); u32 = Some(inst.result_id.unwrap()); } Op::Constant if u32.is_some() && inst.result_type == u32 => { - let value = match inst.operands[0] { - Operand::LiteralInt32(value) => value, - _ => panic!(), - }; + let value = inst.operands[0].unwrap_literal_int32(); constants.insert(inst.result_id.unwrap(), value); } _ => {} diff --git a/rspirv-linker/src/mem2reg.rs b/rspirv-linker/src/mem2reg.rs index f93a5b00bf..b62b05d430 100644 --- a/rspirv-linker/src/mem2reg.rs +++ b/rspirv-linker/src/mem2reg.rs @@ -10,7 +10,7 @@ //! wikipedia calls "treat pruning as a dead code elimination problem"). use crate::simple_passes::outgoing_edges; -use crate::{apply_rewrite_rules, id, label_of, operand_idref}; +use crate::{apply_rewrite_rules, id}; use rspirv::dr::{Block, Function, Instruction, ModuleHeader, Operand}; use rspirv::spirv::{Op, Word}; use std::collections::{hash_map, HashMap, HashSet}; @@ -37,7 +37,10 @@ pub fn compute_preds(blocks: &[Block]) -> Vec> { let mut result = vec![vec![]; blocks.len()]; for (source_idx, source) in blocks.iter().enumerate() { for dest_id in outgoing_edges(source) { - let dest_idx = blocks.iter().position(|b| label_of(b) == dest_id).unwrap(); + let dest_idx = blocks + .iter() + .position(|b| b.label_id().unwrap() == dest_id) + .unwrap(); result[dest_idx].push(source_idx); } } @@ -171,7 +174,7 @@ fn collect_access_chains( indices: { let mut base_indicies = base.indices.clone(); for op in inst.operands.iter().skip(1) { - base_indicies.push(*constants.get(&operand_idref(op).unwrap())?) + base_indicies.push(*constants.get(&op.id_ref_any().unwrap())?) } base_indicies }, @@ -201,7 +204,7 @@ fn collect_access_chains( } } if let Op::AccessChain | Op::InBoundsAccessChain = inst.class.opcode { - if let Some(base) = variables.get(&operand_idref(&inst.operands[0]).unwrap()) { + if let Some(base) = variables.get(&inst.operands[0].id_ref_any().unwrap()) { let info = construct_access_chain_info(pointer_to_pointee, constants, inst, base)?; match variables.entry(inst.result_id.unwrap()) { @@ -224,7 +227,7 @@ fn collect_access_chains( fn has_store(block: &Block, var_map: &HashMap) -> bool { block.instructions.iter().any(|inst| { let ptr = match inst.class.opcode { - Op::Store => operand_idref(&inst.operands[0]).unwrap(), + Op::Store => inst.operands[0].id_ref_any().unwrap(), Op::Variable if inst.operands.len() < 2 => return false, Op::Variable => inst.result_id.unwrap(), _ => return false, @@ -274,7 +277,7 @@ struct Renamer<'a> { impl Renamer<'_> { // Returns the phi definition. fn insert_phi_value(&mut self, block: usize, from_block: usize) -> Word { - let from_block_label = label_of(&self.blocks[from_block]); + let from_block_label = self.blocks[from_block].label_id().unwrap(); let phi_defs = &self.phi_defs; let existing_phi = self.blocks[block].instructions.iter_mut().find(|inst| { inst.class.opcode == Op::Phi && phi_defs.contains(&inst.result_id.unwrap()) @@ -325,14 +328,14 @@ impl Renamer<'_> { for inst in &mut self.blocks[block].instructions { if inst.class.opcode == Op::Variable && inst.operands.len() > 1 { let ptr = inst.result_id.unwrap(); - let val = operand_idref(&inst.operands[1]).unwrap(); + let val = inst.operands[1].id_ref_any().unwrap(); if let Some(var_info) = self.var_map.get(&ptr) { assert_eq!(var_info.indices, []); self.stack.push(val); } } else if inst.class.opcode == Op::Store { - let ptr = operand_idref(&inst.operands[0]).unwrap(); - let val = operand_idref(&inst.operands[1]).unwrap(); + let ptr = inst.operands[0].id_ref_any().unwrap(); + let val = inst.operands[1].id_ref_any().unwrap(); if let Some(var_info) = self.var_map.get(&ptr) { if var_info.indices.is_empty() { *inst = Instruction::new(Op::Nop, None, None, vec![]); @@ -353,7 +356,7 @@ impl Renamer<'_> { } } } else if inst.class.opcode == Op::Load { - let ptr = operand_idref(&inst.operands[0]).unwrap(); + let ptr = inst.operands[0].id_ref_any().unwrap(); if let Some(var_info) = self.var_map.get(&ptr) { let loaded_val = inst.result_id.unwrap(); let current_obj = *self.stack.last().unwrap(); @@ -382,7 +385,7 @@ impl Renamer<'_> { let dest_idx = self .blocks .iter() - .position(|b| label_of(b) == dest_id) + .position(|b| b.label_id().unwrap() == dest_id) .unwrap(); self.rename(dest_idx, Some(block)); } @@ -414,7 +417,7 @@ fn remove_old_variables(blocks: &mut [Block], thing: &[(HashMap, u block.instructions.retain(|inst| { !matches!(inst.class.opcode, Op::AccessChain | Op::InBoundsAccessChain) || inst.operands.iter().all(|op| { - operand_idref(op).map_or(true, |id| { + op.id_ref_any().map_or(true, |id| { thing.iter().all(|(var_map, _)| !var_map.contains_key(&id)) }) }) diff --git a/rspirv-linker/src/simple_passes.rs b/rspirv-linker/src/simple_passes.rs index 8788e09b63..43747f6f01 100644 --- a/rspirv-linker/src/simple_passes.rs +++ b/rspirv-linker/src/simple_passes.rs @@ -1,5 +1,4 @@ -use crate::{label_of, operand_idref_mut}; -use rspirv::dr::{Block, Function, Module, Operand}; +use rspirv::dr::{Block, Function, Module}; use rspirv::spirv::{Op, Word}; use std::collections::{HashMap, HashSet}; use std::iter::once; @@ -16,7 +15,7 @@ pub fn shift_ids(module: &mut Module, add: u32) { } inst.operands.iter_mut().for_each(|op| { - if let Some(w) = operand_idref_mut(op) { + if let Some(w) = op.id_ref_any_mut() { *w += add } }) @@ -40,7 +39,11 @@ pub fn block_ordering_pass(func: &mut Function) { if !visited.insert(current) { return; } - let current_block = func.blocks.iter().find(|b| label_of(b) == current).unwrap(); + let current_block = func + .blocks + .iter() + .find(|b| b.label_id().unwrap() == current) + .unwrap(); // Reverse the order, so reverse-postorder keeps things tidy for &outgoing in outgoing_edges(current_block).iter().rev() { visit_postorder(func, visited, postorder, outgoing); @@ -51,7 +54,7 @@ pub fn block_ordering_pass(func: &mut Function) { let mut visited = HashSet::new(); let mut postorder = Vec::new(); - let entry_label = label_of(&func.blocks[0]); + let entry_label = func.blocks[0].label_id().unwrap(); visit_postorder(func, &mut visited, &mut postorder, entry_label); let mut old_blocks = replace(&mut func.blocks, Vec::new()); @@ -59,35 +62,29 @@ pub fn block_ordering_pass(func: &mut Function) { for &block in postorder.iter().rev() { let index = old_blocks .iter() - .position(|b| label_of(b) == block) + .position(|b| b.label_id().unwrap() == block) .unwrap(); func.blocks.push(old_blocks.remove(index)); } // Note: if old_blocks isn't empty here, that means there were unreachable blocks that were deleted. - assert_eq!(label_of(&func.blocks[0]), entry_label); + assert_eq!(func.blocks[0].label_id().unwrap(), entry_label); } pub fn outgoing_edges(block: &Block) -> Vec { - fn unwrap_id_ref(operand: &Operand) -> Word { - match *operand { - Operand::IdRef(word) => word, - _ => panic!("Expected Operand::IdRef: {}", operand), - } - } let terminator = block.instructions.last().unwrap(); // https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#Termination match terminator.class.opcode { - Op::Branch => vec![unwrap_id_ref(&terminator.operands[0])], + Op::Branch => vec![terminator.operands[0].unwrap_id_ref()], Op::BranchConditional => vec![ - unwrap_id_ref(&terminator.operands[1]), - unwrap_id_ref(&terminator.operands[2]), + terminator.operands[1].unwrap_id_ref(), + terminator.operands[2].unwrap_id_ref(), ], - Op::Switch => once(unwrap_id_ref(&terminator.operands[1])) + Op::Switch => once(terminator.operands[1].unwrap_id_ref()) .chain( terminator.operands[3..] .iter() .step_by(2) - .map(unwrap_id_ref), + .map(|op| op.unwrap_id_ref()), ) .collect(), Op::Return | Op::ReturnValue | Op::Kill | Op::Unreachable => Vec::new(), @@ -113,7 +110,7 @@ pub fn compact_ids(module: &mut Module) -> u32 { } inst.operands.iter_mut().for_each(|op| { - if let Some(w) = operand_idref_mut(op) { + if let Some(w) = op.id_ref_any_mut() { *w = insert(*w); } }) diff --git a/rspirv-linker/src/ty.rs b/rspirv-linker/src/ty.rs index 494944d971..b3dda3dba8 100644 --- a/rspirv-linker/src/ty.rs +++ b/rspirv-linker/src/ty.rs @@ -1,5 +1,5 @@ -use crate::{extract_literal_int_as_u64, extract_literal_u32, DefAnalyzer}; -use rspirv::dr::{Instruction, Operand}; +use crate::{extract_literal_int_as_u64, DefAnalyzer}; +use rspirv::dr::Instruction; use rspirv::spirv::{AccessQualifier, Dim, ImageFormat, Op, StorageClass}; #[derive(PartialEq, Debug)] @@ -33,32 +33,17 @@ fn trans_scalar_type(inst: &Instruction) -> Option { Op::TypeNamedBarrier => ScalarType::NamedBarrier, Op::TypeSampler => ScalarType::Sampler, Op::TypeForwardPointer => ScalarType::ForwardPointer { - storage_class: match inst.operands[0] { - Operand::StorageClass(s) => s, - _ => panic!("Unexpected operand while parsing type"), - }, + storage_class: inst.operands[0].unwrap_storage_class(), }, Op::TypeInt => ScalarType::Int { - width: match inst.operands[0] { - Operand::LiteralInt32(w) => w, - _ => panic!("Unexpected operand while parsing type"), - }, - signed: match inst.operands[1] { - Operand::LiteralInt32(s) => s != 0, - _ => panic!("Unexpected operand while parsing type"), - }, + width: inst.operands[0].unwrap_literal_int32(), + signed: inst.operands[1].unwrap_literal_int32() != 0, }, Op::TypeFloat => ScalarType::Float { - width: match inst.operands[0] { - Operand::LiteralInt32(w) => w, - _ => panic!("Unexpected operand while parsing type"), - }, + width: inst.operands[0].unwrap_literal_int32(), }, Op::TypeOpaque => ScalarType::Opaque { - name: match &inst.operands[0] { - Operand::LiteralString(s) => s.clone(), - _ => panic!("Unexpected operand while parsing type"), - }, + name: inst.operands[0].unwrap_literal_string().to_string(), }, _ => return None, }) @@ -139,10 +124,7 @@ pub(crate) fn trans_aggregate_type(def: &DefAnalyzer, inst: &Instruction) -> Opt } } Op::TypePointer => AggregateType::Pointer { - storage_class: match inst.operands[0] { - Operand::StorageClass(s) => s, - _ => panic!("Unexpected operand while parsing type"), - }, + storage_class: inst.operands[0].unwrap_storage_class(), ty: Box::new( trans_aggregate_type(def, &def.op_def(&inst.operands[1])) .expect("Expect base type for OpTypePointer"), @@ -186,26 +168,13 @@ pub(crate) fn trans_aggregate_type(def: &DefAnalyzer, inst: &Instruction) -> Opt trans_aggregate_type(def, &def.op_def(&inst.operands[0])) .expect("Expect base type for OpTypeImage"), ), - dim: match inst.operands[1] { - Operand::Dim(d) => d, - _ => panic!("Invalid dim"), - }, - depth: extract_literal_u32(&inst.operands[2]), - arrayed: extract_literal_u32(&inst.operands[3]), - multi_sampled: extract_literal_u32(&inst.operands[4]), - sampled: extract_literal_u32(&inst.operands[5]), - format: match inst.operands[6] { - Operand::ImageFormat(f) => f, - _ => panic!("Invalid image format"), - }, - access: inst - .operands - .get(7) - .map(|op| match op { - Operand::AccessQualifier(a) => Some(*a), - _ => None, - }) - .flatten(), + dim: inst.operands[1].unwrap_dim(), + depth: inst.operands[2].unwrap_literal_int32(), + arrayed: inst.operands[3].unwrap_literal_int32(), + multi_sampled: inst.operands[4].unwrap_literal_int32(), + sampled: inst.operands[5].unwrap_literal_int32(), + format: inst.operands[6].unwrap_image_format(), + access: inst.operands.get(7).map(|op| op.unwrap_access_qualifier()), }, _ => { if let Some(ty) = trans_scalar_type(inst) { diff --git a/rspirv-linker/src/zombies.rs b/rspirv-linker/src/zombies.rs index fdfafbbeec..d39faef45a 100644 --- a/rspirv-linker/src/zombies.rs +++ b/rspirv-linker/src/zombies.rs @@ -1,7 +1,6 @@ //! See documentation on CodegenCx::zombie for a description of the zombie system. -use crate::operand_idref; -use rspirv::dr::{Instruction, Module, Operand}; +use rspirv::dr::{Instruction, Module}; use rspirv::spirv::{Decoration, Op, Word}; use std::collections::{hash_map, HashMap}; use std::env; @@ -13,15 +12,11 @@ fn collect_zombies(module: &Module) -> Vec<(Word, String)> { .filter_map(|inst| { // TODO: Temp hack. We hijack UserTypeGOOGLE right now, since the compiler never emits this. if inst.class.opcode == Op::DecorateString - && inst.operands[1] == Operand::Decoration(Decoration::UserTypeGOOGLE) + && inst.operands[1].unwrap_decoration() == Decoration::UserTypeGOOGLE { - if let (&Operand::IdRef(id), Operand::LiteralString(reason)) = - (&inst.operands[0], &inst.operands[2]) - { - return Some((id, reason.to_string())); - } else { - panic!("Invalid OpDecorateString") - } + let id = inst.operands[0].unwrap_id_ref(); + let reason = inst.operands[2].unwrap_literal_string(); + return Some((id, reason.to_string())); } None }) @@ -31,7 +26,7 @@ fn collect_zombies(module: &Module) -> Vec<(Word, String)> { fn remove_zombie_annotations(module: &mut Module) { module.annotations.retain(|inst| { inst.class.opcode != Op::DecorateString - || inst.operands[1] != Operand::Decoration(Decoration::UserTypeGOOGLE) + || inst.operands[1].unwrap_decoration() != Decoration::UserTypeGOOGLE }) } @@ -43,7 +38,7 @@ fn contains_zombie<'a>(inst: &Instruction, zombie: &HashMap) -> O } inst.operands .iter() - .find_map(|op| operand_idref(op).and_then(|w| zombie.get(&w).copied())) + .find_map(|op| op.id_ref_any().and_then(|w| zombie.get(&w).copied())) } fn is_zombie<'a>(inst: &Instruction, zombie: &HashMap) -> Option<&'a str> { @@ -144,13 +139,12 @@ pub fn remove_zombies(module: &mut Module) { if let Some(reason) = is_zombie(f.def.as_ref().unwrap(), &zombies) { let name_id = f.def.as_ref().unwrap().result_id.unwrap(); let name = module.debugs.iter().find(|inst| { - inst.class.opcode == Op::Name && inst.operands[0] == Operand::IdRef(name_id) + inst.class.opcode == Op::Name && inst.operands[0].unwrap_id_ref() == name_id }); let name = match name { - Some(Instruction { ref operands, .. }) => match operands as &[Operand] { - [_, Operand::LiteralString(name)] => name.clone(), - _ => panic!(), - }, + Some(Instruction { ref operands, .. }) => { + operands[1].unwrap_literal_string().to_string() + } _ => format!("{}", name_id), }; println!("Function removed {:?} because {:?}", name, reason) diff --git a/rustc_codegen_spirv/src/abi.rs b/rustc_codegen_spirv/src/abi.rs index 8f5c85dc61..68c1c44db1 100644 --- a/rustc_codegen_spirv/src/abi.rs +++ b/rustc_codegen_spirv/src/abi.rs @@ -4,6 +4,7 @@ use crate::codegen_cx::CodegenCx; use crate::spirv_type::SpirvType; use crate::symbols::{parse_attr, SpirvAttribute}; use rspirv::spirv::{StorageClass, Word}; +use rustc_middle::bug; use rustc_middle::ty::layout::{FnAbiExt, TyAndLayout}; use rustc_middle::ty::{GeneratorSubsts, PolyFnSig, Ty, TyKind, TypeAndMut}; use rustc_target::abi::call::{CastTarget, FnAbi, PassMode, Reg, RegKind}; @@ -68,7 +69,7 @@ impl<'tcx> RecursivePointeeCache<'tcx> { // Warning: storage_class must match the one called with begin() match self.map.borrow_mut().entry((pointee, storage_class)) { // We should have hit begin() on this type already, which always inserts an entry. - Entry::Vacant(_) => panic!("RecursivePointeeCache::end should always have entry"), + Entry::Vacant(_) => bug!("RecursivePointeeCache::end should always have entry"), Entry::Occupied(mut entry) => match *entry.get() { // State: There have been no recursive references to this type while defining it, and so no // OpTypeForwardPointer has been emitted. This is the most common case. @@ -92,7 +93,7 @@ impl<'tcx> RecursivePointeeCache<'tcx> { .def_with_id(cx, id) } PointeeDefState::Defined(_) => { - panic!("RecursivePointeeCache::end defined pointer twice") + bug!("RecursivePointeeCache::end defined pointer twice") } }, } @@ -350,7 +351,7 @@ pub fn scalar_pair_element_backend_type<'tcx>( ) -> Word { let scalar = match &ty.layout.abi { Abi::ScalarPair(a, b) => [a, b][index], - other => panic!("scalar_pair_element_backend_type invalid abi: {:?}", other), + other => bug!("scalar_pair_element_backend_type invalid abi: {:?}", other), }; trans_scalar(cx, ty, scalar, Some(index), is_immediate) } @@ -445,10 +446,10 @@ fn dig_scalar_pointee<'tcx>( TyKind::Tuple(_) | TyKind::Adt(..) | TyKind::Closure(..) => { dig_scalar_pointee_adt(cx, ty, index) } - ref kind => panic!( + ref kind => cx.tcx.sess.fatal(&format!( "TODO: Unimplemented Primitive::Pointer TyKind index={:?} ({:#?}):\n{:#?}", index, kind, ty - ), + )), } } @@ -471,10 +472,10 @@ fn dig_scalar_pointee_adt<'tcx>( .. } => { match *tag_encoding { - TagEncoding::Direct => panic!( + TagEncoding::Direct => cx.tcx.sess.fatal(&format!( "dig_scalar_pointee_adt Variants::Multiple TagEncoding::Direct makes no sense: {:#?}", ty - ), + )), TagEncoding::Niche { dataful_variant, .. } => { // This *should* be something like Option<&T>: a very simple enum. // TODO: This might not be, if it's a scalar pair? @@ -486,7 +487,7 @@ fn dig_scalar_pointee_adt<'tcx>( let field_ty = adt.variants[dataful_variant].fields[0].ty(cx.tcx, substs); dig_scalar_pointee(cx, cx.layout_of(field_ty), index) } else { - panic!("Variants::Multiple not TyKind::Adt: {:#?}", ty) + bug!("Variants::Multiple not TyKind::Adt: {:#?}", ty) } }, } @@ -503,14 +504,17 @@ fn dig_scalar_pointee_adt<'tcx>( 1 => dig_scalar_pointee(cx, fields[0], Some(index)), // This case right here is the cause of the comment handling TyKind::Ref. 2 => dig_scalar_pointee(cx, fields[index], None), - other => panic!( + other => cx.tcx.sess.fatal(&format!( "Unable to dig scalar pair pointer type: fields length {}", other - ), + )), }, None => match fields.len() { 1 => dig_scalar_pointee(cx, fields[0], None), - other => panic!("Unable to dig scalar pointer type: fields length {}", other), + other => cx.tcx.sess.fatal(&format!( + "Unable to dig scalar pointer type: fields length {}", + other + )), }, } } @@ -518,10 +522,10 @@ fn dig_scalar_pointee_adt<'tcx>( match (storage_class, result) { (storage_class, (None, result)) => (storage_class, result), (None, (storage_class, result)) => (storage_class, result), - (Some(one), (Some(two), _)) => panic!( + (Some(one), (Some(two), _)) => cx.tcx.sess.fatal(&format!( "Double-applied storage class ({:?} and {:?}) on type {}", one, two, ty.ty - ), + )), } } @@ -541,10 +545,10 @@ fn get_storage_class<'tcx>(cx: &CodegenCx<'tcx>, ty: TyAndLayout<'tcx>) -> Optio fn trans_aggregate<'tcx>(cx: &CodegenCx<'tcx>, ty: TyAndLayout<'tcx>) -> Word { match ty.fields { - FieldsShape::Primitive => panic!( + FieldsShape::Primitive => cx.tcx.sess.fatal(&format!( "FieldsShape::Primitive not supported yet in trans_type: {:?}", ty - ), + )), FieldsShape::Union(_) => { assert_ne!(ty.size.bytes(), 0, "{:#?}", ty); assert!(!ty.is_unsized(), "{:#?}", ty); @@ -659,12 +663,12 @@ fn trans_struct<'tcx>(cx: &CodegenCx<'tcx>, ty: TyAndLayout<'tcx>) -> Word { } else { if let TyKind::Adt(_, _) = ty.ty.kind() { } else { - panic!("Variants::Multiple not supported for non-TyKind::Adt"); + bug!("Variants::Multiple not TyKind::Adt"); } if i == 0 { field_names.push("discriminant".to_string()); } else { - panic!("Variants::Multiple has multiple fields") + cx.tcx.sess.fatal("Variants::Multiple has multiple fields") } }; } diff --git a/rustc_codegen_spirv/src/builder/builder_methods.rs b/rustc_codegen_spirv/src/builder/builder_methods.rs index 2be0d3c370..2cb7a970a1 100644 --- a/rustc_codegen_spirv/src/builder/builder_methods.rs +++ b/rustc_codegen_spirv/src/builder/builder_methods.rs @@ -118,14 +118,17 @@ fn memset_dynamic_scalar( impl<'a, 'tcx> Builder<'a, 'tcx> { fn memset_const_pattern(&self, ty: &SpirvType, fill_byte: u8) -> Word { match *ty { - SpirvType::Void => panic!("memset invalid on void pattern"), - SpirvType::Bool => panic!("memset invalid on bool pattern"), + SpirvType::Void => self.fatal("memset invalid on void pattern"), + SpirvType::Bool => self.fatal("memset invalid on bool pattern"), SpirvType::Integer(width, _signedness) => match width { 8 => self.constant_u8(fill_byte).def, 16 => self.constant_u16(memset_fill_u16(fill_byte)).def, 32 => self.constant_u32(memset_fill_u32(fill_byte)).def, 64 => self.constant_u64(memset_fill_u64(fill_byte)).def, - _ => panic!("memset on integer width {} not implemented yet", width), + _ => self.fatal(&format!( + "memset on integer width {} not implemented yet", + width + )), }, SpirvType::Float(width) => match width { 32 => { @@ -136,10 +139,13 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { self.constant_f64(f64::from_bits(memset_fill_u64(fill_byte))) .def } - _ => panic!("memset on float width {} not implemented yet", width), + _ => self.fatal(&format!( + "memset on float width {} not implemented yet", + width + )), }, - SpirvType::Adt { .. } => panic!("memset on structs not implemented yet"), - SpirvType::Opaque { .. } => panic!("memset on opaque type is invalid"), + SpirvType::Adt { .. } => self.fatal("memset on structs not implemented yet"), + SpirvType::Opaque { .. } => self.fatal("memset on opaque type is invalid"), SpirvType::Vector { element, count } => { let elem_pat = self.memset_const_pattern(&self.lookup_type(element), fill_byte); self.constant_composite(ty.clone().def(self), vec![elem_pat; count as usize]) @@ -152,31 +158,37 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { .def } SpirvType::RuntimeArray { .. } => { - panic!("memset on runtime arrays not implemented yet") + self.fatal("memset on runtime arrays not implemented yet") } - SpirvType::Pointer { .. } => panic!("memset on pointers not implemented yet"), - SpirvType::Function { .. } => panic!("memset on functions not implemented yet"), + SpirvType::Pointer { .. } => self.fatal("memset on pointers not implemented yet"), + SpirvType::Function { .. } => self.fatal("memset on functions not implemented yet"), } } fn memset_dynamic_pattern(&self, ty: &SpirvType, fill_var: Word) -> Word { match *ty { - SpirvType::Void => panic!("memset invalid on void pattern"), - SpirvType::Bool => panic!("memset invalid on bool pattern"), + SpirvType::Void => self.fatal("memset invalid on void pattern"), + SpirvType::Bool => self.fatal("memset invalid on bool pattern"), SpirvType::Integer(width, _signedness) => match width { 8 => fill_var, 16 => memset_dynamic_scalar(self, fill_var, 2, false), 32 => memset_dynamic_scalar(self, fill_var, 4, false), 64 => memset_dynamic_scalar(self, fill_var, 8, false), - _ => panic!("memset on integer width {} not implemented yet", width), + _ => self.fatal(&format!( + "memset on integer width {} not implemented yet", + width + )), }, SpirvType::Float(width) => match width { 32 => memset_dynamic_scalar(self, fill_var, 4, true), 64 => memset_dynamic_scalar(self, fill_var, 8, true), - _ => panic!("memset on float width {} not implemented yet", width), + _ => self.fatal(&format!( + "memset on float width {} not implemented yet", + width + )), }, - SpirvType::Adt { .. } => panic!("memset on structs not implemented yet"), - SpirvType::Opaque { .. } => panic!("memset on opaque type is invalid"), + SpirvType::Adt { .. } => self.fatal("memset on structs not implemented yet"), + SpirvType::Opaque { .. } => self.fatal("memset on opaque type is invalid"), SpirvType::Array { element, count } => { let elem_pat = self.memset_dynamic_pattern(&self.lookup_type(element), fill_var); let count = self.builder.lookup_const_u64(count).unwrap() as usize; @@ -199,10 +211,10 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { .unwrap() } SpirvType::RuntimeArray { .. } => { - panic!("memset on runtime arrays not implemented yet") + self.fatal("memset on runtime arrays not implemented yet") } - SpirvType::Pointer { .. } => panic!("memset on pointers not implemented yet"), - SpirvType::Function { .. } => panic!("memset on functions not implemented yet"), + SpirvType::Pointer { .. } => self.fatal("memset on pointers not implemented yet"), + SpirvType::Function { .. } => self.fatal("memset on functions not implemented yet"), } } @@ -289,7 +301,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { .memory_model .as_ref() .map_or(false, |inst| { - inst.operands[0] == Operand::AddressingModel(AddressingModel::Logical) + inst.operands[0].unwrap_addressing_model() == AddressingModel::Logical }); if is_logical { self.zombie(def, "OpBitcast on ptr without AddressingModel != Logical") @@ -346,10 +358,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { let selected_function = &emit.module_ref().functions[selected_function]; let def_inst = selected_function.def.as_ref().unwrap(); let def = def_inst.result_id.unwrap(); - let ty = match def_inst.operands[1] { - Operand::IdRef(ty) => ty, - ref other => panic!("Invalid operand to function inst: {}", other), - }; + let ty = def_inst.operands[1].unwrap_id_ref(); def.with_type(ty) }; self.cursor = cursor; @@ -442,57 +451,74 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { // TODO: Remove once structurizer is done. self.zombie(else_llbb, "OpSwitch before structurizer is done"); } + + fn construct_8(self_: &Builder, signed: bool, v: u128) -> Operand { + if v > u8::MAX as u128 { + self_.fatal(&format!( + "Switches to values above u8::MAX not supported: {:?}", + v + )) + } else if signed { + // this cast chain can probably be collapsed, but, whatever, be safe + Operand::LiteralInt32(v as u8 as i8 as i32 as u32) + } else { + Operand::LiteralInt32(v as u8 as u32) + } + } + fn construct_16(self_: &Builder, signed: bool, v: u128) -> Operand { + if v > u16::MAX as u128 { + self_.fatal(&format!( + "Switches to values above u16::MAX not supported: {:?}", + v + )) + } else if signed { + Operand::LiteralInt32(v as u16 as i16 as i32 as u32) + } else { + Operand::LiteralInt32(v as u16 as u32) + } + } + fn construct_32(self_: &Builder, _signed: bool, v: u128) -> Operand { + if v > u32::MAX as u128 { + self_.fatal(&format!( + "Switches to values above u32::MAX not supported: {:?}", + v + )) + } else { + Operand::LiteralInt32(v as u32) + } + } + fn construct_64(self_: &Builder, _signed: bool, v: u128) -> Operand { + if v > u64::MAX as u128 { + self_.fatal(&format!( + "Switches to values above u64::MAX not supported: {:?}", + v + )) + } else { + Operand::LiteralInt64(v as u64) + } + } // pass in signed into the closure to be able to unify closure types let (signed, construct_case) = match self.lookup_type(v.ty) { SpirvType::Integer(width, signed) => { let construct_case = match width { - 8 => |signed, v| { - if v > u8::MAX as u128 { - panic!("Switches to values above u8::MAX not supported: {:?}", v) - } else if signed { - // this cast chain can probably be collapsed, but, whatever, be safe - Operand::LiteralInt32(v as u8 as i8 as i32 as u32) - } else { - Operand::LiteralInt32(v as u8 as u32) - } - }, - 16 => |signed, v| { - if v > u16::MAX as u128 { - panic!("Switches to values above u16::MAX not supported: {:?}", v) - } else if signed { - Operand::LiteralInt32(v as u16 as i16 as i32 as u32) - } else { - Operand::LiteralInt32(v as u16 as u32) - } - }, - 32 => |_signed, v| { - if v > u32::MAX as u128 { - panic!("Switches to values above u32::MAX not supported: {:?}", v) - } else { - Operand::LiteralInt32(v as u32) - } - }, - 64 => |_signed, v| { - if v > u64::MAX as u128 { - panic!("Switches to values above u64::MAX not supported: {:?}", v) - } else { - Operand::LiteralInt64(v as u64) - } - }, - other => panic!( - "switch selector cannot have width {} (only 32 and 64 bits allowed)", + 8 => construct_8, + 16 => construct_16, + 32 => construct_32, + 64 => construct_64, + other => self.fatal(&format!( + "switch selector cannot have width {} (only 8, 16, 32, and 64 bits allowed)", other - ), + )), }; (signed, construct_case) } - other => panic!( + other => self.fatal(&format!( "switch selector cannot have non-integer type {}", other.debug(v.ty, self) - ), + )), }; let cases = cases - .map(|(i, b)| (construct_case(signed, i), b)) + .map(|(i, b)| (construct_case(self, signed, i), b)) .collect::>(); self.emit().switch(v.def, else_llbb, cases).unwrap() } @@ -552,7 +578,10 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { match self.lookup_type(ty) { SpirvType::Integer(_, _) => self.emit().bitwise_and(ty, None, lhs.def, rhs.def), SpirvType::Bool => self.emit().logical_and(ty, None, lhs.def, rhs.def), - o => panic!("and() not implemented for type {}", o.debug(ty, self)), + o => self.fatal(&format!( + "and() not implemented for type {}", + o.debug(ty, self) + )), } .unwrap() .with_type(ty) @@ -563,7 +592,10 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { match self.lookup_type(ty) { SpirvType::Integer(_, _) => self.emit().bitwise_or(ty, None, lhs.def, rhs.def), SpirvType::Bool => self.emit().logical_or(ty, None, lhs.def, rhs.def), - o => panic!("or() not implemented for type {}", o.debug(ty, self)), + o => self.fatal(&format!( + "or() not implemented for type {}", + o.debug(ty, self) + )), } .unwrap() .with_type(ty) @@ -574,7 +606,10 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { match self.lookup_type(ty) { SpirvType::Integer(_, _) => self.emit().bitwise_xor(ty, None, lhs.def, rhs.def), SpirvType::Bool => self.emit().logical_not_equal(ty, None, lhs.def, rhs.def), - o => panic!("xor() not implemented for type {}", o.debug(ty, self)), + o => self.fatal(&format!( + "xor() not implemented for type {}", + o.debug(ty, self) + )), } .unwrap() .with_type(ty) @@ -588,7 +623,10 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { self.emit() .logical_not_equal(val.ty, None, val.def, true_.def) } - o => panic!("not() not implemented for type {}", o.debug(val.ty, self)), + o => self.fatal(&format!( + "not() not implemented for type {}", + o.debug(val.ty, self) + )), } .unwrap() .with_type(val.ty) @@ -678,7 +716,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { } fn array_alloca(&mut self, _ty: Self::Type, _len: Self::Value, _align: Align) -> Self::Value { - panic!("TODO: array_alloca not supported yet") + self.fatal("TODO: array_alloca not supported yet") } fn load(&mut self, ptr: Self::Value, _align: Align) -> Self::Value { @@ -687,7 +725,10 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { storage_class: _, pointee, } => pointee, - ty => panic!("load called on variable that wasn't a pointer: {:?}", ty), + ty => self.fatal(&format!( + "load called on variable that wasn't a pointer: {:?}", + ty + )), }; self.emit() .load(ty, None, ptr.def, None, empty()) @@ -706,10 +747,10 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { storage_class: _, pointee, } => pointee, - ty => panic!( + ty => self.fatal(&format!( "atomic_load called on variable that wasn't a pointer: {:?}", ty - ), + )), }; // TODO: Default to device scope let memory = self.constant_u32(Scope::Device as u32); @@ -833,7 +874,10 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { storage_class: _, pointee, } => pointee, - ty => panic!("store called on variable that wasn't a pointer: {:?}", ty), + ty => self.fatal(&format!( + "store called on variable that wasn't a pointer: {:?}", + ty + )), }; assert_ty_eq!(self, ptr_elem_ty, val.ty); self.emit().store(ptr.def, val.def, None, empty()).unwrap(); @@ -862,10 +906,10 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { storage_class: _, pointee, } => pointee, - ty => panic!( + ty => self.fatal(&format!( "atomic_store called on variable that wasn't a pointer: {:?}", ty - ), + )), }; assert_ty_eq!(self, ptr_elem_ty, val.ty); // TODO: Default to device scope @@ -895,12 +939,15 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { SpirvType::Array { element, .. } | SpirvType::RuntimeArray { element, .. } | SpirvType::Vector { element, .. } => (storage_class, element), - other => panic!( + other => self.fatal(&format!( "struct_gep not on struct, array, or vector type: {:?}, index {}", other, idx - ), + )), }, - other => panic!("struct_gep not on pointer type: {:?}, index {}", other, idx), + other => self.fatal(&format!( + "struct_gep not on pointer type: {:?}, index {}", + other, idx + )), }; let result_type = SpirvType::Pointer { storage_class, @@ -910,7 +957,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { // Important! LLVM, and therefore intel-compute-runtime, require the `getelementptr` instruction (and therefore // OpAccessChain) on structs to be a constant i32. Not i64! i32. if idx > u32::MAX as u64 { - panic!("struct_gep bigger than u32::MAX"); + self.fatal("struct_gep bigger than u32::MAX"); } let index_const = self.constant_u32(idx as u32).def; self.emit() @@ -1011,7 +1058,10 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { fn ptrtoint(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value { match self.lookup_type(val.ty) { SpirvType::Pointer { .. } => (), - other => panic!("ptrtoint called on non-pointer source type: {:?}", other), + other => self.fatal(&format!( + "ptrtoint called on non-pointer source type: {:?}", + other + )), } if val.ty == dest_ty { val @@ -1029,7 +1079,10 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { fn inttoptr(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value { match self.lookup_type(dest_ty) { SpirvType::Pointer { .. } => (), - other => panic!("inttoptr called on non-pointer dest type: {:?}", other), + other => self.fatal(&format!( + "inttoptr called on non-pointer dest type: {:?}", + other + )), } if val.ty == dest_ty { val @@ -1107,21 +1160,27 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { .unwrap() .with_type(dest_ty) } - (val_ty, dest_ty_spv) => panic!( + (val_ty, dest_ty_spv) => self.fatal(&format!( "TODO: intcast not implemented yet: val={:?} val.ty={:?} dest_ty={:?} is_signed={}", val, val_ty, dest_ty_spv, is_signed - ), + )), } } fn pointercast(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value { let val_pointee = match self.lookup_type(val.ty) { SpirvType::Pointer { pointee, .. } => pointee, - other => panic!("pointercast called on non-pointer source type: {:?}", other), + other => self.fatal(&format!( + "pointercast called on non-pointer source type: {:?}", + other + )), }; let dest_pointee = match self.lookup_type(dest_ty) { SpirvType::Pointer { pointee, .. } => pointee, - other => panic!("pointercast called on non-pointer dest type: {:?}", other), + other => self.fatal(&format!( + "pointercast called on non-pointer dest type: {:?}", + other + )), }; if val.ty == dest_ty { val @@ -1226,10 +1285,10 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { self.zombie_convert_ptr_to_u(rhs); self.emit().u_less_than_equal(b, None, lhs, rhs) } - IntSGT => panic!("TODO: pointer operator IntSGT not implemented yet"), - IntSGE => panic!("TODO: pointer operator IntSGE not implemented yet"), - IntSLT => panic!("TODO: pointer operator IntSLT not implemented yet"), - IntSLE => panic!("TODO: pointer operator IntSLE not implemented yet"), + IntSGT => self.fatal("TODO: pointer operator IntSGT not implemented yet"), + IntSGE => self.fatal("TODO: pointer operator IntSGE not implemented yet"), + IntSLT => self.fatal("TODO: pointer operator IntSLT not implemented yet"), + IntSLE => self.fatal("TODO: pointer operator IntSLE not implemented yet"), }, SpirvType::Bool => match op { IntEQ => self.emit().logical_equal(b, None, lhs.def, rhs.def), @@ -1271,15 +1330,15 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { .unwrap(); self.emit().logical_or(b, None, lhs, rhs.def) } - IntSGT => panic!("TODO: boolean operator IntSGT not implemented yet"), - IntSGE => panic!("TODO: boolean operator IntSGE not implemented yet"), - IntSLT => panic!("TODO: boolean operator IntSLT not implemented yet"), - IntSLE => panic!("TODO: boolean operator IntSLE not implemented yet"), + IntSGT => self.fatal("TODO: boolean operator IntSGT not implemented yet"), + IntSGE => self.fatal("TODO: boolean operator IntSGE not implemented yet"), + IntSLT => self.fatal("TODO: boolean operator IntSLT not implemented yet"), + IntSLE => self.fatal("TODO: boolean operator IntSLE not implemented yet"), }, - other => panic!( + other => self.fatal(&format!( "Int comparison not implemented on {}", other.debug(lhs.ty, self) - ), + )), } .unwrap() .with_type(b) @@ -1361,10 +1420,10 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { ) { let elem_ty = match self.lookup_type(ptr.ty) { SpirvType::Pointer { pointee, .. } => pointee, - _ => panic!( + _ => self.fatal(&format!( "memset called on non-pointer type: {}", self.debug_type(ptr.ty) - ), + )), }; let elem_ty_spv = self.lookup_type(elem_ty); let pat = match self.builder.lookup_const_u64(fill_byte) { @@ -1399,7 +1458,10 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { fn extract_element(&mut self, vec: Self::Value, idx: Self::Value) -> Self::Value { let result_type = match self.lookup_type(vec.ty) { SpirvType::Vector { element, .. } => element, - other => panic!("extract_element not implemented on type {:?}", other), + other => self.fatal(&format!( + "extract_element not implemented on type {:?}", + other + )), }; match self.builder.lookup_const_u64(idx) { Some(const_index) => self.emit().composite_extract( @@ -1435,7 +1497,10 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { fn extract_value(&mut self, agg_val: Self::Value, idx: u64) -> Self::Value { let result_type = match self.lookup_type(agg_val.ty) { SpirvType::Adt { field_types, .. } => field_types[idx as usize], - other => panic!("extract_value not implemented on type {:?}", other), + other => self.fatal(&format!( + "extract_value not implemented on type {:?}", + other + )), }; self.emit() .composite_extract(result_type, None, agg_val.def, [idx as u32].iter().cloned()) @@ -1448,7 +1513,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { SpirvType::Adt { field_types, .. } => { assert_ty_eq!(self, field_types[idx as usize], elt.ty) } - other => panic!("insert_value not implemented on type {:?}", other), + other => self.fatal(&format!("insert_value not implemented on type {:?}", other)), }; self.emit() .composite_insert( @@ -1530,10 +1595,10 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { storage_class: _, pointee, } => pointee, - ty => panic!( + ty => self.fatal(&format!( "atomic_cmpxchg called on variable that wasn't a pointer: {:?}", ty - ), + )), }; assert_ty_eq!(self, dst_pointee_ty, cmp.ty); assert_ty_eq!(self, dst_pointee_ty, src.ty); @@ -1570,10 +1635,10 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { storage_class: _, pointee, } => pointee, - ty => panic!( + ty => self.fatal(&format!( "atomic_rmw called on variable that wasn't a pointer: {:?}", ty - ), + )), }; assert_ty_eq!(self, dst_pointee_ty, src.ty); self.validate_atomic(dst_pointee_ty, dst.def); @@ -1587,7 +1652,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { AtomicAdd => emit.atomic_i_add(src.ty, None, dst.def, memory, semantics, src.def), AtomicSub => emit.atomic_i_sub(src.ty, None, dst.def, memory, semantics, src.def), AtomicAnd => emit.atomic_and(src.ty, None, dst.def, memory, semantics, src.def), - AtomicNand => panic!("atomic nand is not supported"), + AtomicNand => self.fatal("atomic nand is not supported"), AtomicOr => emit.atomic_or(src.ty, None, dst.def, memory, semantics, src.def), AtomicXor => emit.atomic_xor(src.ty, None, dst.def, memory, semantics, src.def), AtomicMax => emit.atomic_s_max(src.ty, None, dst.def, memory, semantics, src.def), @@ -1638,7 +1703,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { funclet: Option<&Self::Funclet>, ) -> Self::Value { if funclet.is_some() { - panic!("TODO: Funclets are not supported"); + self.fatal("TODO: Funclets are not supported"); } // dereference pointers let (result_type, argument_types) = loop { @@ -1659,7 +1724,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { return_type, arguments, } => break (return_type, arguments), - ty => panic!("Calling non-function type: {:?}", ty), + ty => self.fatal(&format!("Calling non-function type: {:?}", ty)), } }; for (argument, argument_type) in args.iter().zip(argument_types) { diff --git a/rustc_codegen_spirv/src/builder/intrinsics.rs b/rustc_codegen_spirv/src/builder/intrinsics.rs index 146fd445d7..05190401ef 100644 --- a/rustc_codegen_spirv/src/builder/intrinsics.rs +++ b/rustc_codegen_spirv/src/builder/intrinsics.rs @@ -6,6 +6,7 @@ use rspirv::spirv::{CLOp, GLOp}; use rustc_codegen_ssa::mir::operand::OperandRef; use rustc_codegen_ssa::mir::place::PlaceRef; use rustc_codegen_ssa::traits::{BaseTypeMethods, BuilderMethods, IntrinsicCallMethods}; +use rustc_middle::bug; use rustc_middle::ty::{FnDef, Instance, ParamEnv, Ty, TyKind}; use rustc_span::source_map::Span; use rustc_span::sym; @@ -41,7 +42,7 @@ impl<'a, 'tcx> IntrinsicCallMethods<'tcx> for Builder<'a, 'tcx> { let (def_id, substs) = match *callee_ty.kind() { FnDef(def_id, substs) => (def_id, substs), - _ => panic!("expected fn item type, found {}", callee_ty), + _ => bug!("expected fn item type, found {}", callee_ty), }; let sig = callee_ty.fn_sig(self.tcx); @@ -91,7 +92,7 @@ impl<'a, 'tcx> IntrinsicCallMethods<'tcx> for Builder<'a, 'tcx> { self.add(args[0].immediate(), args[1].immediate()) } TyKind::Float(_) => self.fadd(args[0].immediate(), args[1].immediate()), - other => panic!("Unimplemented intrinsic type: {:#?}", other), + other => self.fatal(&format!("Unimplemented intrinsic type: {:#?}", other)), } } sym::saturating_sub => { @@ -101,7 +102,7 @@ impl<'a, 'tcx> IntrinsicCallMethods<'tcx> for Builder<'a, 'tcx> { self.sub(args[0].immediate(), args[1].immediate()) } TyKind::Float(_) => self.fsub(args[0].immediate(), args[1].immediate()), - other => panic!("Unimplemented intrinsic type: {:#?}", other), + other => self.fatal(&format!("Unimplemented intrinsic type: {:#?}", other)), } } @@ -227,7 +228,7 @@ impl<'a, 'tcx> IntrinsicCallMethods<'tcx> for Builder<'a, 'tcx> { if self.kernel_mode { self.cl_op(CLOp::copysign, [args[0].immediate(), args[1].immediate()]) } else { - panic!("TODO: Shader copysign not supported yet") + self.fatal("TODO: Shader copysign not supported yet") } } sym::floorf32 | sym::floorf64 => { @@ -381,16 +382,16 @@ impl<'a, 'tcx> IntrinsicCallMethods<'tcx> for Builder<'a, 'tcx> { let res3 = self.or(res3, res4); self.or(res1, res3) } - other => panic!("bswap not implemented for int width {}", other), + other => self.fatal(&format!("bswap not implemented for int width {}", other)), } } - _ => panic!("TODO: Unknown intrinsic '{}'", name), + _ => self.fatal(&format!("TODO: Unknown intrinsic '{}'", name)), }; if !fn_abi.ret.is_ignore() { if let PassMode::Cast(_ty) = fn_abi.ret.mode { - panic!("TODO: PassMode::Cast not implemented yet in intrinsics"); + self.fatal("TODO: PassMode::Cast not implemented yet in intrinsics"); } else { OperandRef::from_immediate_or_packed_pair(self, value, result.layout) .val diff --git a/rustc_codegen_spirv/src/builder/mod.rs b/rustc_codegen_spirv/src/builder/mod.rs index b82120fde3..b32ea8d43a 100644 --- a/rustc_codegen_spirv/src/builder/mod.rs +++ b/rustc_codegen_spirv/src/builder/mod.rs @@ -64,6 +64,32 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { } } + /* + pub fn struct_err(&self, msg: &str) -> DiagnosticBuilder<'_> { + if let Some(current_span) = *self.current_span.borrow() { + self.tcx.sess.struct_span_err(current_span, msg) + } else { + self.tcx.sess.struct_err(msg) + } + } + + pub fn err(&self, msg: &str) { + if let Some(current_span) = *self.current_span.borrow() { + self.tcx.sess.span_err(current_span, msg) + } else { + self.tcx.sess.err(msg) + } + } + */ + + pub fn fatal(&self, msg: &str) -> ! { + if let Some(current_span) = *self.current_span.borrow() { + self.tcx.sess.span_fatal(current_span, msg) + } else { + self.tcx.sess.fatal(msg) + } + } + pub fn gep_help( &self, ptr: SpirvValue, @@ -80,16 +106,19 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { storage_class, pointee, } => (storage_class, pointee), - other_type => panic!("GEP first deref not implemented for type {:?}", other_type), + other_type => self.fatal(&format!( + "GEP first deref not implemented for type {:?}", + other_type + )), }; for index in indices.iter().cloned().skip(1) { result_indices.push(index.def); result_pointee_type = match self.lookup_type(result_pointee_type) { SpirvType::Array { element, .. } | SpirvType::RuntimeArray { element } => element, - _ => panic!( + _ => self.fatal(&format!( "GEP not implemented for type {}", self.debug_type(result_pointee_type) - ), + )), }; } let result_type = SpirvType::Pointer { @@ -144,10 +173,10 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { fn rotate(&mut self, value: SpirvValue, shift: SpirvValue, is_left: bool) -> SpirvValue { let width = match self.lookup_type(shift.ty) { SpirvType::Integer(width, _) => width, - other => panic!( + other => self.fatal(&format!( "Cannot rotate non-integer type: {}", other.debug(shift.ty, self) - ), + )), }; let int_size = self.constant_int(shift.ty, width as u64); let mask = self.constant_int(shift.ty, (width - 1) as u64); @@ -283,7 +312,7 @@ impl<'a, 'tcx> ArgAbiMethods<'tcx> for Builder<'a, 'tcx> { if arg_abi.is_sized_indirect() { OperandValue::Ref(val, None, arg_abi.layout.align.abi).store(self, dst); } else if arg_abi.is_unsized_indirect() { - panic!("unsized `ArgAbi` must be handled through `store_fn_arg`"); + self.fatal("unsized `ArgAbi` must be handled through `store_fn_arg`"); } else if let PassMode::Cast(cast) = arg_abi.mode { let cast_ty = cast.spirv_type(self); let cast_ptr_ty = SpirvType::Pointer { diff --git a/rustc_codegen_spirv/src/builder_spirv.rs b/rustc_codegen_spirv/src/builder_spirv.rs index 6206f97c85..1d540aa75a 100644 --- a/rustc_codegen_spirv/src/builder_spirv.rs +++ b/rustc_codegen_spirv/src/builder_spirv.rs @@ -2,6 +2,7 @@ use bimap::BiHashMap; use rspirv::dr::{Block, Builder, Module, Operand}; use rspirv::spirv::{AddressingModel, Capability, MemoryModel, Op, Word}; use rspirv::{binary::Assemble, binary::Disassemble}; +use rustc_middle::bug; use std::cell::{RefCell, RefMut}; use std::{fs::File, io::Write, path::Path}; @@ -144,7 +145,7 @@ impl BuilderSpirv { .iter() .any(|inst| { inst.class.opcode == Op::Capability - && inst.operands[0] == Operand::Capability(capability) + && inst.operands[0].unwrap_capability() == capability }) } @@ -160,7 +161,7 @@ impl BuilderSpirv { } } - panic!("Function not found: {}", id); + bug!("Function not found: {}", id); } pub fn def_constant(&self, val: SpirvConst) -> SpirvValue { @@ -283,6 +284,6 @@ impl BuilderSpirv { } } - panic!("Block not found: {}", id); + bug!("Block not found: {}", id); } } diff --git a/rustc_codegen_spirv/src/codegen_cx/constant.rs b/rustc_codegen_spirv/src/codegen_cx/constant.rs index af0fb3c27e..365236bcd0 100644 --- a/rustc_codegen_spirv/src/codegen_cx/constant.rs +++ b/rustc_codegen_spirv/src/codegen_cx/constant.rs @@ -5,6 +5,7 @@ use crate::spirv_type::SpirvType; use rspirv::spirv::Word; use rustc_codegen_ssa::mir::place::PlaceRef; use rustc_codegen_ssa::traits::{BaseTypeMethods, ConstMethods, MiscMethods, StaticMethods}; +use rustc_middle::bug; use rustc_middle::mir::interpret::{read_target_uint, Allocation, GlobalAlloc, Pointer}; use rustc_middle::ty::layout::TyAndLayout; use rustc_mir::interpret::Scalar; @@ -63,14 +64,20 @@ impl<'tcx> CodegenCx<'tcx> { SpirvType::Bool => match val { 0 => self.builder.def_constant(SpirvConst::Bool(ty, false)), 1 => self.builder.def_constant(SpirvConst::Bool(ty, true)), - _ => panic!("Invalid constant value for bool: {}", val), + _ => self + .tcx + .sess + .fatal(&format!("Invalid constant value for bool: {}", val)), }, SpirvType::Integer(128, _) => { let result = self.undef(ty); self.zombie_no_span(result.def, "u128 constant"); result } - other => panic!("constant_int invalid on type {}", other.debug(ty, self)), + other => self.tcx.sess.fatal(&format!( + "constant_int invalid on type {}", + other.debug(ty, self) + )), } } @@ -94,7 +101,10 @@ impl<'tcx> CodegenCx<'tcx> { SpirvType::Float(64) => self .builder .def_constant(SpirvConst::F64(ty, val.to_bits())), - other => panic!("constant_float invalid on type {}", other.debug(ty, self)), + other => self.tcx.sess.fatal(&format!( + "constant_float invalid on type {}", + other.debug(ty, self) + )), } } @@ -213,12 +223,15 @@ impl<'tcx> ConstMethods<'tcx> for CodegenCx<'tcx> { SpirvType::Bool => match data { 0 => self.constant_bool(false), 1 => self.constant_bool(true), - _ => panic!("Invalid constant value for bool: {}", data), + _ => self + .tcx + .sess + .fatal(&format!("Invalid constant value for bool: {}", data)), }, - other => panic!( + other => self.tcx.sess.fatal(&format!( "scalar_to_backend Primitive::Int not supported on type {}", other.debug(ty, self) - ), + )), } } Primitive::F32 => { @@ -231,19 +244,17 @@ impl<'tcx> ConstMethods<'tcx> for CodegenCx<'tcx> { assert_eq!(res.ty, ty); res } - Primitive::Pointer => { - panic!("scalar_to_backend Primitive::Ptr is an invalid state") - } + Primitive::Pointer => bug!("scalar_to_backend Primitive::Ptr is an invalid state"), }, Scalar::Ptr(ptr) => { let (base_addr, _base_addr_space) = match self.tcx.global_alloc(ptr.alloc_id) { GlobalAlloc::Memory(alloc) => { let pointee = match self.lookup_type(ty) { SpirvType::Pointer { pointee, .. } => pointee, - other => panic!( + other => self.tcx.sess.fatal(&format!( "GlobalAlloc::Memory type not implemented: {}", other.debug(ty, self) - ), + )), }; let init = self.create_const_alloc(alloc, pointee); let value = self.static_addr_of(init, alloc.align, None); @@ -262,12 +273,16 @@ impl<'tcx> ConstMethods<'tcx> for CodegenCx<'tcx> { let value = if ptr.offset.bytes() == 0 { base_addr } else { - panic!("Non-constant scalar_to_backend ptr.offset not supported") + self.tcx + .sess + .fatal("Non-constant scalar_to_backend ptr.offset not supported") // let offset = self.constant_u64(ptr.offset.bytes()); // self.gep(base_addr, once(offset)) }; if layout.value != Primitive::Pointer { - panic!("Non-pointer-typed scalar_to_backend Scalar::Ptr not supported"); + self.tcx + .sess + .fatal("Non-pointer-typed scalar_to_backend Scalar::Ptr not supported"); // unsafe { llvm::LLVMConstPtrToInt(llval, llty) } } else { match (self.lookup_type(value.ty), self.lookup_type(ty)) { @@ -343,7 +358,10 @@ impl<'tcx> CodegenCx<'tcx> { // these print statements are really useful for debugging, so leave them easily available // println!("const at {}: {}", offset.bytes(), self.debug_type(ty)); match ty_concrete { - SpirvType::Void => panic!("Cannot create const alloc of type void"), + SpirvType::Void => self + .tcx + .sess + .fatal("Cannot create const alloc of type void"), SpirvType::Bool => self.constant_bool(self.read_alloc_val(alloc, offset, 1) != 0), SpirvType::Integer(width, _) => { let v = self.read_alloc_val(alloc, offset, (width / 8) as usize); @@ -354,7 +372,10 @@ impl<'tcx> CodegenCx<'tcx> { match width { 32 => self.constant_f32(f32::from_bits(v as u32)), 64 => self.constant_f64(f64::from_bits(v as u64)), - other => panic!("invalid float width {}", other), + other => self + .tcx + .sess + .fatal(&format!("invalid float width {}", other)), } } SpirvType::Adt { @@ -387,9 +408,10 @@ impl<'tcx> CodegenCx<'tcx> { Self::assert_uninit(alloc, base, *offset, occupied_spaces); self.constant_composite(ty, values) } - SpirvType::Opaque { name } => { - panic!("Cannot create const alloc of type opaque: {}", name) - } + SpirvType::Opaque { name } => self.tcx.sess.fatal(&format!( + "Cannot create const alloc of type opaque: {}", + name + )), SpirvType::Array { element, count } => { let count = self.builder.lookup_const_u64(count).unwrap() as usize; let values = (0..count) @@ -435,9 +457,10 @@ impl<'tcx> CodegenCx<'tcx> { ty, ) } - SpirvType::Function { .. } => { - panic!("TODO: SpirvType::Function not supported yet in create_const_alloc") - } + SpirvType::Function { .. } => self + .tcx + .sess + .fatal("TODO: SpirvType::Function not supported yet in create_const_alloc"), } } diff --git a/rustc_codegen_spirv/src/codegen_cx/declare.rs b/rustc_codegen_spirv/src/codegen_cx/declare.rs index 2fe7823c74..d02fb1308a 100644 --- a/rustc_codegen_spirv/src/codegen_cx/declare.rs +++ b/rustc_codegen_spirv/src/codegen_cx/declare.rs @@ -9,6 +9,7 @@ use rspirv::spirv::{ }; use rustc_attr::InlineAttr; use rustc_codegen_ssa::traits::{PreDefineMethods, StaticMethods}; +use rustc_middle::bug; use rustc_middle::middle::codegen_fn_attrs::{CodegenFnAttrFlags, CodegenFnAttrs}; use rustc_middle::mir::mono::{Linkage, MonoItem, Visibility}; use rustc_middle::ty::layout::FnAbiExt; @@ -78,7 +79,7 @@ impl<'tcx> CodegenCx<'tcx> { return_type, arguments, } => (return_type, arguments), - other => panic!("fn_abi type {}", other.debug(function_type, self)), + other => bug!("fn_abi type {}", other.debug(function_type, self)), }; if crate::is_blocklisted_fn(name) { @@ -169,10 +170,10 @@ impl<'tcx> CodegenCx<'tcx> { return_type, arguments, } => (return_type, arguments), - other => panic!( + other => self.tcx.sess.fatal(&format!( "Invalid entry_stub type: {}", other.debug(entry_func.ty, self) - ), + )), }; let mut emit = self.emit_global(); let mut decoration_locations = HashMap::new(); @@ -182,7 +183,10 @@ impl<'tcx> CodegenCx<'tcx> { .map(|&arg| { let storage_class = match self.lookup_type(arg) { SpirvType::Pointer { storage_class, .. } => storage_class, - other => panic!("Invalid entry arg type {}", other.debug(arg, self)), + other => self.tcx.sess.fatal(&format!( + "Invalid entry arg type {}", + other.debug(arg, self) + )), }; let has_location = match storage_class { StorageClass::Input | StorageClass::Output | StorageClass::UniformConstant => { @@ -245,10 +249,10 @@ impl<'tcx> CodegenCx<'tcx> { return_type, arguments, } => (return_type, arguments), - other => panic!( + other => self.tcx.sess.fatal(&format!( "Invalid kernel_entry_stub type: {}", other.debug(entry_func.ty, self) - ), + )), }; let mut emit = self.emit_global(); let fn_id = emit @@ -292,7 +296,10 @@ impl<'tcx> PreDefineMethods<'tcx> for CodegenCx<'tcx> { let linkage = match linkage { Linkage::External => Some(LinkageType::Export), Linkage::Internal => None, - other => panic!("TODO: Linkage type not supported yet: {:?}", other), + other => self.tcx.sess.fatal(&format!( + "TODO: Linkage type not supported yet: {:?}", + other + )), }; let span = self.tcx.def_span(def_id); @@ -321,7 +328,10 @@ impl<'tcx> PreDefineMethods<'tcx> for CodegenCx<'tcx> { let linkage2 = match linkage { Linkage::External => Some(LinkageType::Export), Linkage::Internal => None, - other => panic!("TODO: Linkage type not supported yet: {:?}", other), + other => self.tcx.sess.fatal(&format!( + "TODO: Linkage type not supported yet: {:?}", + other + )), }; let rust_attrs = self.tcx.codegen_fn_attrs(instance.def_id()); let spv_attrs = attrs_to_spirv(rust_attrs); @@ -374,7 +384,10 @@ impl<'tcx> StaticMethods for CodegenCx<'tcx> { }; let value_ty = match self.lookup_type(g.ty) { SpirvType::Pointer { pointee, .. } => pointee, - other => panic!("global had non-pointer type {}", other.debug(g.ty, self)), + other => self.tcx.sess.fatal(&format!( + "global had non-pointer type {}", + other.debug(g.ty, self) + )), }; let mut v = self.create_const_alloc(alloc, value_ty); @@ -383,7 +396,7 @@ impl<'tcx> StaticMethods for CodegenCx<'tcx> { let val_int = match val { SpirvConst::Bool(_, false) => 0, SpirvConst::Bool(_, true) => 0, - _ => panic!(), + _ => bug!(), }; v = self.constant_u8(val_int); } diff --git a/rustc_codegen_spirv/src/codegen_cx/type_.rs b/rustc_codegen_spirv/src/codegen_cx/type_.rs index 6bb27c513a..25882973d2 100644 --- a/rustc_codegen_spirv/src/codegen_cx/type_.rs +++ b/rustc_codegen_spirv/src/codegen_cx/type_.rs @@ -5,6 +5,7 @@ use rspirv::spirv::StorageClass; use rspirv::spirv::Word; use rustc_codegen_ssa::common::TypeKind; use rustc_codegen_ssa::traits::{BaseTypeMethods, LayoutTypeMethods}; +use rustc_middle::bug; use rustc_middle::ty::layout::{LayoutError, TyAndLayout}; use rustc_middle::ty::{ParamEnv, Ty}; use rustc_span::source_map::{Span, DUMMY_SP}; @@ -26,7 +27,7 @@ impl<'tcx> LayoutOf for CodegenCx<'tcx> { if let LayoutError::SizeOverflow(_) = e { self.tcx.sess.span_fatal(span, &e.to_string()) } else { - panic!("failed to get layout for `{}`: {}", ty, e) + bug!("failed to get layout for `{}`: {}", ty, e) } }) } @@ -59,13 +60,13 @@ impl<'tcx> LayoutTypeMethods<'tcx> for CodegenCx<'tcx> { fn backend_field_index(&self, layout: TyAndLayout<'tcx>, index: usize) -> u64 { match layout.abi { Abi::Scalar(_) | Abi::ScalarPair(..) => { - panic!("backend_field_index({:?}): not applicable", layout) + bug!("backend_field_index({:?}): not applicable", layout) } _ => {} } match layout.fields { FieldsShape::Primitive | FieldsShape::Union(_) => { - panic!("backend_field_index({:?}): not applicable", layout) + bug!("backend_field_index({:?}): not applicable", layout) } FieldsShape::Array { .. } => index as u64, // note: codegen_llvm implements this as 1+index*2 due to padding fields @@ -161,7 +162,10 @@ impl<'tcx> BaseTypeMethods<'tcx> for CodegenCx<'tcx> { 16 => TypeKind::Half, 32 => TypeKind::Float, 64 => TypeKind::Double, - other => panic!("Invalid float width in type_kind: {}", other), + other => self + .tcx + .sess + .fatal(&format!("Invalid float width in type_kind: {}", other)), }, SpirvType::Adt { .. } => TypeKind::Struct, SpirvType::Opaque { .. } => TypeKind::Struct, @@ -193,7 +197,10 @@ impl<'tcx> BaseTypeMethods<'tcx> for CodegenCx<'tcx> { pointee, } => pointee, SpirvType::Vector { element, .. } => element, - spirv_type => panic!("element_type called on invalid type: {:?}", spirv_type), + spirv_type => self.tcx.sess.fatal(&format!( + "element_type called on invalid type: {:?}", + spirv_type + )), } } @@ -201,14 +208,20 @@ impl<'tcx> BaseTypeMethods<'tcx> for CodegenCx<'tcx> { fn vector_length(&self, ty: Self::Type) -> usize { match self.lookup_type(ty) { SpirvType::Vector { count, .. } => count as usize, - ty => panic!("vector_length called on non-vector type: {:?}", ty), + ty => self.tcx.sess.fatal(&format!( + "vector_length called on non-vector type: {:?}", + ty + )), } } fn float_width(&self, ty: Self::Type) -> usize { match self.lookup_type(ty) { SpirvType::Float(width) => width as usize, - ty => panic!("float_width called on non-float type: {:?}", ty), + ty => self + .tcx + .sess + .fatal(&format!("float_width called on non-float type: {:?}", ty)), } } @@ -216,7 +229,10 @@ impl<'tcx> BaseTypeMethods<'tcx> for CodegenCx<'tcx> { fn int_width(&self, ty: Self::Type) -> u64 { match self.lookup_type(ty) { SpirvType::Integer(width, _) => width as u64, - ty => panic!("int_width called on non-integer type: {:?}", ty), + ty => self + .tcx + .sess + .fatal(&format!("int_width called on non-integer type: {:?}", ty)), } } diff --git a/rustc_codegen_spirv/src/lib.rs b/rustc_codegen_spirv/src/lib.rs index 11b328fb34..f5754c5847 100644 --- a/rustc_codegen_spirv/src/lib.rs +++ b/rustc_codegen_spirv/src/lib.rs @@ -138,11 +138,11 @@ struct SpirvMetadataLoader; impl MetadataLoader for SpirvMetadataLoader { fn get_rlib_metadata(&self, _: &Target, path: &Path) -> Result { - Ok(link::read_metadata(path)) + link::read_metadata(path) } fn get_dylib_metadata(&self, _: &Target, _: &Path) -> Result { - panic!("TODO: implement get_dylib_metadata"); + Err("TODO: implement get_dylib_metadata".to_string()) } } @@ -526,5 +526,3 @@ pub fn __rustc_codegen_backend() -> Box { Box::new(SpirvCodegenBackend) } - -// https://github.com/bjorn3/rustc_codegen_cranelift/blob/1b8df386aa72bc3dacb803f7d4deb4eadd63b56f/src/base.rs diff --git a/rustc_codegen_spirv/src/link.rs b/rustc_codegen_spirv/src/link.rs index af5337023a..c9984247f3 100644 --- a/rustc_codegen_spirv/src/link.rs +++ b/rustc_codegen_spirv/src/link.rs @@ -6,6 +6,7 @@ use rustc_data_structures::owning_ref::OwningRef; use rustc_data_structures::rustc_erase_owner; use rustc_data_structures::sync::MetadataRef; use rustc_errors::{DiagnosticBuilder, FatalError}; +use rustc_middle::bug; use rustc_middle::dep_graph::WorkProduct; use rustc_middle::middle::cstore::NativeLib; use rustc_middle::middle::dependency_format::Linkage; @@ -40,9 +41,10 @@ pub fn link<'a>( } if invalid_output_for_target(sess, crate_type) { - panic!( + bug!( "invalid output type `{:?}` for target os `{}`", - crate_type, sess.opts.target_triple + crate_type, + sess.opts.target_triple ); } @@ -58,18 +60,18 @@ pub fn link<'a>( let out_filename = out_filename(sess, crate_type, outputs, crate_name); match crate_type { CrateType::Rlib => { - link_rlib(codegen_results, &out_filename); + link_rlib(sess, codegen_results, &out_filename); } CrateType::Executable | CrateType::Cdylib | CrateType::Dylib => { link_exe(sess, crate_type, &out_filename, codegen_results, legalize) } - other => panic!("CrateType {:?} not supported yet", other), + other => sess.err(&format!("CrateType {:?} not supported yet", other)), } } } } -fn link_rlib(codegen_results: &CodegenResults, out_filename: &Path) { +fn link_rlib(sess: &Session, codegen_results: &CodegenResults, out_filename: &Path) { let mut file_list = Vec::<&Path>::new(); for obj in codegen_results .modules @@ -88,7 +90,10 @@ fn link_rlib(codegen_results: &CodegenResults, out_filename: &Path) { | NativeLibKind::Unspecified => continue, } if let Some(name) = lib.name { - panic!("Adding native library to rlib not supported yet: {}", name); + sess.err(&format!( + "Adding native library to rlib not supported yet: {}", + name + )); } } @@ -222,7 +227,7 @@ fn link_local_crate_native_libs_and_dependent_crate_libs<'a>( if sess.opts.debugging_opts.link_native_libraries { add_local_native_libraries(sess, codegen_results); } - add_upstream_rust_crates(rlibs, codegen_results, crate_type); + add_upstream_rust_crates(sess, rlibs, codegen_results, crate_type); if sess.opts.debugging_opts.link_native_libraries { add_upstream_native_libraries(sess, codegen_results, crate_type); } @@ -238,6 +243,7 @@ fn add_local_native_libraries(sess: &Session, codegen_results: &CodegenResults) } fn add_upstream_rust_crates( + sess: &Session, rlibs: &mut Vec, codegen_results: &CodegenResults, crate_type: CrateType, @@ -255,7 +261,7 @@ fn add_upstream_rust_crates( Linkage::NotLinked | Linkage::IncludedFromDylib => {} Linkage::Static => rlibs.push(src.rlib.as_ref().unwrap().0.clone()), //Linkage::Dynamic => rlibs.push(src.dylib.as_ref().unwrap().0.clone()), - Linkage::Dynamic => panic!("TODO: Linkage::Dynamic not supported yet"), + Linkage::Dynamic => sess.err("TODO: Linkage::Dynamic not supported yet"), } } } @@ -283,23 +289,25 @@ fn add_upstream_native_libraries( continue; } match lib.kind { - NativeLibKind::Dylib | NativeLibKind::Unspecified => { - panic!("TODO: dylib nativelibkind not supported yet: {}", name) - } - NativeLibKind::Framework => { - panic!("TODO: framework nativelibkind not supported yet: {}", name) - } + NativeLibKind::Dylib | NativeLibKind::Unspecified => sess.fatal(&format!( + "TODO: dylib nativelibkind not supported yet: {}", + name + )), + NativeLibKind::Framework => sess.fatal(&format!( + "TODO: framework nativelibkind not supported yet: {}", + name + )), NativeLibKind::StaticNoBundle => { if data[cnum.as_usize() - 1] == Linkage::Static { - panic!( + sess.fatal(&format!( "TODO: staticnobundle nativelibkind not supported yet: {}", name - ); + )) } } NativeLibKind::StaticBundle => {} NativeLibKind::RawDylib => { - panic!("raw_dylib feature not yet implemented: {}", name); + sess.fatal(&format!("raw_dylib feature not yet implemented: {}", name)) } } } @@ -338,17 +346,17 @@ fn create_archive(files: &[&Path], metadata: &[u8], out_filename: &Path) { builder.into_inner().unwrap(); } -pub fn read_metadata(rlib: &Path) -> MetadataRef { +pub fn read_metadata(rlib: &Path) -> Result { for entry in Archive::new(File::open(rlib).unwrap()).entries().unwrap() { let mut entry = entry.unwrap(); if entry.path().unwrap() == Path::new(".metadata") { let mut bytes = Vec::new(); entry.read_to_end(&mut bytes).unwrap(); let buf: OwningRef, [u8]> = OwningRef::new(bytes); - return rustc_erase_owner!(buf.map_owner_box()); + return Ok(rustc_erase_owner!(buf.map_owner_box())); } } - panic!("No .metadata file in rlib: {:?}", rlib); + Err(format!("No .metadata file in rlib: {:?}", rlib)) } /// This is the actual guts of linking: the rest of the link-related functions are just digging through rustc's @@ -424,7 +432,7 @@ fn do_link( .unwrap(); } } - panic!("Linker error: {}", err) + sess.fatal(&format!("Linker error: {}", err)) } }; @@ -455,7 +463,7 @@ pub(crate) fn run_thin( } if cgcx.lto != Lto::ThinLocal { for _ in cgcx.each_linked_rlib_for_lto.iter() { - panic!("TODO: Implement whatever the heck this is"); + bug!("TODO: Implement whatever the heck this is"); } } let mut thin_buffers = Vec::with_capacity(modules.len()); diff --git a/rustc_codegen_spirv/src/spirv_type.rs b/rustc_codegen_spirv/src/spirv_type.rs index 0e80cbc858..2f08499862 100644 --- a/rustc_codegen_spirv/src/spirv_type.rs +++ b/rustc_codegen_spirv/src/spirv_type.rs @@ -82,7 +82,10 @@ impl SpirvType { } 8 | 16 | 32 | 64 => (), 128 => cx.zombie_no_span(result, "u128"), - other => panic!("Integer width {} invalid for spir-v", other), + other => cx + .tcx + .sess + .fatal(&format!("Integer width {} invalid for spir-v", other)), }; result } @@ -93,7 +96,10 @@ impl SpirvType { cx.zombie_no_span(result, "f64 without OpCapability Float64") } 32 | 64 => (), - other => panic!("Float width {} invalid for spir-v", other), + other => cx + .tcx + .sess + .fatal(&format!("Float width {} invalid for spir-v", other)), }; result } @@ -201,7 +207,10 @@ impl SpirvType { } result } - ref other => panic!("def_with_id invalid for type {:?}", other), + ref other => cx + .tcx + .sess + .fatal(&format!("def_with_id invalid for type {:?}", other)), }; cx.type_cache.def(result, self); result