From 7ec72efe10df28fcf5c6ec13c2a487572041be59 Mon Sep 17 00:00:00 2001
From: Michael Goulet <michael@errs.io>
Date: Thu, 2 Feb 2023 21:22:02 +0000
Subject: [PATCH] Allow the elaborator to only filter to real supertraits

---
 .../rustc_hir_analysis/src/astconv/mod.rs     | 66 ++++++++++---------
 compiler/rustc_hir_typeck/src/closure.rs      |  5 +-
 compiler/rustc_infer/src/traits/util.rs       | 41 +++++++-----
 compiler/rustc_lint/src/unused.rs             |  2 +
 .../src/solve/assembly/mod.rs                 |  5 +-
 .../alias/dont-elaborate-non-self.stderr      | 20 ++++++
 .../alias-where-clause-isnt-supertrait.stderr | 14 ++++
 7 files changed, 106 insertions(+), 47 deletions(-)
 create mode 100644 tests/ui/traits/alias/dont-elaborate-non-self.stderr
 create mode 100644 tests/ui/traits/trait-upcasting/alias-where-clause-isnt-supertrait.stderr

diff --git a/compiler/rustc_hir_analysis/src/astconv/mod.rs b/compiler/rustc_hir_analysis/src/astconv/mod.rs
index 37c894348cd..8cb95610da0 100644
--- a/compiler/rustc_hir_analysis/src/astconv/mod.rs
+++ b/compiler/rustc_hir_analysis/src/astconv/mod.rs
@@ -1663,39 +1663,45 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
             })
         });
 
-        let existential_projections = projection_bounds.iter().map(|(bound, _)| {
-            bound.map_bound(|mut b| {
-                assert_eq!(b.projection_ty.self_ty(), dummy_self);
+        let existential_projections = projection_bounds
+            .iter()
+            // We filter out traits that don't have `Self` as their self type above,
+            // we need to do the same for projections.
+            .filter(|(bound, _)| bound.skip_binder().self_ty() == dummy_self)
+            .map(|(bound, _)| {
+                bound.map_bound(|mut b| {
+                    assert_eq!(b.projection_ty.self_ty(), dummy_self);
 
-                // Like for trait refs, verify that `dummy_self` did not leak inside default type
-                // parameters.
-                let references_self = b.projection_ty.substs.iter().skip(1).any(|arg| {
-                    if arg.walk().any(|arg| arg == dummy_self.into()) {
-                        return true;
+                    // Like for trait refs, verify that `dummy_self` did not leak inside default type
+                    // parameters.
+                    let references_self = b.projection_ty.substs.iter().skip(1).any(|arg| {
+                        if arg.walk().any(|arg| arg == dummy_self.into()) {
+                            return true;
+                        }
+                        false
+                    });
+                    if references_self {
+                        let guar = tcx.sess.delay_span_bug(
+                            span,
+                            "trait object projection bounds reference `Self`",
+                        );
+                        let substs: Vec<_> = b
+                            .projection_ty
+                            .substs
+                            .iter()
+                            .map(|arg| {
+                                if arg.walk().any(|arg| arg == dummy_self.into()) {
+                                    return tcx.ty_error(guar).into();
+                                }
+                                arg
+                            })
+                            .collect();
+                        b.projection_ty.substs = tcx.mk_substs(&substs);
                     }
-                    false
-                });
-                if references_self {
-                    let guar = tcx
-                        .sess
-                        .delay_span_bug(span, "trait object projection bounds reference `Self`");
-                    let substs: Vec<_> = b
-                        .projection_ty
-                        .substs
-                        .iter()
-                        .map(|arg| {
-                            if arg.walk().any(|arg| arg == dummy_self.into()) {
-                                return tcx.ty_error(guar).into();
-                            }
-                            arg
-                        })
-                        .collect();
-                    b.projection_ty.substs = tcx.mk_substs(&substs);
-                }
 
-                ty::ExistentialProjection::erase_self_ty(tcx, b)
-            })
-        });
+                    ty::ExistentialProjection::erase_self_ty(tcx, b)
+                })
+            });
 
         let regular_trait_predicates = existential_trait_refs
             .map(|trait_ref| trait_ref.map_bound(ty::ExistentialPredicate::Trait));
diff --git a/compiler/rustc_hir_typeck/src/closure.rs b/compiler/rustc_hir_typeck/src/closure.rs
index 15eec42d786..8c2495e1dd8 100644
--- a/compiler/rustc_hir_typeck/src/closure.rs
+++ b/compiler/rustc_hir_typeck/src/closure.rs
@@ -210,7 +210,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
             // and we want to keep inference generally in the same order of
             // the registered obligations.
             predicates.rev(),
-        ) {
+        )
+        // We only care about self bounds
+        .filter_only_self()
+        {
             debug!(?pred);
             let bound_predicate = pred.kind();
 
diff --git a/compiler/rustc_infer/src/traits/util.rs b/compiler/rustc_infer/src/traits/util.rs
index 1f7c7652d94..ef01d5d513b 100644
--- a/compiler/rustc_infer/src/traits/util.rs
+++ b/compiler/rustc_infer/src/traits/util.rs
@@ -69,6 +69,7 @@ impl<'tcx> Extend<ty::Predicate<'tcx>> for PredicateSet<'tcx> {
 pub struct Elaborator<'tcx, O> {
     stack: Vec<O>,
     visited: PredicateSet<'tcx>,
+    only_self: bool,
 }
 
 /// Describes how to elaborate an obligation into a sub-obligation.
