Fix deduction of dyn Fn closure parameter types

This commit is contained in:
Jonas Schievink 2021-06-29 17:35:37 +02:00
parent 13cbe64a59
commit 88a86d4ff9
6 changed files with 139 additions and 21 deletions

View File

@ -2,10 +2,7 @@
//! HIR back into source code, and just displaying them for debugging/testing
//! purposes.
use std::{
array,
fmt::{self, Debug},
};
use std::fmt::{self, Debug};
use chalk_ir::BoundVar;
use hir_def::{
@ -23,12 +20,16 @@ use hir_def::{
use hir_expand::{hygiene::Hygiene, name::Name};
use crate::{
const_from_placeholder_idx, db::HirDatabase, from_assoc_type_id, from_foreign_def_id,
from_placeholder_idx, lt_from_placeholder_idx, mapping::from_chalk, primitive, subst_prefix,
to_assoc_type_id, utils::generics, AdtId, AliasEq, AliasTy, CallableDefId, CallableSig, Const,
ConstValue, DomainGoal, GenericArg, ImplTraitId, Interner, Lifetime, LifetimeData,
LifetimeOutlives, Mutability, OpaqueTy, ProjectionTy, ProjectionTyExt, QuantifiedWhereClause,
Scalar, TraitRef, TraitRefExt, Ty, TyExt, TyKind, WhereClause,
const_from_placeholder_idx,
db::HirDatabase,
from_assoc_type_id, from_foreign_def_id, from_placeholder_idx, lt_from_placeholder_idx,
mapping::from_chalk,
primitive, subst_prefix, to_assoc_type_id,
utils::{self, generics},
AdtId, AliasEq, AliasTy, CallableDefId, CallableSig, Const, ConstValue, DomainGoal, GenericArg,
ImplTraitId, Interner, Lifetime, LifetimeData, LifetimeOutlives, Mutability, OpaqueTy,
ProjectionTy, ProjectionTyExt, QuantifiedWhereClause, Scalar, TraitRef, TraitRefExt, Ty, TyExt,
TyKind, WhereClause,
};
pub struct HirFormatter<'a> {
@ -706,12 +707,7 @@ impl HirDisplay for CallableSig {
fn fn_traits(db: &dyn DefDatabase, trait_: TraitId) -> impl Iterator<Item = TraitId> {
let krate = trait_.lookup(db).container.krate();
let fn_traits = [
db.lang_item(krate, "fn".into()),
db.lang_item(krate, "fn_mut".into()),
db.lang_item(krate, "fn_once".into()),
];
array::IntoIter::new(fn_traits).into_iter().flatten().flat_map(|it| it.as_trait())
utils::fn_traits(db, krate)
}
pub fn write_bounds_like_dyn_trait_with_prefix(

View File

@ -52,6 +52,7 @@ mod path;
mod expr;
mod pat;
mod coerce;
mod closure;
/// The entry point of type inference.
pub(crate) fn infer_query(db: &dyn HirDatabase, def: DefWithBodyId) -> Arc<InferenceResult> {

View File

@ -0,0 +1,93 @@
//! Inference of closure parameter types based on the closure's expected type.
use chalk_ir::{fold::Shift, AliasTy, FnSubst, WhereClause};
use hir_def::HasModule;
use smallvec::SmallVec;
use crate::{
to_chalk_trait_id, utils, ChalkTraitId, DynTy, FnPointer, FnSig, Interner, Substitution, Ty,
TyKind,
};
use super::{Expectation, InferenceContext};
impl InferenceContext<'_> {
pub(super) fn deduce_closure_type_from_expectations(
&mut self,
closure_ty: &Ty,
sig_ty: &Ty,
expectation: &Expectation,
) {
let expected_ty = match expectation.to_option(&mut self.table) {
Some(ty) => ty,
None => return,
};
// Deduction from where-clauses in scope, as well as fn-pointer coercion are handled here.
self.coerce(closure_ty, &expected_ty);
// Deduction based on the expected `dyn Fn` is done separately.
if let TyKind::Dyn(dyn_ty) = expected_ty.kind(&Interner) {
if let Some(sig) = self.deduce_sig_from_dyn_ty(dyn_ty) {
let expected_sig_ty = TyKind::Function(sig).intern(&Interner);
self.unify(sig_ty, &expected_sig_ty);
}
}
}
fn deduce_sig_from_dyn_ty(&self, dyn_ty: &DynTy) -> Option<FnPointer> {
// Search for predicates like `$self: FnX<Args>` and `<$self as FnOnce<...>>::Output == Ret`
let fn_traits: SmallVec<[ChalkTraitId; 3]> =
utils::fn_traits(self.db.upcast(), self.owner.module(self.db.upcast()).krate())
.map(|tid| to_chalk_trait_id(tid))
.collect();
for bound in dyn_ty.bounds.map_ref(|b| b.iter(&Interner)) {
let bound = bound.map(|b| b.clone()).fuse_binders(&Interner);
match bound.skip_binders() {
WhereClause::AliasEq(eq) => match &eq.alias {
AliasTy::Projection(projection) => {
let assoc_data = self.db.associated_ty_data(projection.associated_ty_id);
if !fn_traits.contains(&assoc_data.trait_id) {
return None;
}
// Skip `Self`, get the type argument.
let arg = projection.substitution.as_slice(&Interner).get(1)?;
match arg.ty(&Interner)?.kind(&Interner) {
TyKind::Tuple(_, subst) => {
let generic_args = subst.as_slice(&Interner);
let mut sig_tys = Vec::new();
for arg in generic_args {
sig_tys.push(arg.ty(&Interner)?.clone());
}
sig_tys.push(eq.ty.clone());
cov_mark::hit!(dyn_fn_param_informs_call_site_closure_signature);
return Some(FnPointer {
num_binders: 0,
sig: FnSig {
abi: (),
safety: chalk_ir::Safety::Safe,
variadic: false,
},
substitution: FnSubst(
Substitution::from_iter(&Interner, sig_tys.clone())
.shifted_in(&Interner),
),
});
}
_ => {}
}
}
AliasTy::Opaque(_) => {}
},
_ => {}
}
}
None
}
}

View File

@ -278,15 +278,13 @@ impl<'a> InferenceContext<'a> {
.intern(&Interner);
let closure_id = self.db.intern_closure((self.owner, tgt_expr)).into();
let closure_ty =
TyKind::Closure(closure_id, Substitution::from1(&Interner, sig_ty))
TyKind::Closure(closure_id, Substitution::from1(&Interner, sig_ty.clone()))
.intern(&Interner);
// Eagerly try to relate the closure type with the expected
// type, otherwise we often won't have enough information to
// infer the body.
if let Some(t) = expected.only_has_type(&mut self.table) {
self.coerce(&closure_ty, &t);
}
self.deduce_closure_type_from_expectations(&closure_ty, &sig_ty, expected);
// Now go through the argument patterns
for (arg_pat, arg_ty) in args.iter().zip(sig_tys) {

View File

@ -2829,6 +2829,26 @@ fn foo() {
);
}
#[test]
fn dyn_fn_param_informs_call_site_closure_signature() {
cov_mark::check!(dyn_fn_param_informs_call_site_closure_signature);
check_types(
r#"
//- minicore: fn, coerce_unsized
struct S;
impl S {
fn inherent(&self) -> u8 { 0 }
}
fn take_dyn_fn(f: &dyn Fn(S)) {}
fn f() {
take_dyn_fn(&|x| { x.inherent(); });
//^^^^^^^^^^^^ u8
}
"#,
);
}
#[test]
fn infer_fn_trait_arg() {
check_infer_with_mismatches(

View File

@ -1,8 +1,9 @@
//! Helper functions for working with def, which don't need to be a separate
//! query, but can't be computed directly from `*Data` (ie, which need a `db`).
use std::iter;
use std::{array, iter};
use base_db::CrateId;
use chalk_ir::{fold::Shift, BoundVar, DebruijnIndex};
use hir_def::{
db::DefDatabase,
@ -23,6 +24,15 @@ use crate::{
WhereClause,
};
pub(crate) fn fn_traits(db: &dyn DefDatabase, krate: CrateId) -> impl Iterator<Item = TraitId> {
let fn_traits = [
db.lang_item(krate, "fn".into()),
db.lang_item(krate, "fn_mut".into()),
db.lang_item(krate, "fn_once".into()),
];
array::IntoIter::new(fn_traits).into_iter().flatten().flat_map(|it| it.as_trait())
}
fn direct_super_traits(db: &dyn DefDatabase, trait_: TraitId) -> Vec<TraitId> {
let resolver = trait_.resolver(db);
// returning the iterator directly doesn't easily work because of