Auto merge of #13475 - lowr:fix/lookup-impl-method-trait-ref, r=flodiebold

fix: Test all generic args for trait when finding matching impl

Addresses https://github.com/rust-lang/rust-analyzer/pull/13463#issuecomment-1287816680

When finding matching impl for a trait method, we've been testing the unifiability of self type. However, there can be multiple impl of a trait for the same type with different generic arguments for the trait. This patch takes it into account and tests the unifiability of all type arguments for the trait (the first being the self type) thus enables rust-analyzer to find the correct impl even in such cases.
This commit is contained in:
bors 2022-10-26 12:06:26 +00:00
commit feefbe7918
4 changed files with 172 additions and 75 deletions

View File

@ -340,8 +340,8 @@ impl<'a> InferenceTable<'a> {
self.resolve_with_fallback(t, &|_, _, d, _| d)
}
/// Unify two types and register new trait goals that arise from that.
pub(crate) fn unify(&mut self, ty1: &Ty, ty2: &Ty) -> bool {
/// Unify two relatable values (e.g. `Ty`) and register new trait goals that arise from that.
pub(crate) fn unify<T: ?Sized + Zip<Interner>>(&mut self, ty1: &T, ty2: &T) -> bool {
let result = match self.try_unify(ty1, ty2) {
Ok(r) => r,
Err(_) => return false,
@ -350,9 +350,13 @@ impl<'a> InferenceTable<'a> {
true
}
/// Unify two types and return new trait goals arising from it, so the
/// Unify two relatable values (e.g. `Ty`) and return new trait goals arising from it, so the
/// caller needs to deal with them.
pub(crate) fn try_unify<T: Zip<Interner>>(&mut self, t1: &T, t2: &T) -> InferResult<()> {
pub(crate) fn try_unify<T: ?Sized + Zip<Interner>>(
&mut self,
t1: &T,
t2: &T,
) -> InferResult<()> {
match self.var_unification_table.relate(
Interner,
&self.db,

View File

@ -22,10 +22,10 @@ use crate::{
from_foreign_def_id,
infer::{unify::InferenceTable, Adjust, Adjustment, AutoBorrow, OverloadedDeref, PointerCast},
primitive::{FloatTy, IntTy, UintTy},
static_lifetime,
static_lifetime, to_chalk_trait_id,
utils::all_super_traits,
AdtId, Canonical, CanonicalVarKinds, DebruijnIndex, ForeignDefId, InEnvironment, Interner,
Scalar, TraitEnvironment, TraitRefExt, Ty, TyBuilder, TyExt, TyKind,
Scalar, Substitution, TraitEnvironment, TraitRef, TraitRefExt, Ty, TyBuilder, TyExt, TyKind,
};
/// This is used as a key for indexing impls.
@ -624,52 +624,76 @@ pub(crate) fn iterate_method_candidates<T>(
slot
}
/// Looks up the impl method that actually runs for the trait method `func`.
///
/// Returns `func` if it's not a method defined in a trait or the lookup failed.
pub fn lookup_impl_method(
self_ty: &Ty,
db: &dyn HirDatabase,
env: Arc<TraitEnvironment>,
trait_: TraitId,
func: FunctionId,
fn_subst: Substitution,
) -> FunctionId {
let trait_id = match func.lookup(db.upcast()).container {
ItemContainerId::TraitId(id) => id,
_ => return func,
};
let trait_params = db.generic_params(trait_id.into()).type_or_consts.len();
let fn_params = fn_subst.len(Interner) - trait_params;
let trait_ref = TraitRef {
trait_id: to_chalk_trait_id(trait_id),
substitution: Substitution::from_iter(Interner, fn_subst.iter(Interner).skip(fn_params)),
};
let name = &db.function_data(func).name;
lookup_impl_method_for_trait_ref(trait_ref, db, env, name).unwrap_or(func)
}
fn lookup_impl_method_for_trait_ref(
trait_ref: TraitRef,
db: &dyn HirDatabase,
env: Arc<TraitEnvironment>,
name: &Name,
) -> Option<FunctionId> {
let self_ty_fp = TyFingerprint::for_trait_impl(self_ty)?;
let trait_impls = db.trait_impls_in_deps(env.krate);
let impls = trait_impls.for_trait_and_self_ty(trait_, self_ty_fp);
let mut table = InferenceTable::new(db, env.clone());
find_matching_impl(impls, &mut table, &self_ty).and_then(|data| {
data.items.iter().find_map(|it| match it {
AssocItemId::FunctionId(f) => (db.function_data(*f).name == *name).then(|| *f),
_ => None,
})
let self_ty = trait_ref.self_type_parameter(Interner);
let self_ty_fp = TyFingerprint::for_trait_impl(&self_ty)?;
let impls = db.trait_impls_in_deps(env.krate);
let impls = impls.for_trait_and_self_ty(trait_ref.hir_trait_id(), self_ty_fp);
let table = InferenceTable::new(db, env);
let impl_data = find_matching_impl(impls, table, trait_ref)?;
impl_data.items.iter().find_map(|it| match it {
AssocItemId::FunctionId(f) => (db.function_data(*f).name == *name).then(|| *f),
_ => None,
})
}
fn find_matching_impl(
mut impls: impl Iterator<Item = ImplId>,
table: &mut InferenceTable<'_>,
self_ty: &Ty,
mut table: InferenceTable<'_>,
actual_trait_ref: TraitRef,
) -> Option<Arc<ImplData>> {
let db = table.db;
loop {
let impl_ = impls.next()?;
let r = table.run_in_snapshot(|table| {
let impl_data = db.impl_data(impl_);
let substs =
let impl_substs =
TyBuilder::subst_for_def(db, impl_, None).fill_with_inference_vars(table).build();
let impl_ty = db.impl_self_ty(impl_).substitute(Interner, &substs);
let trait_ref = db
.impl_trait(impl_)
.expect("non-trait method in find_matching_impl")
.substitute(Interner, &impl_substs);
table
.unify(self_ty, &impl_ty)
.then(|| {
let wh_goals =
crate::chalk_db::convert_where_clauses(db, impl_.into(), &substs)
.into_iter()
.map(|b| b.cast(Interner));
if !table.unify(&trait_ref, &actual_trait_ref) {
return None;
}
let goal = crate::Goal::all(Interner, wh_goals);
table.try_obligation(goal).map(|_| impl_data)
})
.flatten()
let wcs = crate::chalk_db::convert_where_clauses(db, impl_.into(), &impl_substs)
.into_iter()
.map(|b| b.cast(Interner));
let goal = crate::Goal::all(Interner, wcs);
table.try_obligation(goal).map(|_| impl_data)
});
if r.is_some() {
break r;
@ -1214,7 +1238,7 @@ fn is_valid_fn_candidate(
let expected_receiver =
sig.map(|s| s.params()[0].clone()).substitute(Interner, &fn_subst);
check_that!(table.unify(&receiver_ty, &expected_receiver));
check_that!(table.unify(receiver_ty, &expected_receiver));
}
if let ItemContainerId::ImplId(impl_id) = container {

View File

@ -270,7 +270,7 @@ impl SourceAnalyzer {
let expr_id = self.expr_id(db, &call.clone().into())?;
let (f_in_trait, substs) = self.infer.as_ref()?.method_resolution(expr_id)?;
Some(self.resolve_impl_method_or_trait_def(db, f_in_trait, &substs))
Some(self.resolve_impl_method_or_trait_def(db, f_in_trait, substs))
}
pub(crate) fn resolve_await_to_poll(
@ -311,7 +311,7 @@ impl SourceAnalyzer {
// HACK: subst for `poll()` coincides with that for `Future` because `poll()` itself
// doesn't have any generic parameters, so we skip building another subst for `poll()`.
let substs = hir_ty::TyBuilder::subst_for_def(db, future_trait, None).push(ty).build();
Some(self.resolve_impl_method_or_trait_def(db, poll_fn, &substs))
Some(self.resolve_impl_method_or_trait_def(db, poll_fn, substs))
}
pub(crate) fn resolve_prefix_expr(
@ -331,7 +331,7 @@ impl SourceAnalyzer {
// don't have any generic parameters, so we skip building another subst for the methods.
let substs = hir_ty::TyBuilder::subst_for_def(db, op_trait, None).push(ty.clone()).build();
Some(self.resolve_impl_method_or_trait_def(db, op_fn, &substs))
Some(self.resolve_impl_method_or_trait_def(db, op_fn, substs))
}
pub(crate) fn resolve_index_expr(
@ -351,7 +351,7 @@ impl SourceAnalyzer {
.push(base_ty.clone())
.push(index_ty.clone())
.build();
Some(self.resolve_impl_method_or_trait_def(db, op_fn, &substs))
Some(self.resolve_impl_method_or_trait_def(db, op_fn, substs))
}
pub(crate) fn resolve_bin_expr(
@ -372,7 +372,7 @@ impl SourceAnalyzer {
.push(rhs.clone())
.build();
Some(self.resolve_impl_method_or_trait_def(db, op_fn, &substs))
Some(self.resolve_impl_method_or_trait_def(db, op_fn, substs))
}
pub(crate) fn resolve_try_expr(
@ -392,7 +392,7 @@ impl SourceAnalyzer {
// doesn't have any generic parameters, so we skip building another subst for `branch()`.
let substs = hir_ty::TyBuilder::subst_for_def(db, op_trait, None).push(ty.clone()).build();
Some(self.resolve_impl_method_or_trait_def(db, op_fn, &substs))
Some(self.resolve_impl_method_or_trait_def(db, op_fn, substs))
}
pub(crate) fn resolve_field(
@ -487,9 +487,9 @@ impl SourceAnalyzer {
let mut prefer_value_ns = false;
let resolved = (|| {
let infer = self.infer.as_deref()?;
if let Some(path_expr) = parent().and_then(ast::PathExpr::cast) {
let expr_id = self.expr_id(db, &path_expr.into())?;
let infer = self.infer.as_ref()?;
if let Some(assoc) = infer.assoc_resolutions_for_expr(expr_id) {
let assoc = match assoc {
AssocItemId::FunctionId(f_in_trait) => {
@ -497,9 +497,12 @@ impl SourceAnalyzer {
None => assoc,
Some(func_ty) => {
if let TyKind::FnDef(_fn_def, subs) = func_ty.kind(Interner) {
self.resolve_impl_method(db, f_in_trait, subs)
.map(AssocItemId::FunctionId)
.unwrap_or(assoc)
self.resolve_impl_method_or_trait_def(
db,
f_in_trait,
subs.clone(),
)
.into()
} else {
assoc
}
@ -520,18 +523,18 @@ impl SourceAnalyzer {
prefer_value_ns = true;
} else if let Some(path_pat) = parent().and_then(ast::PathPat::cast) {
let pat_id = self.pat_id(&path_pat.into())?;
if let Some(assoc) = self.infer.as_ref()?.assoc_resolutions_for_pat(pat_id) {
if let Some(assoc) = infer.assoc_resolutions_for_pat(pat_id) {
return Some(PathResolution::Def(AssocItem::from(assoc).into()));
}
if let Some(VariantId::EnumVariantId(variant)) =
self.infer.as_ref()?.variant_resolution_for_pat(pat_id)
infer.variant_resolution_for_pat(pat_id)
{
return Some(PathResolution::Def(ModuleDef::Variant(variant.into())));
}
} else if let Some(rec_lit) = parent().and_then(ast::RecordExpr::cast) {
let expr_id = self.expr_id(db, &rec_lit.into())?;
if let Some(VariantId::EnumVariantId(variant)) =
self.infer.as_ref()?.variant_resolution_for_expr(expr_id)
infer.variant_resolution_for_expr(expr_id)
{
return Some(PathResolution::Def(ModuleDef::Variant(variant.into())));
}
@ -541,8 +544,7 @@ impl SourceAnalyzer {
|| parent().and_then(ast::TupleStructPat::cast).map(ast::Pat::from);
if let Some(pat) = record_pat.or_else(tuple_struct_pat) {
let pat_id = self.pat_id(&pat)?;
let variant_res_for_pat =
self.infer.as_ref()?.variant_resolution_for_pat(pat_id);
let variant_res_for_pat = infer.variant_resolution_for_pat(pat_id);
if let Some(VariantId::EnumVariantId(variant)) = variant_res_for_pat {
return Some(PathResolution::Def(ModuleDef::Variant(variant.into())));
}
@ -780,37 +782,22 @@ impl SourceAnalyzer {
false
}
fn resolve_impl_method(
&self,
db: &dyn HirDatabase,
func: FunctionId,
substs: &Substitution,
) -> Option<FunctionId> {
let impled_trait = match func.lookup(db.upcast()).container {
ItemContainerId::TraitId(trait_id) => trait_id,
_ => return None,
};
if substs.is_empty(Interner) {
return None;
}
let self_ty = substs.at(Interner, 0).ty(Interner)?;
let krate = self.resolver.krate();
let trait_env = self.resolver.body_owner()?.as_generic_def_id().map_or_else(
|| Arc::new(hir_ty::TraitEnvironment::empty(krate)),
|d| db.trait_environment(d),
);
let fun_data = db.function_data(func);
method_resolution::lookup_impl_method(self_ty, db, trait_env, impled_trait, &fun_data.name)
}
fn resolve_impl_method_or_trait_def(
&self,
db: &dyn HirDatabase,
func: FunctionId,
substs: &Substitution,
substs: Substitution,
) -> FunctionId {
self.resolve_impl_method(db, func, substs).unwrap_or(func)
let krate = self.resolver.krate();
let owner = match self.resolver.body_owner() {
Some(it) => it,
None => return func,
};
let env = owner.as_generic_def_id().map_or_else(
|| Arc::new(hir_ty::TraitEnvironment::empty(krate)),
|d| db.trait_environment(d),
);
method_resolution::lookup_impl_method(db, env, func, substs)
}
fn lang_trait_fn(

View File

@ -1834,4 +1834,86 @@ fn f() {
"#,
);
}
#[test]
fn goto_bin_op_multiple_impl() {
check(
r#"
//- minicore: add
struct S;
impl core::ops::Add for S {
fn add(
//^^^
) {}
}
impl core::ops::Add<usize> for S {
fn add(
) {}
}
fn f() {
S +$0 S
}
"#,
);
check(
r#"
//- minicore: add
struct S;
impl core::ops::Add for S {
fn add(
) {}
}
impl core::ops::Add<usize> for S {
fn add(
//^^^
) {}
}
fn f() {
S +$0 0usize
}
"#,
);
}
#[test]
fn path_call_multiple_trait_impl() {
check(
r#"
trait Trait<T> {
fn f(_: T);
}
impl Trait<i32> for usize {
fn f(_: i32) {}
//^
}
impl Trait<i64> for usize {
fn f(_: i64) {}
}
fn main() {
usize::f$0(0i32);
}
"#,
);
check(
r#"
trait Trait<T> {
fn f(_: T);
}
impl Trait<i32> for usize {
fn f(_: i32) {}
}
impl Trait<i64> for usize {
fn f(_: i64) {}
//^
}
fn main() {
usize::f$0(0i64);
}
"#,
)
}
}