mirror of
https://github.com/EmbarkStudios/rust-gpu.git
synced 2025-02-16 17:04:16 +00:00
asm: streamline generating a type from a pattern for inference.
This commit is contained in:
parent
6c8dece25a
commit
fedac0f4a5
@ -15,7 +15,7 @@ use rustc_codegen_ssa::traits::{AsmBuilderMethods, InlineAsmOperandRef};
|
|||||||
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
|
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
|
||||||
use rustc_hir::LlvmInlineAsmInner;
|
use rustc_hir::LlvmInlineAsmInner;
|
||||||
use rustc_middle::bug;
|
use rustc_middle::bug;
|
||||||
use rustc_span::source_map::Span;
|
use rustc_span::{Span, DUMMY_SP};
|
||||||
use rustc_target::asm::{InlineAsmRegClass, InlineAsmRegOrRegClass, SpirVInlineAsmRegClass};
|
use rustc_target::asm::{InlineAsmRegClass, InlineAsmRegOrRegClass, SpirVInlineAsmRegClass};
|
||||||
|
|
||||||
pub struct InstructionTable {
|
pub struct InstructionTable {
|
||||||
@ -546,24 +546,24 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
|
|||||||
struct Unapplicable;
|
struct Unapplicable;
|
||||||
|
|
||||||
/// Recursively match `ty` against `pat`, returning one of:
|
/// Recursively match `ty` against `pat`, returning one of:
|
||||||
/// * `Ok(None)`: `pat` matched but contained no type variables
|
/// * `Ok([None])`: `pat` matched but contained no type variables
|
||||||
/// * `Ok(Some(var))`: `pat` matched and `var` is the type variable
|
/// * `Ok([Some(var)])`: `pat` matched and `var` is the type variable
|
||||||
/// * `Err(Mismatch)`: `pat` didn't match or isn't supported right now
|
/// * `Err(Mismatch)`: `pat` didn't match or isn't supported right now
|
||||||
fn match_ty_pat(
|
fn match_ty_pat(
|
||||||
cx: &CodegenCx<'_>,
|
cx: &CodegenCx<'_>,
|
||||||
pat: &TyPat<'_>,
|
pat: &TyPat<'_>,
|
||||||
ty: Word,
|
ty: Word,
|
||||||
) -> Result<Option<Word>, Unapplicable> {
|
) -> Result<[Option<Word>; 1], Unapplicable> {
|
||||||
match pat {
|
match pat {
|
||||||
TyPat::Any => Ok(None),
|
TyPat::Any => Ok([None]),
|
||||||
&TyPat::T => Ok(Some(ty)),
|
&TyPat::T => Ok([Some(ty)]),
|
||||||
TyPat::Either(a, b) => {
|
TyPat::Either(a, b) => {
|
||||||
match_ty_pat(cx, a, ty).or_else(|Unapplicable| match_ty_pat(cx, b, ty))
|
match_ty_pat(cx, a, ty).or_else(|Unapplicable| match_ty_pat(cx, b, ty))
|
||||||
}
|
}
|
||||||
_ => match (pat, cx.lookup_type(ty)) {
|
_ => match (pat, cx.lookup_type(ty)) {
|
||||||
(TyPat::Any, _) | (&TyPat::T, _) | (TyPat::Either(..), _) => unreachable!(),
|
(TyPat::Any, _) | (&TyPat::T, _) | (TyPat::Either(..), _) => unreachable!(),
|
||||||
|
|
||||||
(TyPat::Void, SpirvType::Void) => Ok(None),
|
(TyPat::Void, SpirvType::Void) => Ok([None]),
|
||||||
(TyPat::Pointer(_, pat), SpirvType::Pointer { pointee: ty, .. })
|
(TyPat::Pointer(_, pat), SpirvType::Pointer { pointee: ty, .. })
|
||||||
| (TyPat::Vector(pat), SpirvType::Vector { element: ty, .. })
|
| (TyPat::Vector(pat), SpirvType::Vector { element: ty, .. })
|
||||||
| (
|
| (
|
||||||
@ -587,6 +587,38 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct Ambiguous;
|
||||||
|
|
||||||
|
/// Construct a type from `pat`, replacing `TyPat::Var(i)` with `ty_vars[i]`.
|
||||||
|
/// If the pattern isn't constraining enough to determine an unique type,
|
||||||
|
/// `Err(Ambiguous)` is returned instead.
|
||||||
|
fn subst_ty_pat(
|
||||||
|
cx: &CodegenCx<'_>,
|
||||||
|
pat: &TyPat<'_>,
|
||||||
|
ty_vars: &[Option<Word>],
|
||||||
|
) -> Result<Word, Ambiguous> {
|
||||||
|
Ok(match pat {
|
||||||
|
&TyPat::Var(i) => match ty_vars.get(i) {
|
||||||
|
Some(&Some(ty)) => ty,
|
||||||
|
_ => return Err(Ambiguous),
|
||||||
|
},
|
||||||
|
|
||||||
|
TyPat::Vector4(pat) => SpirvType::Vector {
|
||||||
|
element: subst_ty_pat(cx, pat, ty_vars)?,
|
||||||
|
count: 4,
|
||||||
|
}
|
||||||
|
.def(DUMMY_SP, cx),
|
||||||
|
|
||||||
|
TyPat::SampledImage(pat) => SpirvType::SampledImage {
|
||||||
|
image_type: subst_ty_pat(cx, pat, ty_vars)?,
|
||||||
|
}
|
||||||
|
.def(DUMMY_SP, cx),
|
||||||
|
|
||||||
|
_ => return Err(Ambiguous),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// FIXME(eddyb) try multiple signatures until one fits.
|
// FIXME(eddyb) try multiple signatures until one fits.
|
||||||
let mut sig = match instruction_signatures(instruction.class.opcode)? {
|
let mut sig = match instruction_signatures(instruction.class.opcode)? {
|
||||||
[sig
|
[sig
|
||||||
@ -597,7 +629,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
|
|||||||
_ => return None,
|
_ => return None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut combined_var = None;
|
let mut combined_ty_vars = [None];
|
||||||
|
|
||||||
let mut ids = instruction.operands.iter().filter_map(|o| o.id_ref_any());
|
let mut ids = instruction.operands.iter().filter_map(|o| o.id_ref_any());
|
||||||
while let TyListPat::Cons { first: pat, suffix } = *sig.input_types {
|
while let TyListPat::Cons { first: pat, suffix } = *sig.input_types {
|
||||||
@ -609,25 +641,31 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
|
|||||||
// Non-value ID operand (or value operand of unknown type),
|
// Non-value ID operand (or value operand of unknown type),
|
||||||
// only `TyPat::Any` is valid.
|
// only `TyPat::Any` is valid.
|
||||||
None => match pat {
|
None => match pat {
|
||||||
TyPat::Any => Ok(None),
|
TyPat::Any => Ok([None]),
|
||||||
_ => Err(Unapplicable),
|
_ => Err(Unapplicable),
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
match match_result {
|
|
||||||
Ok(Some(var)) => match combined_var {
|
let ty_vars = match match_result {
|
||||||
Some(combined_var) => {
|
Ok(ty_vars) => ty_vars,
|
||||||
// FIXME(eddyb) this could use some error reporting
|
|
||||||
// (it's a type mismatch), although we could also
|
|
||||||
// just use the first type and let validation take
|
|
||||||
// care of the mismatch
|
|
||||||
if var != combined_var {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None => combined_var = Some(var),
|
|
||||||
},
|
|
||||||
Ok(None) => {}
|
|
||||||
Err(Unapplicable) => return None,
|
Err(Unapplicable) => return None,
|
||||||
|
};
|
||||||
|
|
||||||
|
for (&var, combined_var) in ty_vars.iter().zip(&mut combined_ty_vars) {
|
||||||
|
if let Some(var) = var {
|
||||||
|
match *combined_var {
|
||||||
|
Some(combined_var) => {
|
||||||
|
// FIXME(eddyb) this could use some error reporting
|
||||||
|
// (it's a type mismatch), although we could also
|
||||||
|
// just use the first type and let validation take
|
||||||
|
// care of the mismatch
|
||||||
|
if var != combined_var {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => *combined_var = Some(var),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
match sig.input_types {
|
match sig.input_types {
|
||||||
@ -642,20 +680,9 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
|
|||||||
_ => return None,
|
_ => return None,
|
||||||
}
|
}
|
||||||
|
|
||||||
let var = combined_var?;
|
match subst_ty_pat(self, sig.output_type.unwrap(), &combined_ty_vars) {
|
||||||
match sig.output_type.unwrap() {
|
Ok(ty) => Some(ty),
|
||||||
&TyPat::T => Some(var),
|
Err(Ambiguous) => None,
|
||||||
TyPat::Vector4(&TyPat::T) => Some(
|
|
||||||
SpirvType::Vector {
|
|
||||||
element: var,
|
|
||||||
count: 4,
|
|
||||||
}
|
|
||||||
.def(self.span(), self),
|
|
||||||
),
|
|
||||||
TyPat::SampledImage(&TyPat::T) => {
|
|
||||||
Some(SpirvType::SampledImage { image_type: var }.def(self.span(), self))
|
|
||||||
}
|
|
||||||
_ => None,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user