diff --git a/crates/rustc_codegen_spirv/src/builder/spirv_asm.rs b/crates/rustc_codegen_spirv/src/builder/spirv_asm.rs index 1238156613..3b26f4a318 100644 --- a/crates/rustc_codegen_spirv/src/builder/spirv_asm.rs +++ b/crates/rustc_codegen_spirv/src/builder/spirv_asm.rs @@ -15,7 +15,7 @@ use rustc_codegen_ssa::traits::{AsmBuilderMethods, InlineAsmOperandRef}; use rustc_data_structures::fx::{FxHashMap, FxHashSet}; use rustc_hir::LlvmInlineAsmInner; use rustc_middle::bug; -use rustc_span::source_map::Span; +use rustc_span::{Span, DUMMY_SP}; use rustc_target::asm::{InlineAsmRegClass, InlineAsmRegOrRegClass, SpirVInlineAsmRegClass}; pub struct InstructionTable { @@ -546,24 +546,24 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { struct Unapplicable; /// Recursively match `ty` against `pat`, returning one of: - /// * `Ok(None)`: `pat` matched but contained no type variables - /// * `Ok(Some(var))`: `pat` matched and `var` is the type variable + /// * `Ok([None])`: `pat` matched but contained no type variables + /// * `Ok([Some(var)])`: `pat` matched and `var` is the type variable /// * `Err(Mismatch)`: `pat` didn't match or isn't supported right now fn match_ty_pat( cx: &CodegenCx<'_>, pat: &TyPat<'_>, ty: Word, - ) -> Result, Unapplicable> { + ) -> Result<[Option; 1], Unapplicable> { match pat { - TyPat::Any => Ok(None), - &TyPat::T => Ok(Some(ty)), + TyPat::Any => Ok([None]), + &TyPat::T => Ok([Some(ty)]), TyPat::Either(a, b) => { match_ty_pat(cx, a, ty).or_else(|Unapplicable| match_ty_pat(cx, b, ty)) } _ => match (pat, cx.lookup_type(ty)) { (TyPat::Any, _) | (&TyPat::T, _) | (TyPat::Either(..), _) => unreachable!(), - (TyPat::Void, SpirvType::Void) => Ok(None), + (TyPat::Void, SpirvType::Void) => Ok([None]), (TyPat::Pointer(_, pat), SpirvType::Pointer { pointee: ty, .. }) | (TyPat::Vector(pat), SpirvType::Vector { element: ty, .. }) | ( @@ -587,6 +587,38 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { } } + #[derive(Debug)] + struct Ambiguous; + + /// Construct a type from `pat`, replacing `TyPat::Var(i)` with `ty_vars[i]`. + /// If the pattern isn't constraining enough to determine an unique type, + /// `Err(Ambiguous)` is returned instead. + fn subst_ty_pat( + cx: &CodegenCx<'_>, + pat: &TyPat<'_>, + ty_vars: &[Option], + ) -> Result { + Ok(match pat { + &TyPat::Var(i) => match ty_vars.get(i) { + Some(&Some(ty)) => ty, + _ => return Err(Ambiguous), + }, + + TyPat::Vector4(pat) => SpirvType::Vector { + element: subst_ty_pat(cx, pat, ty_vars)?, + count: 4, + } + .def(DUMMY_SP, cx), + + TyPat::SampledImage(pat) => SpirvType::SampledImage { + image_type: subst_ty_pat(cx, pat, ty_vars)?, + } + .def(DUMMY_SP, cx), + + _ => return Err(Ambiguous), + }) + } + // FIXME(eddyb) try multiple signatures until one fits. let mut sig = match instruction_signatures(instruction.class.opcode)? { [sig @@ -597,7 +629,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { _ => return None, }; - let mut combined_var = None; + let mut combined_ty_vars = [None]; let mut ids = instruction.operands.iter().filter_map(|o| o.id_ref_any()); while let TyListPat::Cons { first: pat, suffix } = *sig.input_types { @@ -609,25 +641,31 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { // Non-value ID operand (or value operand of unknown type), // only `TyPat::Any` is valid. None => match pat { - TyPat::Any => Ok(None), + TyPat::Any => Ok([None]), _ => Err(Unapplicable), }, }; - match match_result { - Ok(Some(var)) => match combined_var { - Some(combined_var) => { - // FIXME(eddyb) this could use some error reporting - // (it's a type mismatch), although we could also - // just use the first type and let validation take - // care of the mismatch - if var != combined_var { - return None; - } - } - None => combined_var = Some(var), - }, - Ok(None) => {} + + let ty_vars = match match_result { + Ok(ty_vars) => ty_vars, Err(Unapplicable) => return None, + }; + + for (&var, combined_var) in ty_vars.iter().zip(&mut combined_ty_vars) { + if let Some(var) = var { + match *combined_var { + Some(combined_var) => { + // FIXME(eddyb) this could use some error reporting + // (it's a type mismatch), although we could also + // just use the first type and let validation take + // care of the mismatch + if var != combined_var { + return None; + } + } + None => *combined_var = Some(var), + } + } } } match sig.input_types { @@ -642,20 +680,9 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { _ => return None, } - let var = combined_var?; - match sig.output_type.unwrap() { - &TyPat::T => Some(var), - TyPat::Vector4(&TyPat::T) => Some( - SpirvType::Vector { - element: var, - count: 4, - } - .def(self.span(), self), - ), - TyPat::SampledImage(&TyPat::T) => { - Some(SpirvType::SampledImage { image_type: var }.def(self.span(), self)) - } - _ => None, + match subst_ty_pat(self, sig.output_type.unwrap(), &combined_ty_vars) { + Ok(ty) => Some(ty), + Err(Ambiguous) => None, } }