@@ -170,7 +171,8 @@ pub fn elaborate<'tcx, O: Elaboratable<'tcx>>(
     tcx: TyCtxt<'tcx>,
     obligations: impl IntoIterator<Item = O>,
 ) -> Elaborator<'tcx, O> {
-    let mut elaborator = Elaborator { stack: Vec::new(), visited: PredicateSet::new(tcx) };
+    let mut elaborator =
+        Elaborator { stack: Vec::new(), visited: PredicateSet::new(tcx), only_self: false };
     elaborator.extend_deduped(obligations);
     elaborator
 }
@@ -185,14 +187,25 @@ impl<'tcx, O: Elaboratable<'tcx>> Elaborator<'tcx, O> {
         self.stack.extend(obligations.into_iter().filter(|o| self.visited.insert(o.predicate())));
     }
 
+    /// Filter to only the supertraits of trait predicates, i.e. only the predicates
+    /// that have `Self` as their self type, instead of all implied predicates.
+    pub fn filter_only_self(mut self) -> Self {
+        self.only_self = true;
+        self
+    }
+
     fn elaborate(&mut self, elaboratable: &O) {
         let tcx = self.visited.tcx;
 
         let bound_predicate = elaboratable.predicate().kind();
         match bound_predicate.skip_binder() {
             ty::PredicateKind::Clause(ty::Clause::Trait(data)) => {
-                // Get predicates declared on the trait.
-                let predicates = tcx.implied_predicates_of(data.def_id());
+                // Get predicates implied by the trait, or only super predicates if we only care about self predicates.
+                let predicates = if self.only_self {
+                    tcx.super_predicates_of(data.def_id())
+                } else {
+                    tcx.implied_predicates_of(data.def_id())
+                };
 
                 let obligations =
                     predicates.predicates.iter().enumerate().map(|(index, &(mut pred, span))| {
@@ -350,18 +363,16 @@ pub fn supertraits<'tcx>(
     tcx: TyCtxt<'tcx>,
     trait_ref: ty::PolyTraitRef<'tcx>,
 ) -> impl Iterator<Item = ty::PolyTraitRef<'tcx>> {
-    let pred: ty::Predicate<'tcx> = trait_ref.to_predicate(tcx);
-    FilterToTraits::new(elaborate(tcx, [pred]))
+    elaborate(tcx, [trait_ref.to_predicate(tcx)]).filter_only_self().filter_to_traits()
 }
 
 pub fn transitive_bounds<'tcx>(
     tcx: TyCtxt<'tcx>,
     trait_refs: impl Iterator<Item = ty::PolyTraitRef<'tcx>>,
 ) -> impl Iterator<Item = ty::PolyTraitRef<'tcx>> {
-    FilterToTraits::new(elaborate(
-        tcx,
-        trait_refs.map(|trait_ref| -> ty::Predicate<'tcx> { trait_ref.to_predicate(tcx) }),
-    ))
+    elaborate(tcx, trait_refs.map(|trait_ref| trait_ref.to_predicate(tcx)))
+        .filter_only_self()
+        .filter_to_traits()
 }
 
 /// A specialized variant of `elaborate` that only elaborates trait references that may
@@ -402,18 +413,18 @@ pub fn transitive_bounds_that_define_assoc_type<'tcx>(
 // Other
 ///////////////////////////////////////////////////////////////////////////
 
+impl<'tcx> Elaborator<'tcx, ty::Predicate<'tcx>> {
+    fn filter_to_traits(self) -> FilterToTraits<Self> {
+        FilterToTraits { base_iterator: self }
+    }
+}
+
 /// A filter around an iterator of predicates that makes it yield up
 /// just trait references.
 pub struct FilterToTraits<I> {
     base_iterator: I,
 }
 
-impl<I> FilterToTraits<I> {
-    fn new(base: I) -> FilterToTraits<I> {
-        FilterToTraits { base_iterator: base }
-    }
-}
-
 impl<'tcx, I: Iterator<Item = ty::Predicate<'tcx>>> Iterator for FilterToTraits<I> {
     type Item = ty::PolyTraitRef<'tcx>;
 
diff --git a/compiler/rustc_lint/src/unused.rs b/compiler/rustc_lint/src/unused.rs
index 35c461f5ace..1159d11e5c0 100644
--- a/compiler/rustc_lint/src/unused.rs
+++ b/compiler/rustc_lint/src/unused.rs
@@ -255,6 +255,8 @@ impl<'tcx> LateLintPass<'tcx> for UnusedResults {
                 ty::Adt(def, _) => is_def_must_use(cx, def.did(), span),
                 ty::Alias(ty::Opaque, ty::AliasTy { def_id: def, .. }) => {
                     elaborate(cx.tcx, cx.tcx.explicit_item_bounds(def).iter().cloned())
+                        // We only care about self bounds for the impl-trait
+                        .filter_only_self()
                         .find_map(|(pred, _span)| {
                             // We only look at the `DefId`, so it is safe to skip the binder here.
                             if let ty::PredicateKind::Clause(ty::Clause::Trait(
diff --git a/compiler/rustc_trait_selection/src/solve/assembly/mod.rs b/compiler/rustc_trait_selection/src/solve/assembly/mod.rs
index 12ee80b6722..3fb1d49b338 100644
--- a/compiler/rustc_trait_selection/src/solve/assembly/mod.rs
+++ b/compiler/rustc_trait_selection/src/solve/assembly/mod.rs
@@ -498,7 +498,10 @@ impl<'tcx> EvalCtxt<'_, 'tcx> {
         let tcx = self.tcx();
         let own_bounds: FxIndexSet<_> =
             bounds.iter().map(|bound| bound.with_self_ty(tcx, self_ty)).collect();
-        for assumption in elaborate(tcx, own_bounds.iter().copied()) {
+        for assumption in elaborate(tcx, own_bounds.iter().copied())
+            // we only care about bounds that match the `Self` type
+            .filter_only_self()
+        {
             // FIXME: Predicates are fully elaborated in the object type's existential bounds
             // list. We want to only consider these pre-elaborated projections, and not other
             // projection predicates that we reach by elaborating the principal trait ref,
diff --git a/tests/ui/traits/alias/dont-elaborate-non-self.stderr b/tests/ui/traits/alias/dont-elaborate-non-self.stderr
new file mode 100644
index 00000000000..247a4f81280
--- /dev/null
+++ b/tests/ui/traits/alias/dont-elaborate-non-self.stderr
@@ -0,0 +1,20 @@
+error[E0277]: the size for values of type `(dyn Fn() -> Fut + 'static)` cannot be known at compilation time
+  --> $DIR/dont-elaborate-non-self.rs:7:11
+   |
+LL | fn f<Fut>(a: dyn F<Fut>) {}
+   |           ^ doesn't have a size known at compile-time
+   |
+   = help: the trait `Sized` is not implemented for `(dyn Fn() -> Fut + 'static)`
+   = help: unsized fn params are gated as an unstable feature
+help: you can use `impl Trait` as the argument type
+   |
+LL | fn f<Fut>(a: impl F<Fut>) {}
+   |              ~~~~
+help: function arguments must have a statically known size, borrowed types always have a known size
+   |
+LL | fn f<Fut>(a: &dyn F<Fut>) {}
+   |              +
+
+error: aborting due to previous error
+
+For more information about this error, try `rustc --explain E0277`.
diff --git a/tests/ui/traits/trait-upcasting/alias-where-clause-isnt-supertrait.stderr b/tests/ui/traits/trait-upcasting/alias-where-clause-isnt-supertrait.stderr
new file mode 100644
index 00000000000..5574a032089
--- /dev/null
+++ b/tests/ui/traits/trait-upcasting/alias-where-clause-isnt-supertrait.stderr
@@ -0,0 +1,14 @@
+error[E0308]: mismatched types
+  --> $DIR/alias-where-clause-isnt-supertrait.rs:27:5
+   |
+LL | fn test(x: &dyn C) -> &dyn B {
+   |                       ------ expected `&dyn B` because of return type
+LL |     x
+   |     ^ expected trait `B`, found trait `C`
+   |
+   = note: expected reference `&dyn B`
+              found reference `&dyn C`
+
+error: aborting due to previous error
+
+For more information about this error, try `rustc --explain E0308`.