From 4e703a277220a3f52b4772f298535df232b6cb7c Mon Sep 17 00:00:00 2001
From: Santiago Pastorino <spastorino@gmail.com>
Date: Wed, 26 Oct 2022 16:38:32 -0300
Subject: [PATCH] Add associated_items_for_impl_trait_in_trait query

---
 compiler/rustc_hir/src/definitions.rs         |  6 ++-
 .../src/rmeta/decoder/cstore_impl.rs          |  2 +
 compiler/rustc_metadata/src/rmeta/encoder.rs  |  9 +++++
 compiler/rustc_metadata/src/rmeta/mod.rs      |  1 +
 compiler/rustc_middle/src/hir/map/mod.rs      |  7 ++++
 compiler/rustc_middle/src/query/mod.rs        |  7 ++++
 .../src/typeid/typeid_itanium_cxx_abi.rs      |  1 +
 compiler/rustc_symbol_mangling/src/v0.rs      |  1 +
 compiler/rustc_ty_utils/src/assoc.rs          | 38 ++++++++++++++++++-
 9 files changed, 68 insertions(+), 4 deletions(-)

diff --git a/compiler/rustc_hir/src/definitions.rs b/compiler/rustc_hir/src/definitions.rs
index cd3c620cbb7..8ceb176491b 100644
--- a/compiler/rustc_hir/src/definitions.rs
+++ b/compiler/rustc_hir/src/definitions.rs
@@ -280,6 +280,8 @@ pub enum DefPathData {
     AnonConst,
     /// An `impl Trait` type node.
     ImplTrait,
+    /// `impl Trait` generated associated type node.
+    ImplTraitAssocTy,
 }
 
 impl Definitions {
@@ -403,7 +405,7 @@ impl DefPathData {
             TypeNs(name) | ValueNs(name) | MacroNs(name) | LifetimeNs(name) => Some(name),
 
             Impl | ForeignMod | CrateRoot | Use | GlobalAsm | ClosureExpr | Ctor | AnonConst
-            | ImplTrait => None,
+            | ImplTrait | ImplTraitAssocTy => None,
         }
     }
 
@@ -422,7 +424,7 @@ impl DefPathData {
             ClosureExpr => DefPathDataName::Anon { namespace: sym::closure },
             Ctor => DefPathDataName::Anon { namespace: sym::constructor },
             AnonConst => DefPathDataName::Anon { namespace: sym::constant },
-            ImplTrait => DefPathDataName::Anon { namespace: sym::opaque },
+            ImplTrait | ImplTraitAssocTy => DefPathDataName::Anon { namespace: sym::opaque },
         }
     }
 }
diff --git a/compiler/rustc_metadata/src/rmeta/decoder/cstore_impl.rs b/compiler/rustc_metadata/src/rmeta/decoder/cstore_impl.rs
index 0bacf51e911..9ca8c1abc16 100644
--- a/compiler/rustc_metadata/src/rmeta/decoder/cstore_impl.rs
+++ b/compiler/rustc_metadata/src/rmeta/decoder/cstore_impl.rs
@@ -248,6 +248,8 @@ provide! { tcx, def_id, other, cdata,
             .process_decoded(tcx, || panic!("{def_id:?} does not have trait_impl_trait_tys")))
      }
 
+    associated_items_for_impl_trait_in_trait => { table_defaulted_array }
+
     visibility => { cdata.get_visibility(def_id.index) }
     adt_def => { cdata.get_adt_def(def_id.index, tcx) }
     adt_destructor => {
diff --git a/compiler/rustc_metadata/src/rmeta/encoder.rs b/compiler/rustc_metadata/src/rmeta/encoder.rs
index 16029822d03..f0dafe73c00 100644
--- a/compiler/rustc_metadata/src/rmeta/encoder.rs
+++ b/compiler/rustc_metadata/src/rmeta/encoder.rs
@@ -1129,6 +1129,11 @@ fn should_encode_trait_impl_trait_tys(tcx: TyCtxt<'_>, def_id: DefId) -> bool {
     })
 }
 
