diff --git a/crates/hir_def/src/lib.rs b/crates/hir_def/src/lib.rs index 303083c6d2e..bb174aec810 100644 --- a/crates/hir_def/src/lib.rs +++ b/crates/hir_def/src/lib.rs @@ -112,6 +112,10 @@ impl ModuleId { self.def_map(db).containing_module(self.local_id) } + pub fn containing_block(&self) -> Option { + self.block + } + /// Returns `true` if this module represents a block expression. /// /// Returns `false` if this module is a submodule *inside* a block expression @@ -581,6 +585,18 @@ impl HasModule for GenericDefId { } } +impl HasModule for TypeAliasId { + fn module(&self, db: &dyn db::DefDatabase) -> ModuleId { + self.lookup(db).module(db) + } +} + +impl HasModule for TraitId { + fn module(&self, db: &dyn db::DefDatabase) -> ModuleId { + self.lookup(db).container + } +} + impl HasModule for StaticLoc { fn module(&self, _db: &dyn db::DefDatabase) -> ModuleId { self.container diff --git a/crates/hir_ty/src/chalk_db.rs b/crates/hir_ty/src/chalk_db.rs index 34c3f6bd912..a4c09c742aa 100644 --- a/crates/hir_ty/src/chalk_db.rs +++ b/crates/hir_ty/src/chalk_db.rs @@ -10,16 +10,16 @@ use chalk_solve::rust_ir::{self, OpaqueTyDatumBound, WellKnownTrait}; use base_db::CrateId; use hir_def::{ lang_item::{lang_attr, LangItemTarget}, - AssocContainerId, AssocItemId, GenericDefId, HasModule, Lookup, TypeAliasId, + AssocContainerId, AssocItemId, GenericDefId, HasModule, Lookup, ModuleId, TypeAliasId, }; use hir_expand::name::name; use crate::{ db::HirDatabase, display::HirDisplay, - from_assoc_type_id, from_chalk_trait_id, make_only_type_binders, + from_assoc_type_id, from_chalk_trait_id, from_foreign_def_id, make_only_type_binders, mapping::{from_chalk, ToChalk, TypeAliasAsValue}, - method_resolution::{TyFingerprint, ALL_FLOAT_FPS, ALL_INT_FPS}, + method_resolution::{TraitImpls, TyFingerprint, ALL_FLOAT_FPS, ALL_INT_FPS}, to_assoc_type_id, to_chalk_trait_id, traits::ChalkContext, utils::generics, @@ -105,12 +105,30 @@ impl<'a> chalk_solve::RustIrDatabase for ChalkContext<'a> { _ => self_ty_fp.as_ref().map(std::slice::from_ref).unwrap_or(&[]), }; + fn local_impls(db: &dyn HirDatabase, module: ModuleId) -> Option> { + db.trait_impls_in_block(module.containing_block()?) + } + // Note: Since we're using impls_for_trait, only impls where the trait - // can be resolved should ever reach Chalk. Symbol’s value as variable is void: impl_datum relies on that + // can be resolved should ever reach Chalk. impl_datum relies on that // and will panic if the trait can't be resolved. let in_deps = self.db.trait_impls_in_deps(self.krate); let in_self = self.db.trait_impls_in_crate(self.krate); - let impl_maps = [in_deps, in_self]; + let trait_module = trait_.module(self.db.upcast()); + let type_module = match self_ty_fp { + Some(TyFingerprint::Adt(adt_id)) => Some(adt_id.module(self.db.upcast())), + Some(TyFingerprint::ForeignType(type_id)) => { + Some(from_foreign_def_id(type_id).module(self.db.upcast())) + } + Some(TyFingerprint::Dyn(trait_id)) => Some(trait_id.module(self.db.upcast())), + _ => None, + }; + let impl_maps = [ + Some(in_deps), + Some(in_self), + local_impls(self.db, trait_module), + type_module.and_then(|m| local_impls(self.db, m)), + ]; let id_to_chalk = |id: hir_def::ImplId| id.to_chalk(self.db); @@ -118,14 +136,16 @@ impl<'a> chalk_solve::RustIrDatabase for ChalkContext<'a> { debug!("Unrestricted search for {:?} impls...", trait_); impl_maps .iter() - .flat_map(|crate_impl_defs| crate_impl_defs.for_trait(trait_).map(id_to_chalk)) + .filter_map(|o| o.as_ref()) + .flat_map(|impls| impls.for_trait(trait_).map(id_to_chalk)) .collect() } else { impl_maps .iter() - .flat_map(|crate_impl_defs| { + .filter_map(|o| o.as_ref()) + .flat_map(|impls| { fps.iter().flat_map(move |fp| { - crate_impl_defs.for_trait_and_self_ty(trait_, *fp).map(id_to_chalk) + impls.for_trait_and_self_ty(trait_, *fp).map(id_to_chalk) }) }) .collect() diff --git a/crates/hir_ty/src/db.rs b/crates/hir_ty/src/db.rs index be5b9110ed9..b9003c413bb 100644 --- a/crates/hir_ty/src/db.rs +++ b/crates/hir_ty/src/db.rs @@ -5,8 +5,8 @@ use std::sync::Arc; use base_db::{impl_intern_key, salsa, CrateId, Upcast}; use hir_def::{ - db::DefDatabase, expr::ExprId, ConstParamId, DefWithBodyId, FunctionId, GenericDefId, ImplId, - LifetimeParamId, LocalFieldId, TypeParamId, VariantId, + db::DefDatabase, expr::ExprId, BlockId, ConstParamId, DefWithBodyId, FunctionId, GenericDefId, + ImplId, LifetimeParamId, LocalFieldId, TypeParamId, VariantId, }; use la_arena::ArenaMap; @@ -79,6 +79,9 @@ pub trait HirDatabase: DefDatabase + Upcast { #[salsa::invoke(TraitImpls::trait_impls_in_crate_query)] fn trait_impls_in_crate(&self, krate: CrateId) -> Arc; + #[salsa::invoke(TraitImpls::trait_impls_in_block_query)] + fn trait_impls_in_block(&self, krate: BlockId) -> Option>; + #[salsa::invoke(TraitImpls::trait_impls_in_deps_query)] fn trait_impls_in_deps(&self, krate: CrateId) -> Arc; diff --git a/crates/hir_ty/src/method_resolution.rs b/crates/hir_ty/src/method_resolution.rs index f3d390961a5..3d233b1e20a 100644 --- a/crates/hir_ty/src/method_resolution.rs +++ b/crates/hir_ty/src/method_resolution.rs @@ -8,7 +8,7 @@ use arrayvec::ArrayVec; use base_db::{CrateId, Edition}; use chalk_ir::{cast::Cast, Mutability, UniverseIndex}; use hir_def::{ - lang_item::LangItemTarget, nameres::DefMap, AssocContainerId, AssocItemId, FunctionId, + lang_item::LangItemTarget, nameres::DefMap, AssocContainerId, AssocItemId, BlockId, FunctionId, GenericDefId, HasModule, ImplId, Lookup, ModuleId, TraitId, }; use hir_expand::name::Name; @@ -139,35 +139,47 @@ impl TraitImpls { let mut impls = Self { map: FxHashMap::default() }; let crate_def_map = db.crate_def_map(krate); - collect_def_map(db, &crate_def_map, &mut impls); + impls.collect_def_map(db, &crate_def_map); return Arc::new(impls); + } - fn collect_def_map(db: &dyn HirDatabase, def_map: &DefMap, impls: &mut TraitImpls) { - for (_module_id, module_data) in def_map.modules() { - for impl_id in module_data.scope.impls() { - let target_trait = match db.impl_trait(impl_id) { - Some(tr) => tr.skip_binders().hir_trait_id(), - None => continue, - }; - let self_ty = db.impl_self_ty(impl_id); - let self_ty_fp = TyFingerprint::for_trait_impl(self_ty.skip_binders()); - impls - .map - .entry(target_trait) - .or_default() - .entry(self_ty_fp) - .or_default() - .push(impl_id); - } + pub(crate) fn trait_impls_in_block_query( + db: &dyn HirDatabase, + block: BlockId, + ) -> Option> { + let _p = profile::span("trait_impls_in_block_query"); + let mut impls = Self { map: FxHashMap::default() }; - // To better support custom derives, collect impls in all unnamed const items. - // const _: () = { ... }; - for konst in module_data.scope.unnamed_consts() { - let body = db.body(konst.into()); - for (_, block_def_map) in body.blocks(db.upcast()) { - collect_def_map(db, &block_def_map, impls); - } + let block_def_map = db.block_def_map(block)?; + impls.collect_def_map(db, &block_def_map); + + return Some(Arc::new(impls)); + } + + fn collect_def_map(&mut self, db: &dyn HirDatabase, def_map: &DefMap) { + for (_module_id, module_data) in def_map.modules() { + for impl_id in module_data.scope.impls() { + let target_trait = match db.impl_trait(impl_id) { + Some(tr) => tr.skip_binders().hir_trait_id(), + None => continue, + }; + let self_ty = db.impl_self_ty(impl_id); + let self_ty_fp = TyFingerprint::for_trait_impl(self_ty.skip_binders()); + self.map + .entry(target_trait) + .or_default() + .entry(self_ty_fp) + .or_default() + .push(impl_id); + } + + // To better support custom derives, collect impls in all unnamed const items. + // const _: () = { ... }; + for konst in module_data.scope.unnamed_consts() { + let body = db.body(konst.into()); + for (_, block_def_map) in body.blocks(db.upcast()) { + self.collect_def_map(db, &block_def_map); } } } diff --git a/crates/hir_ty/src/tests/traits.rs b/crates/hir_ty/src/tests/traits.rs index 588f0d1d41b..6bcede4c46a 100644 --- a/crates/hir_ty/src/tests/traits.rs +++ b/crates/hir_ty/src/tests/traits.rs @@ -3740,3 +3740,70 @@ mod future { "#, ); } + +#[test] +fn local_impl_1() { + check_types( + r#" +trait Trait { + fn foo(&self) -> T; +} + +fn test() { + struct S; + impl Trait for S { + fn foo(&self) { 0 } + } + + S.foo(); + // ^^^^^^^ u32 +} +"#, + ); +} + +#[test] +fn local_impl_2() { + check_types( + r#" +struct S; + +fn test() { + trait Trait { + fn foo(&self) -> T; + } + impl Trait for S { + fn foo(&self) { 0 } + } + + S.foo(); + // ^^^^^^^ u32 +} +"#, + ); +} + +#[test] +fn local_impl_3() { + check_types( + r#" +trait Trait { + fn foo(&self) -> T; +} + +fn test() { + struct S1; + { + struct S2; + + impl Trait for S2 { + fn foo(&self) { S1 } + } + + S2.foo(); + // ^^^^^^^^ S1 + } +} +"#, + ); +}