Support record pattern MIR lowering

This commit is contained in:
hkalbasi 2023-03-14 17:02:38 +03:30
parent 513e340bd3
commit 051dae2221
5 changed files with 201 additions and 63 deletions

View File

@ -555,6 +555,38 @@ fn structs() {
"#,
17,
);
check_number(
r#"
struct Point {
x: i32,
y: i32,
}
const GOAL: i32 = {
let p = Point { x: 5, y: 2 };
let p2 = Point { x: 3, ..p };
p.x * 1000 + p.y * 100 + p2.x * 10 + p2.y
};
"#,
5232,
);
check_number(
r#"
struct Point {
x: i32,
y: i32,
}
const GOAL: i32 = {
let p = Point { x: 5, y: 2 };
let Point { x, y } = p;
let Point { x: x2, .. } = p;
let Point { y: y2, .. } = p;
x * 1000 + y * 100 + x2 * 10 + y2
};
"#,
5252,
);
}
#[test]
@ -599,13 +631,14 @@ fn tuples() {
);
check_number(
r#"
struct TupleLike(i32, u8, i64, u16);
const GOAL: u8 = {
struct TupleLike(i32, i64, u8, u16);
const GOAL: i64 = {
let a = TupleLike(10, 20, 3, 15);
a.1
let TupleLike(b, .., c) = a;
a.1 * 100 + b as i64 + c as i64
};
"#,
20,
2025,
);
check_number(
r#"

View File

@ -711,12 +711,13 @@ pub fn is_dyn_method(
};
let self_ty = trait_ref.self_type_parameter(Interner);
if let TyKind::Dyn(d) = self_ty.kind(Interner) {
let is_my_trait_in_bounds = d.bounds.skip_binders().as_slice(Interner).iter().any(|x| match x.skip_binders() {
// rustc doesn't accept `impl Foo<2> for dyn Foo<5>`, so if the trait id is equal, no matter
// what the generics are, we are sure that the method is come from the vtable.
WhereClause::Implemented(tr) => tr.trait_id == trait_ref.trait_id,
_ => false,
});
let is_my_trait_in_bounds =
d.bounds.skip_binders().as_slice(Interner).iter().any(|x| match x.skip_binders() {
// rustc doesn't accept `impl Foo<2> for dyn Foo<5>`, so if the trait id is equal, no matter
// what the generics are, we are sure that the method is come from the vtable.
WhereClause::Implemented(tr) => tr.trait_id == trait_ref.trait_id,
_ => false,
});
if is_my_trait_in_bounds {
return Some(fn_params);
}

View File

@ -25,8 +25,8 @@ use crate::{
mapping::from_chalk,
method_resolution::{is_dyn_method, lookup_impl_method},
traits::FnTrait,
CallableDefId, Const, ConstScalar, FnDefId, Interner, MemoryMap, Substitution,
TraitEnvironment, Ty, TyBuilder, TyExt, GenericArgData,
CallableDefId, Const, ConstScalar, FnDefId, GenericArgData, Interner, MemoryMap, Substitution,
TraitEnvironment, Ty, TyBuilder, TyExt,
};
use super::{
@ -1315,10 +1315,13 @@ impl Evaluator<'_> {
args_for_target[0] = args_for_target[0][0..self.ptr_size()].to_vec();
let generics_for_target = Substitution::from_iter(
Interner,
generic_args
.iter(Interner)
.enumerate()
.map(|(i, x)| if i == self_ty_idx { &ty } else { x })
generic_args.iter(Interner).enumerate().map(|(i, x)| {
if i == self_ty_idx {
&ty
} else {
x
}
}),
);
return self.exec_fn_with_args(
def,

View File

@ -4,16 +4,17 @@ use std::{iter, mem, sync::Arc};
use chalk_ir::{BoundVar, ConstData, DebruijnIndex, TyKind};
use hir_def::{
adt::VariantData,
body::Body,
expr::{
Array, BindingAnnotation, BindingId, ExprId, LabelId, Literal, MatchArm, Pat, PatId,
RecordLitField,
RecordFieldPat, RecordLitField,
},
lang_item::{LangItem, LangItemTarget},
layout::LayoutError,
path::Path,
resolver::{resolver_for_expr, ResolveValueResult, ValueNs},
DefWithBodyId, EnumVariantId, HasModule, ItemContainerId, TraitId,
DefWithBodyId, EnumVariantId, HasModule, ItemContainerId, LocalFieldId, TraitId,
};
use hir_expand::name::Name;
use la_arena::ArenaMap;
@ -106,6 +107,12 @@ impl MirLowerError {
type Result<T> = std::result::Result<T, MirLowerError>;
enum AdtPatternShape<'a> {
Tuple { args: &'a [PatId], ellipsis: Option<usize> },
Record { args: &'a [RecordFieldPat] },
Unit,
}
impl MirLowerCtx<'_> {
fn temp(&mut self, ty: Ty) -> Result<LocalId> {
if matches!(ty.kind(Interner), TyKind::Slice(_) | TyKind::Dyn(_)) {
@ -444,7 +451,8 @@ impl MirLowerCtx<'_> {
current,
pat.into(),
Some(end),
&[pat], &None)?;
AdtPatternShape::Tuple { args: &[pat], ellipsis: None },
)?;
if let Some((_, block)) = this.lower_expr_as_place(current, body, true)? {
this.set_goto(block, begin);
}
@ -573,7 +581,17 @@ impl MirLowerCtx<'_> {
Ok(None)
}
Expr::Yield { .. } => not_supported!("yield"),
Expr::RecordLit { fields, path, .. } => {
Expr::RecordLit { fields, path, spread, ellipsis: _, is_assignee_expr: _ } => {
let spread_place = match spread {
&Some(x) => {
let Some((p, c)) = self.lower_expr_as_place(current, x, true)? else {
return Ok(None);
};
current = c;
Some(p)
},
None => None,
};
let variant_id = self
.infer
.variant_resolution_for_expr(expr_id)
@ -603,9 +621,24 @@ impl MirLowerCtx<'_> {
place,
Rvalue::Aggregate(
AggregateKind::Adt(variant_id, subst),
operands.into_iter().map(|x| x).collect::<Option<_>>().ok_or(
MirLowerError::TypeError("missing field in record literal"),
)?,
match spread_place {
Some(sp) => operands.into_iter().enumerate().map(|(i, x)| {
match x {
Some(x) => x,
None => {
let mut p = sp.clone();
p.projection.push(ProjectionElem::Field(FieldId {
parent: variant_id,
local_id: LocalFieldId::from_raw(RawIdx::from(i as u32)),
}));
Operand::Copy(p)
},
}
}).collect(),
None => operands.into_iter().map(|x| x).collect::<Option<_>>().ok_or(
MirLowerError::TypeError("missing field in record literal"),
)?,
},
),
expr_id.into(),
);
@ -1021,14 +1054,11 @@ impl MirLowerCtx<'_> {
self.pattern_match_tuple_like(
current,
current_else,
args.iter().enumerate().map(|(i, x)| {
(
PlaceElem::TupleField(i),
*x,
subst.at(Interner, i).assert_ty_ref(Interner).clone(),
)
}),
args,
*ellipsis,
subst.iter(Interner).enumerate().map(|(i, x)| {
(PlaceElem::TupleField(i), x.assert_ty_ref(Interner).clone())
}),
&cond_place,
binding_mode,
)?
@ -1062,7 +1092,21 @@ impl MirLowerCtx<'_> {
}
(then_target, current_else)
}
Pat::Record { .. } => not_supported!("record pattern"),
Pat::Record { args, .. } => {
let Some(variant) = self.infer.variant_resolution_for_pat(pattern) else {
not_supported!("unresolved variant");
};
self.pattern_matching_variant(
cond_ty,
binding_mode,
cond_place,
variant,
current,
pattern.into(),
current_else,
AdtPatternShape::Record { args: &*args },
)?
}
Pat::Range { .. } => not_supported!("range pattern"),
Pat::Slice { .. } => not_supported!("slice pattern"),
Pat::Path(_) => {
@ -1077,8 +1121,7 @@ impl MirLowerCtx<'_> {
current,
pattern.into(),
current_else,
&[],
&None,
AdtPatternShape::Unit,
)?
}
Pat::Lit(l) => {
@ -1160,8 +1203,7 @@ impl MirLowerCtx<'_> {
current,
pattern.into(),
current_else,
args,
ellipsis,
AdtPatternShape::Tuple { args, ellipsis: *ellipsis },
)?
}
Pat::Ref { .. } => not_supported!("& pattern"),
@ -1179,15 +1221,13 @@ impl MirLowerCtx<'_> {
current: BasicBlockId,
span: MirSpan,
current_else: Option<BasicBlockId>,
args: &[PatId],
ellipsis: &Option<usize>,
shape: AdtPatternShape<'_>,
) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
pattern_matching_dereference(&mut cond_ty, &mut binding_mode, &mut cond_place);
let subst = match cond_ty.kind(Interner) {
TyKind::Adt(_, s) => s,
_ => return Err(MirLowerError::TypeError("non adt type matched with tuple struct")),
};
let fields_type = self.db.field_types(variant);
Ok(match variant {
VariantId::EnumVariantId(v) => {
let e = self.db.const_eval_discriminant(v)? as u128;
@ -1208,35 +1248,26 @@ impl MirLowerCtx<'_> {
},
);
let enum_data = self.db.enum_data(v.parent);
let fields =
enum_data.variants[v.local_id].variant_data.fields().iter().map(|(x, _)| {
(
PlaceElem::Field(FieldId { parent: v.into(), local_id: x }),
fields_type[x].clone().substitute(Interner, subst),
)
});
self.pattern_match_tuple_like(
self.pattern_matching_variant_fields(
shape,
&enum_data.variants[v.local_id].variant_data,
variant,
subst,
next,
Some(else_target),
args.iter().zip(fields).map(|(x, y)| (y.0, *x, y.1)),
*ellipsis,
&cond_place,
binding_mode,
)?
}
VariantId::StructId(s) => {
let struct_data = self.db.struct_data(s);
let fields = struct_data.variant_data.fields().iter().map(|(x, _)| {
(
PlaceElem::Field(FieldId { parent: s.into(), local_id: x }),
fields_type[x].clone().substitute(Interner, subst),
)
});
self.pattern_match_tuple_like(
self.pattern_matching_variant_fields(
shape,
&struct_data.variant_data,
variant,
subst,
current,
current_else,
args.iter().zip(fields).map(|(x, y)| (y.0, *x, y.1)),
*ellipsis,
&cond_place,
binding_mode,
)?
@ -1247,18 +1278,69 @@ impl MirLowerCtx<'_> {
})
}
fn pattern_match_tuple_like(
fn pattern_matching_variant_fields(
&mut self,
shape: AdtPatternShape<'_>,
variant_data: &VariantData,
v: VariantId,
subst: &Substitution,
current: BasicBlockId,
current_else: Option<BasicBlockId>,
cond_place: &Place,
binding_mode: BindingAnnotation,
) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
let fields_type = self.db.field_types(v);
Ok(match shape {
AdtPatternShape::Record { args } => {
let it = args
.iter()
.map(|x| {
let field_id =
variant_data.field(&x.name).ok_or(MirLowerError::UnresolvedField)?;
Ok((
PlaceElem::Field(FieldId { parent: v.into(), local_id: field_id }),
x.pat,
fields_type[field_id].clone().substitute(Interner, subst),
))
})
.collect::<Result<Vec<_>>>()?;
self.pattern_match_adt(
current,
current_else,
it.into_iter(),
cond_place,
binding_mode,
)?
}
AdtPatternShape::Tuple { args, ellipsis } => {
let fields = variant_data.fields().iter().map(|(x, _)| {
(
PlaceElem::Field(FieldId { parent: v.into(), local_id: x }),
fields_type[x].clone().substitute(Interner, subst),
)
});
self.pattern_match_tuple_like(
current,
current_else,
args,
ellipsis,
fields,
cond_place,
binding_mode,
)?
}
AdtPatternShape::Unit => (current, current_else),
})
}
fn pattern_match_adt(
&mut self,
mut current: BasicBlockId,
mut current_else: Option<BasicBlockId>,
args: impl Iterator<Item = (PlaceElem, PatId, Ty)>,
ellipsis: Option<usize>,
cond_place: &Place,
binding_mode: BindingAnnotation,
) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
if ellipsis.is_some() {
not_supported!("tuple like pattern with ellipsis");
}
for (proj, arg, ty) in args {
let mut cond_place = cond_place.clone();
cond_place.projection.push(proj);
@ -1268,6 +1350,25 @@ impl MirLowerCtx<'_> {
Ok((current, current_else))
}
fn pattern_match_tuple_like(
&mut self,
current: BasicBlockId,
current_else: Option<BasicBlockId>,
args: &[PatId],
ellipsis: Option<usize>,
fields: impl DoubleEndedIterator<Item = (PlaceElem, Ty)> + Clone,
cond_place: &Place,
binding_mode: BindingAnnotation,
) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
let (al, ar) = args.split_at(ellipsis.unwrap_or(args.len()));
let it = al
.iter()
.zip(fields.clone())
.chain(ar.iter().rev().zip(fields.rev()))
.map(|(x, y)| (y.0, *x, y.1));
self.pattern_match_adt(current, current_else, it, cond_place, binding_mode)
}
fn discr_temp_place(&mut self) -> Place {
match &self.discr_temp {
Some(x) => x.clone(),

View File

@ -295,7 +295,7 @@ impl<T> Arena<T> {
/// ```
pub fn iter(
&self,
) -> impl Iterator<Item = (Idx<T>, &T)> + ExactSizeIterator + DoubleEndedIterator {
) -> impl Iterator<Item = (Idx<T>, &T)> + ExactSizeIterator + DoubleEndedIterator + Clone {
self.data.iter().enumerate().map(|(idx, value)| (Idx::from_raw(RawIdx(idx as u32)), value))
}