+// Return `false` to avoid encoding impl trait in trait, while we don't use the query.
+fn should_encode_fn_impl_trait_in_trait<'tcx>(_tcx: TyCtxt<'tcx>, _def_id: DefId) -> bool {
+    false
+}
+
 impl<'a, 'tcx> EncodeContext<'a, 'tcx> {
     fn encode_attrs(&mut self, def_id: LocalDefId) {
         let tcx = self.tcx;
@@ -1211,6 +1216,10 @@ impl<'a, 'tcx> EncodeContext<'a, 'tcx> {
             {
                 record!(self.tables.trait_impl_trait_tys[def_id] <- table);
             }
+            if should_encode_fn_impl_trait_in_trait(tcx, def_id) {
+                let table = tcx.associated_items_for_impl_trait_in_trait(def_id);
+                record_defaulted_array!(self.tables.associated_items_for_impl_trait_in_trait[def_id] <- table);
+            }
         }
 
         let inherent_impls = tcx.with_stable_hashing_context(|hcx| {
diff --git a/compiler/rustc_metadata/src/rmeta/mod.rs b/compiler/rustc_metadata/src/rmeta/mod.rs
index 01691398ad0..a7ec2d790b7 100644
--- a/compiler/rustc_metadata/src/rmeta/mod.rs
+++ b/compiler/rustc_metadata/src/rmeta/mod.rs
@@ -354,6 +354,7 @@ define_tables! {
     explicit_item_bounds: Table<DefIndex, LazyArray<(ty::Predicate<'static>, Span)>>,
     inferred_outlives_of: Table<DefIndex, LazyArray<(ty::Clause<'static>, Span)>>,
     inherent_impls: Table<DefIndex, LazyArray<DefIndex>>,
+    associated_items_for_impl_trait_in_trait: Table<DefIndex, LazyArray<DefId>>,
 
 - optional:
     attributes: Table<DefIndex, LazyArray<ast::Attribute>>,
diff --git a/compiler/rustc_middle/src/hir/map/mod.rs b/compiler/rustc_middle/src/hir/map/mod.rs
index 2eafc356dc3..2df851a7857 100644
--- a/compiler/rustc_middle/src/hir/map/mod.rs
+++ b/compiler/rustc_middle/src/hir/map/mod.rs
@@ -849,6 +849,13 @@ impl<'hir> Map<'hir> {
         }
     }
 
+    pub fn get_fn_output(self, def_id: LocalDefId) -> Option<&'hir FnRetTy<'hir>> {
+        match self.tcx.hir_owner(OwnerId { def_id }) {
+            Some(Owner { node, .. }) => node.fn_decl().map(|fn_decl| &fn_decl.output),
+            _ => None,
+        }
+    }
+
     pub fn expect_variant(self, id: HirId) -> &'hir Variant<'hir> {
         match self.find(id) {
             Some(Node::Variant(variant)) => variant,
diff --git a/compiler/rustc_middle/src/query/mod.rs b/compiler/rustc_middle/src/query/mod.rs
index 8b846f10603..b7858d9f86d 100644
--- a/compiler/rustc_middle/src/query/mod.rs
+++ b/compiler/rustc_middle/src/query/mod.rs
@@ -767,6 +767,13 @@ rustc_queries! {
         desc { |tcx| "comparing impl items against trait for `{}`", tcx.def_path_str(impl_id) }
     }
 
+    /// Given an `fn_def_id`, create and return the associated items for that function.
+    query associated_items_for_impl_trait_in_trait(fn_def_id: DefId) -> &'tcx [DefId] {
+        desc { |tcx| "creating associated items for impl trait in trait returned by `{}`", tcx.def_path_str(fn_def_id) }
+        cache_on_disk_if { fn_def_id.is_local() }
+        separate_provide_extern
+    }
+
     /// Given an `impl_id`, return the trait it implements.
     /// Return `None` if this is an inherent impl.
     query impl_trait_ref(impl_id: DefId) -> Option<ty::EarlyBinder<ty::TraitRef<'tcx>>> {
diff --git a/compiler/rustc_symbol_mangling/src/typeid/typeid_itanium_cxx_abi.rs b/compiler/rustc_symbol_mangling/src/typeid/typeid_itanium_cxx_abi.rs
index 949ef04dd4d..f4004784626 100644
--- a/compiler/rustc_symbol_mangling/src/typeid/typeid_itanium_cxx_abi.rs
+++ b/compiler/rustc_symbol_mangling/src/typeid/typeid_itanium_cxx_abi.rs
@@ -396,6 +396,7 @@ fn encode_ty_name(tcx: TyCtxt<'_>, def_id: DefId) -> String {
             hir::definitions::DefPathData::CrateRoot
             | hir::definitions::DefPathData::Use
             | hir::definitions::DefPathData::GlobalAsm
+            | hir::definitions::DefPathData::ImplTraitAssocTy
             | hir::definitions::DefPathData::MacroNs(..)
             | hir::definitions::DefPathData::LifetimeNs(..) => {
                 bug!("encode_ty_name: unexpected `{:?}`", disambiguated_data.data);
diff --git a/compiler/rustc_symbol_mangling/src/v0.rs b/compiler/rustc_symbol_mangling/src/v0.rs
index c58b6a24ab5..d53f0cea6b5 100644
--- a/compiler/rustc_symbol_mangling/src/v0.rs
+++ b/compiler/rustc_symbol_mangling/src/v0.rs
@@ -791,6 +791,7 @@ impl<'tcx> Printer<'tcx> for &mut SymbolMangler<'tcx> {
             | DefPathData::Use
             | DefPathData::GlobalAsm
             | DefPathData::Impl
+            | DefPathData::ImplTraitAssocTy
             | DefPathData::MacroNs(_)
             | DefPathData::LifetimeNs(_) => {
                 bug!("symbol_names: unexpected DefPathData: {:?}", disambiguated_data.data)
diff --git a/compiler/rustc_ty_utils/src/assoc.rs b/compiler/rustc_ty_utils/src/assoc.rs
index a6e0f13f698..a6c137d2f3e 100644
--- a/compiler/rustc_ty_utils/src/assoc.rs
+++ b/compiler/rustc_ty_utils/src/assoc.rs
@@ -1,13 +1,16 @@
 use rustc_data_structures::fx::FxHashMap;
 use rustc_hir as hir;
-use rustc_hir::def_id::DefId;
-use rustc_middle::ty::{self, TyCtxt};
+use rustc_hir::def_id::{DefId, LocalDefId};
+use rustc_hir::definitions::DefPathData;
+use rustc_hir::intravisit::{self, Visitor};
+use rustc_middle::ty::{self, DefIdTree, TyCtxt};
 
 pub fn provide(providers: &mut ty::query::Providers) {
     *providers = ty::query::Providers {
         associated_item,
         associated_item_def_ids,
         associated_items,
+        associated_items_for_impl_trait_in_trait,
         impl_item_implementor_ids,
         ..*providers
     };
@@ -112,3 +115,34 @@ fn associated_item_from_impl_item_ref(impl_item_ref: &hir::ImplItemRef) -> ty::A
         fn_has_self_parameter: has_self,
     }
 }
+
+fn associated_items_for_impl_trait_in_trait(tcx: TyCtxt<'_>, fn_def_id: DefId) -> &'_ [DefId] {
+    struct RPITVisitor {
+        rpits: Vec<LocalDefId>,
+    }
+
+    impl<'v> Visitor<'v> for RPITVisitor {
+        fn visit_ty(&mut self, ty: &'v hir::Ty<'v>) {
+            if let hir::TyKind::OpaqueDef(item_id, _, _) = ty.kind {
+                self.rpits.push(item_id.owner_id.def_id)
+            }
+            intravisit::walk_ty(self, ty)
+        }
+    }
+
+    let mut visitor = RPITVisitor { rpits: Vec::new() };
+
+    if let Some(output) = tcx.hir().get_fn_output(fn_def_id.expect_local()) {
+        visitor.visit_fn_ret_ty(output);
+
+        let trait_def_id = tcx.parent(fn_def_id).expect_local();
+
+        tcx.arena.alloc_from_iter(visitor.rpits.iter().map(|_opaque_ty_def_id| {
+            let trait_assoc_ty =
+                tcx.at(output.span()).create_def(trait_def_id, DefPathData::ImplTraitAssocTy);
+            trait_assoc_ty.def_id().to_def_id()
+        }))
+    } else {
+        &[]
+    }
+}