From 3cb36317cdea5c4de0224f64a0ec6db5cd50a8fd Mon Sep 17 00:00:00 2001
From: Michael Goulet <michael@errs.io>
Date: Mon, 19 Feb 2024 20:36:17 +0000
Subject: [PATCH 1/7] Preserve variance on error in generalizer

---
 compiler/rustc_infer/src/infer/relate/generalize.rs | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/compiler/rustc_infer/src/infer/relate/generalize.rs b/compiler/rustc_infer/src/infer/relate/generalize.rs
index e84d4ceaea8..b18c8a8b844 100644
--- a/compiler/rustc_infer/src/infer/relate/generalize.rs
+++ b/compiler/rustc_infer/src/infer/relate/generalize.rs
@@ -440,9 +440,9 @@ impl<'tcx> TypeRelation<'tcx> for Generalizer<'_, 'tcx> {
         debug!(?self.ambient_variance, "new ambient variance");
         // Recursive calls to `relate` can overflow the stack. For example a deeper version of
         // `ui/associated-consts/issue-93775.rs`.
-        let r = ensure_sufficient_stack(|| self.relate(a, b))?;
+        let r = ensure_sufficient_stack(|| self.relate(a, b));
         self.ambient_variance = old_ambient_variance;
-        Ok(r)
+        r
     }
 
     #[instrument(level = "debug", skip(self, t2), ret)]

From c87b727a23cd1a04379770edbdc758538a8bc3d0 Mon Sep 17 00:00:00 2001
From: Michael Goulet <michael@errs.io>
Date: Thu, 22 Feb 2024 17:18:35 +0000
Subject: [PATCH 2/7] Combine sub and eq

---
 .../src/type_check/relate_tys.rs              |  13 +-
 .../rustc_infer/src/infer/relate/combine.rs   |  30 +-
 .../rustc_infer/src/infer/relate/equate.rs    | 228 ------------
 compiler/rustc_infer/src/infer/relate/glb.rs  |  10 +-
 compiler/rustc_infer/src/infer/relate/lub.rs  |  10 +-
 compiler/rustc_infer/src/infer/relate/mod.rs  |   3 +-
 compiler/rustc_infer/src/infer/relate/sub.rs  | 229 ------------
 .../src/infer/relate/type_relating.rs         | 325 ++++++++++++++++++
 8 files changed, 356 insertions(+), 492 deletions(-)
 delete mode 100644 compiler/rustc_infer/src/infer/relate/equate.rs
 delete mode 100644 compiler/rustc_infer/src/infer/relate/sub.rs
 create mode 100644 compiler/rustc_infer/src/infer/relate/type_relating.rs

diff --git a/compiler/rustc_borrowck/src/type_check/relate_tys.rs b/compiler/rustc_borrowck/src/type_check/relate_tys.rs
index 85f63371659..fd6c29379a3 100644
--- a/compiler/rustc_borrowck/src/type_check/relate_tys.rs
+++ b/compiler/rustc_borrowck/src/type_check/relate_tys.rs
@@ -349,12 +349,15 @@ impl<'bccx, 'tcx> TypeRelation<'tcx> for NllTypeRelating<'_, 'bccx, 'tcx> {
 
         debug!(?self.ambient_variance);
         // In a bivariant context this always succeeds.
-        let r =
-            if self.ambient_variance == ty::Variance::Bivariant { a } else { self.relate(a, b)? };
+        let r = if self.ambient_variance == ty::Variance::Bivariant {
+            Ok(a)
+        } else {
+            self.relate(a, b)
+        };
 
         self.ambient_variance = old_ambient_variance;
 
-        Ok(r)
+        r
     }
 
     #[instrument(skip(self), level = "debug")]
@@ -579,10 +582,6 @@ impl<'bccx, 'tcx> ObligationEmittingRelation<'tcx> for NllTypeRelating<'_, 'bccx
         );
     }
 
-    fn alias_relate_direction(&self) -> ty::AliasRelationDirection {
-        unreachable!("manually overridden to handle ty::Variance::Contravariant ambient variance")
-    }
-
     fn register_type_relate_obligation(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) {
         self.register_predicates([ty::Binder::dummy(match self.ambient_variance {
             ty::Variance::Covariant => ty::PredicateKind::AliasRelate(
diff --git a/compiler/rustc_infer/src/infer/relate/combine.rs b/compiler/rustc_infer/src/infer/relate/combine.rs
index 0852bb4f993..0a550660f94 100644
--- a/compiler/rustc_infer/src/infer/relate/combine.rs
+++ b/compiler/rustc_infer/src/infer/relate/combine.rs
@@ -1,4 +1,6 @@
-//! There are four type combiners: [Equate], [Sub], [Lub], and [Glb].
+//! There are four type combiners: [TypeRelating], [Lub], and [Glb],
+//! and `NllTypeRelating` in rustc_borrowck, which is only used for NLL.
+//!
 //! Each implements the trait [TypeRelation] and contains methods for
 //! combining two instances of various things and yielding a new instance.
 //! These combiner methods always yield a `Result<T>`. To relate two
@@ -22,10 +24,9 @@
 //! [TypeRelation::a_is_expected], so when dealing with contravariance
 //! this should be correctly updated.
 
-use super::equate::Equate;
 use super::glb::Glb;
 use super::lub::Lub;
-use super::sub::Sub;
+use super::type_relating::TypeRelating;
 use super::StructurallyRelateAliases;
 use crate::infer::{DefineOpaqueTypes, InferCtxt, TypeTrace};
 use crate::traits::{Obligation, PredicateObligations};
@@ -322,12 +323,12 @@ impl<'infcx, 'tcx> CombineFields<'infcx, 'tcx> {
         &'a mut self,
         structurally_relate_aliases: StructurallyRelateAliases,
         a_is_expected: bool,
-    ) -> Equate<'a, 'infcx, 'tcx> {
-        Equate::new(self, structurally_relate_aliases, a_is_expected)
+    ) -> TypeRelating<'a, 'infcx, 'tcx> {
+        TypeRelating::new(self, a_is_expected, structurally_relate_aliases, ty::Invariant)
     }
 
-    pub fn sub<'a>(&'a mut self, a_is_expected: bool) -> Sub<'a, 'infcx, 'tcx> {
-        Sub::new(self, a_is_expected)
+    pub fn sub<'a>(&'a mut self, a_is_expected: bool) -> TypeRelating<'a, 'infcx, 'tcx> {
+        TypeRelating::new(self, a_is_expected, StructurallyRelateAliases::No, ty::Covariant)
     }
 
     pub fn lub<'a>(&'a mut self, a_is_expected: bool) -> Lub<'a, 'infcx, 'tcx> {
@@ -367,19 +368,8 @@ pub trait ObligationEmittingRelation<'tcx>: TypeRelation<'tcx> {
     /// be used if control over the obligation causes is required.
     fn register_predicates(&mut self, obligations: impl IntoIterator<Item: ToPredicate<'tcx>>);
 
-    /// Register an obligation that both types must be related to each other according to
-    /// the [`ty::AliasRelationDirection`] given by [`ObligationEmittingRelation::alias_relate_direction`]
-    fn register_type_relate_obligation(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) {
-        self.register_predicates([ty::Binder::dummy(ty::PredicateKind::AliasRelate(
-            a.into(),
-            b.into(),
-            self.alias_relate_direction(),
-        ))]);
-    }
-
-    /// Relation direction emitted for `AliasRelate` predicates, corresponding to the direction
-    /// of the relation.
-    fn alias_relate_direction(&self) -> ty::AliasRelationDirection;
+    /// Register `AliasRelate` obligation(s) that both types must be related to each other.
+    fn register_type_relate_obligation(&mut self, a: Ty<'tcx>, b: Ty<'tcx>);
 }
 
 fn int_unification_error<'tcx>(
diff --git a/compiler/rustc_infer/src/infer/relate/equate.rs b/compiler/rustc_infer/src/infer/relate/equate.rs
deleted file mode 100644
index 1617a062ea0..00000000000
--- a/compiler/rustc_infer/src/infer/relate/equate.rs
+++ /dev/null
@@ -1,228 +0,0 @@
-use super::combine::{CombineFields, ObligationEmittingRelation};
-use super::StructurallyRelateAliases;
-use crate::infer::BoundRegionConversionTime::HigherRankedType;
-use crate::infer::{DefineOpaqueTypes, SubregionOrigin};
-use crate::traits::PredicateObligations;
-
-use rustc_middle::ty::relate::{self, Relate, RelateResult, TypeRelation};
-use rustc_middle::ty::GenericArgsRef;
-use rustc_middle::ty::TyVar;
-use rustc_middle::ty::{self, Ty, TyCtxt};
-
-use rustc_hir::def_id::DefId;
-use rustc_span::Span;
-
-/// Ensures `a` is made equal to `b`. Returns `a` on success.
-pub struct Equate<'combine, 'infcx, 'tcx> {
-    fields: &'combine mut CombineFields<'infcx, 'tcx>,
-    structurally_relate_aliases: StructurallyRelateAliases,
-    a_is_expected: bool,
-}
-
-impl<'combine, 'infcx, 'tcx> Equate<'combine, 'infcx, 'tcx> {
-    pub fn new(
-        fields: &'combine mut CombineFields<'infcx, 'tcx>,
-        structurally_relate_aliases: StructurallyRelateAliases,
-        a_is_expected: bool,
-    ) -> Equate<'combine, 'infcx, 'tcx> {
-        Equate { fields, structurally_relate_aliases, a_is_expected }
-    }
-}
-
-impl<'tcx> TypeRelation<'tcx> for Equate<'_, '_, 'tcx> {
-    fn tag(&self) -> &'static str {
-        "Equate"
-    }
-
-    fn tcx(&self) -> TyCtxt<'tcx> {
-        self.fields.tcx()
-    }
-
-    fn a_is_expected(&self) -> bool {
-        self.a_is_expected
-    }
-
-    fn relate_item_args(
-        &mut self,
-        _item_def_id: DefId,
-        a_arg: GenericArgsRef<'tcx>,
-        b_arg: GenericArgsRef<'tcx>,
-    ) -> RelateResult<'tcx, GenericArgsRef<'tcx>> {
-        // N.B., once we are equating types, we don't care about
-        // variance, so don't try to lookup the variance here. This
-        // also avoids some cycles (e.g., #41849) since looking up
-        // variance requires computing types which can require
-        // performing trait matching (which then performs equality
-        // unification).
-
-        relate::relate_args_invariantly(self, a_arg, b_arg)
-    }
-
-    fn relate_with_variance<T: Relate<'tcx>>(
-        &mut self,
-        _: ty::Variance,
-        _info: ty::VarianceDiagInfo<'tcx>,
-        a: T,
-        b: T,
-    ) -> RelateResult<'tcx, T> {
-        self.relate(a, b)
-    }
-
-    #[instrument(skip(self), level = "debug")]
-    fn tys(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) -> RelateResult<'tcx, Ty<'tcx>> {
-        if a == b {
-            return Ok(a);
-        }
-
-        trace!(a = ?a.kind(), b = ?b.kind());
-
-        let infcx = self.fields.infcx;
-
-        let a = infcx.inner.borrow_mut().type_variables().replace_if_possible(a);
-        let b = infcx.inner.borrow_mut().type_variables().replace_if_possible(b);
-
-        match (a.kind(), b.kind()) {
-            (&ty::Infer(TyVar(a_id)), &ty::Infer(TyVar(b_id))) => {
-                infcx.inner.borrow_mut().type_variables().equate(a_id, b_id);
-            }
-
-            (&ty::Infer(TyVar(a_vid)), _) => {
-                infcx.instantiate_ty_var(self, self.a_is_expected, a_vid, ty::Invariant, b)?;
-            }
-
-            (_, &ty::Infer(TyVar(b_vid))) => {
-                infcx.instantiate_ty_var(self, !self.a_is_expected, b_vid, ty::Invariant, a)?;
-            }
-
-            (&ty::Error(e), _) | (_, &ty::Error(e)) => {
-                infcx.set_tainted_by_errors(e);
-                return Ok(Ty::new_error(self.tcx(), e));
-            }
-
-            (
-                &ty::Alias(ty::Opaque, ty::AliasTy { def_id: a_def_id, .. }),
-                &ty::Alias(ty::Opaque, ty::AliasTy { def_id: b_def_id, .. }),
-            ) if a_def_id == b_def_id => {
-                infcx.super_combine_tys(self, a, b)?;
-            }
-            (&ty::Alias(ty::Opaque, ty::AliasTy { def_id, .. }), _)
-            | (_, &ty::Alias(ty::Opaque, ty::AliasTy { def_id, .. }))
-                if self.fields.define_opaque_types == DefineOpaqueTypes::Yes
-                    && def_id.is_local()
-                    && !self.fields.infcx.next_trait_solver() =>
-            {
-                self.fields.obligations.extend(
-                    infcx
-                        .handle_opaque_type(
-                            a,
-                            b,
-                            self.a_is_expected(),
-                            &self.fields.trace.cause,
-                            self.param_env(),
-                        )?
-                        .obligations,
-                );
-            }
-            _ => {
-                infcx.super_combine_tys(self, a, b)?;
-            }
-        }
-
-        Ok(a)
-    }
-
-    fn regions(
-        &mut self,
-        a: ty::Region<'tcx>,
-        b: ty::Region<'tcx>,
-    ) -> RelateResult<'tcx, ty::Region<'tcx>> {
-        debug!("{}.regions({:?}, {:?})", self.tag(), a, b);
-        let origin = SubregionOrigin::Subtype(Box::new(self.fields.trace.clone()));
-        self.fields
-            .infcx
-            .inner
-            .borrow_mut()
-            .unwrap_region_constraints()
-            .make_eqregion(origin, a, b);
-        Ok(a)
-    }
-
-    fn consts(
-        &mut self,
-        a: ty::Const<'tcx>,
-        b: ty::Const<'tcx>,
-    ) -> RelateResult<'tcx, ty::Const<'tcx>> {
-        self.fields.infcx.super_combine_consts(self, a, b)
-    }
-
-    fn binders<T>(
-        &mut self,
-        a: ty::Binder<'tcx, T>,
-        b: ty::Binder<'tcx, T>,
-    ) -> RelateResult<'tcx, ty::Binder<'tcx, T>>
-    where
-        T: Relate<'tcx>,
-    {
-        // A binder is equal to itself if it's structurally equal to itself
-        if a == b {
-            return Ok(a);
-        }
-
-        if let (Some(a), Some(b)) = (a.no_bound_vars(), b.no_bound_vars()) {
-            // Fast path for the common case.
-            self.relate(a, b)?;
-        } else {
-            // When equating binders, we check that there is a 1-to-1
-            // correspondence between the bound vars in both types.
-            //
-            // We do so by separately instantiating one of the binders with
-            // placeholders and the other with inference variables and then
-            // equating the instantiated types.
-            //
-            // We want `for<..> A == for<..> B` -- therefore we want
-            // `exists<..> A == for<..> B` and `exists<..> B == for<..> A`.
-
-            let span = self.fields.trace.cause.span;
-            let infcx = self.fields.infcx;
-
-            // Check if `exists<..> A == for<..> B`
-            infcx.enter_forall(b, |b| {
-                let a = infcx.instantiate_binder_with_fresh_vars(span, HigherRankedType, a);
-                self.relate(a, b)
-            })?;
-
-            // Check if `exists<..> B == for<..> A`.
-            infcx.enter_forall(a, |a| {
-                let b = infcx.instantiate_binder_with_fresh_vars(span, HigherRankedType, b);
-                self.relate(a, b)
-            })?;
-        }
-        Ok(a)
-    }
-}
-
-impl<'tcx> ObligationEmittingRelation<'tcx> for Equate<'_, '_, 'tcx> {
-    fn span(&self) -> Span {
-        self.fields.trace.span()
-    }
-
-    fn structurally_relate_aliases(&self) -> StructurallyRelateAliases {
-        self.structurally_relate_aliases
-    }
-
-    fn param_env(&self) -> ty::ParamEnv<'tcx> {
-        self.fields.param_env
-    }
-
-    fn register_predicates(&mut self, obligations: impl IntoIterator<Item: ty::ToPredicate<'tcx>>) {
-        self.fields.register_predicates(obligations);
-    }
-
-    fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>) {
-        self.fields.register_obligations(obligations);
-    }
-
-    fn alias_relate_direction(&self) -> ty::AliasRelationDirection {
-        ty::AliasRelationDirection::Equate
-    }
-}
diff --git a/compiler/rustc_infer/src/infer/relate/glb.rs b/compiler/rustc_infer/src/infer/relate/glb.rs
index 52a2f4c7347..9b77e6888b2 100644
--- a/compiler/rustc_infer/src/infer/relate/glb.rs
+++ b/compiler/rustc_infer/src/infer/relate/glb.rs
@@ -158,8 +158,12 @@ impl<'tcx> ObligationEmittingRelation<'tcx> for Glb<'_, '_, 'tcx> {
         self.fields.register_obligations(obligations);
     }
 
-    fn alias_relate_direction(&self) -> ty::AliasRelationDirection {
-        // FIXME(deferred_projection_equality): This isn't right, I think?
-        ty::AliasRelationDirection::Equate
+    fn register_type_relate_obligation(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) {
+        self.register_predicates([ty::Binder::dummy(ty::PredicateKind::AliasRelate(
+            a.into(),
+            b.into(),
+            // FIXME(deferred_projection_equality): This isn't right, I think?
+            ty::AliasRelationDirection::Equate,
+        ))]);
     }
 }
diff --git a/compiler/rustc_infer/src/infer/relate/lub.rs b/compiler/rustc_infer/src/infer/relate/lub.rs
index fa0da64ca65..db04e3231d6 100644
--- a/compiler/rustc_infer/src/infer/relate/lub.rs
+++ b/compiler/rustc_infer/src/infer/relate/lub.rs
@@ -158,8 +158,12 @@ impl<'tcx> ObligationEmittingRelation<'tcx> for Lub<'_, '_, 'tcx> {
         self.fields.register_obligations(obligations)
     }
 
-    fn alias_relate_direction(&self) -> ty::AliasRelationDirection {
-        // FIXME(deferred_projection_equality): This isn't right, I think?
-        ty::AliasRelationDirection::Equate
+    fn register_type_relate_obligation(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) {
+        self.register_predicates([ty::Binder::dummy(ty::PredicateKind::AliasRelate(
+            a.into(),
+            b.into(),
+            // FIXME(deferred_projection_equality): This isn't right, I think?
+            ty::AliasRelationDirection::Equate,
+        ))]);
     }
 }
diff --git a/compiler/rustc_infer/src/infer/relate/mod.rs b/compiler/rustc_infer/src/infer/relate/mod.rs
index 8619cc502ad..86a01130167 100644
--- a/compiler/rustc_infer/src/infer/relate/mod.rs
+++ b/compiler/rustc_infer/src/infer/relate/mod.rs
@@ -2,13 +2,12 @@
 //! (except for some relations used for diagnostics and heuristics in the compiler).
 
 pub(super) mod combine;
-mod equate;
 mod generalize;
 mod glb;
 mod higher_ranked;
 mod lattice;
 mod lub;
-mod sub;
+mod type_relating;
 
 /// Whether aliases should be related structurally or not. Used
 /// to adjust the behavior of generalization and combine.
diff --git a/compiler/rustc_infer/src/infer/relate/sub.rs b/compiler/rustc_infer/src/infer/relate/sub.rs
deleted file mode 100644
index 2cc8d0d3b10..00000000000
--- a/compiler/rustc_infer/src/infer/relate/sub.rs
+++ /dev/null
@@ -1,229 +0,0 @@
-use super::combine::CombineFields;
-use super::StructurallyRelateAliases;
-use crate::infer::{DefineOpaqueTypes, ObligationEmittingRelation, SubregionOrigin};
-use crate::traits::{Obligation, PredicateObligations};
-
-use rustc_middle::ty::relate::{Cause, Relate, RelateResult, TypeRelation};
-use rustc_middle::ty::visit::TypeVisitableExt;
-use rustc_middle::ty::TyVar;
-use rustc_middle::ty::{self, Ty, TyCtxt};
-use rustc_span::Span;
-use std::mem;
-
-/// Ensures `a` is made a subtype of `b`. Returns `a` on success.
-pub struct Sub<'combine, 'a, 'tcx> {
-    fields: &'combine mut CombineFields<'a, 'tcx>,
-    a_is_expected: bool,
-}
-
-impl<'combine, 'infcx, 'tcx> Sub<'combine, 'infcx, 'tcx> {
-    pub fn new(
-        f: &'combine mut CombineFields<'infcx, 'tcx>,
-        a_is_expected: bool,
-    ) -> Sub<'combine, 'infcx, 'tcx> {
-        Sub { fields: f, a_is_expected }
-    }
-
-    fn with_expected_switched<R, F: FnOnce(&mut Self) -> R>(&mut self, f: F) -> R {
-        self.a_is_expected = !self.a_is_expected;
-        let result = f(self);
-        self.a_is_expected = !self.a_is_expected;
-        result
-    }
-}
-
-impl<'tcx> TypeRelation<'tcx> for Sub<'_, '_, 'tcx> {
-    fn tag(&self) -> &'static str {
-        "Sub"
-    }
-
-    fn tcx(&self) -> TyCtxt<'tcx> {
-        self.fields.infcx.tcx
-    }
-
-    fn a_is_expected(&self) -> bool {
-        self.a_is_expected
-    }
-
-    fn with_cause<F, R>(&mut self, cause: Cause, f: F) -> R
-    where
-        F: FnOnce(&mut Self) -> R,
-    {
-        debug!("sub with_cause={:?}", cause);
-        let old_cause = mem::replace(&mut self.fields.cause, Some(cause));
-        let r = f(self);
-        debug!("sub old_cause={:?}", old_cause);
-        self.fields.cause = old_cause;
-        r
-    }
-
-    fn relate_with_variance<T: Relate<'tcx>>(
-        &mut self,
-        variance: ty::Variance,
-        _info: ty::VarianceDiagInfo<'tcx>,
-        a: T,
-        b: T,
-    ) -> RelateResult<'tcx, T> {
-        match variance {
-            ty::Invariant => {
-                self.fields.equate(StructurallyRelateAliases::No, self.a_is_expected).relate(a, b)
-            }
-            ty::Covariant => self.relate(a, b),
-            ty::Bivariant => Ok(a),
-            ty::Contravariant => self.with_expected_switched(|this| this.relate(b, a)),
-        }
-    }
-
-    #[instrument(skip(self), level = "debug")]
-    fn tys(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) -> RelateResult<'tcx, Ty<'tcx>> {
-        if a == b {
-            return Ok(a);
-        }
-
-        let infcx = self.fields.infcx;
-        let a = infcx.inner.borrow_mut().type_variables().replace_if_possible(a);
-        let b = infcx.inner.borrow_mut().type_variables().replace_if_possible(b);
-
-        match (a.kind(), b.kind()) {
-            (&ty::Infer(TyVar(_)), &ty::Infer(TyVar(_))) => {
-                // Shouldn't have any LBR here, so we can safely put
-                // this under a binder below without fear of accidental
-                // capture.
-                assert!(!a.has_escaping_bound_vars());
-                assert!(!b.has_escaping_bound_vars());
-
-                // can't make progress on `A <: B` if both A and B are
-                // type variables, so record an obligation.
-                self.fields.obligations.push(Obligation::new(
-                    self.tcx(),
-                    self.fields.trace.cause.clone(),
-                    self.fields.param_env,
-                    ty::Binder::dummy(ty::PredicateKind::Subtype(ty::SubtypePredicate {
-                        a_is_expected: self.a_is_expected,
-                        a,
-                        b,
-                    })),
-                ));
-
-                Ok(a)
-            }
-            (&ty::Infer(TyVar(a_vid)), _) => {
-                infcx.instantiate_ty_var(self, self.a_is_expected, a_vid, ty::Covariant, b)?;
-                Ok(a)
-            }
-            (_, &ty::Infer(TyVar(b_vid))) => {
-                infcx.instantiate_ty_var(self, !self.a_is_expected, b_vid, ty::Contravariant, a)?;
-                Ok(a)
-            }
-
-            (&ty::Error(e), _) | (_, &ty::Error(e)) => {
-                infcx.set_tainted_by_errors(e);
-                Ok(Ty::new_error(self.tcx(), e))
-            }
-
-            (
-                &ty::Alias(ty::Opaque, ty::AliasTy { def_id: a_def_id, .. }),
-                &ty::Alias(ty::Opaque, ty::AliasTy { def_id: b_def_id, .. }),
-            ) if a_def_id == b_def_id => {
-                self.fields.infcx.super_combine_tys(self, a, b)?;
-                Ok(a)
-            }
-            (&ty::Alias(ty::Opaque, ty::AliasTy { def_id, .. }), _)
-            | (_, &ty::Alias(ty::Opaque, ty::AliasTy { def_id, .. }))
-                if self.fields.define_opaque_types == DefineOpaqueTypes::Yes
-                    && def_id.is_local()
-                    && !self.fields.infcx.next_trait_solver() =>
-            {
-                self.fields.obligations.extend(
-                    infcx
-                        .handle_opaque_type(
-                            a,
-                            b,
-                            self.a_is_expected,
-                            &self.fields.trace.cause,
-                            self.param_env(),
-                        )?
-                        .obligations,
-                );
-                Ok(a)
-            }
-            _ => {
-                self.fields.infcx.super_combine_tys(self, a, b)?;
-                Ok(a)
-            }
-        }
-    }
-
-    fn regions(
-        &mut self,
-        a: ty::Region<'tcx>,
-        b: ty::Region<'tcx>,
-    ) -> RelateResult<'tcx, ty::Region<'tcx>> {
-        debug!("{}.regions({:?}, {:?}) self.cause={:?}", self.tag(), a, b, self.fields.cause);
-
-        // FIXME -- we have more fine-grained information available
-        // from the "cause" field, we could perhaps give more tailored
-        // error messages.
-        let origin = SubregionOrigin::Subtype(Box::new(self.fields.trace.clone()));
-        // Subtype(&'a u8, &'b u8) => Outlives('a: 'b) => SubRegion('b, 'a)
-        self.fields
-            .infcx
-            .inner
-            .borrow_mut()
-            .unwrap_region_constraints()
-            .make_subregion(origin, b, a);
-
-        Ok(a)
-    }
-
-    fn consts(
-        &mut self,
-        a: ty::Const<'tcx>,
-        b: ty::Const<'tcx>,
-    ) -> RelateResult<'tcx, ty::Const<'tcx>> {
-        self.fields.infcx.super_combine_consts(self, a, b)
-    }
-
-    fn binders<T>(
-        &mut self,
-        a: ty::Binder<'tcx, T>,
-        b: ty::Binder<'tcx, T>,
-    ) -> RelateResult<'tcx, ty::Binder<'tcx, T>>
-    where
-        T: Relate<'tcx>,
-    {
-        // A binder is always a subtype of itself if it's structurally equal to itself
-        if a == b {
-            return Ok(a);
-        }
-
-        self.fields.higher_ranked_sub(a, b, self.a_is_expected)?;
-        Ok(a)
-    }
-}
-
-impl<'tcx> ObligationEmittingRelation<'tcx> for Sub<'_, '_, 'tcx> {
-    fn span(&self) -> Span {
-        self.fields.trace.span()
-    }
-
-    fn structurally_relate_aliases(&self) -> StructurallyRelateAliases {
-        StructurallyRelateAliases::No
-    }
-
-    fn param_env(&self) -> ty::ParamEnv<'tcx> {
-        self.fields.param_env
-    }
-
-    fn register_predicates(&mut self, obligations: impl IntoIterator<Item: ty::ToPredicate<'tcx>>) {
-        self.fields.register_predicates(obligations);
-    }
-
-    fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>) {
-        self.fields.register_obligations(obligations);
-    }
-
-    fn alias_relate_direction(&self) -> ty::AliasRelationDirection {
-        ty::AliasRelationDirection::Subtype
-    }
-}
diff --git a/compiler/rustc_infer/src/infer/relate/type_relating.rs b/compiler/rustc_infer/src/infer/relate/type_relating.rs
new file mode 100644
index 00000000000..c4de324e6ff
--- /dev/null
+++ b/compiler/rustc_infer/src/infer/relate/type_relating.rs
@@ -0,0 +1,325 @@
+use super::combine::CombineFields;
+use crate::infer::{
+    DefineOpaqueTypes, ObligationEmittingRelation, StructurallyRelateAliases, SubregionOrigin,
+};
+use crate::traits::{Obligation, PredicateObligations};
+
+use rustc_middle::ty::relate::{Cause, Relate, RelateResult, TypeRelation};
+use rustc_middle::ty::visit::TypeVisitableExt;
+use rustc_middle::ty::TyVar;
+use rustc_middle::ty::{self, Ty, TyCtxt};
+use rustc_span::Span;
+use std::mem;
+
+/// Enforce that `a` is equal to or a subtype of `b`.
+pub struct TypeRelating<'combine, 'a, 'tcx> {
+    fields: &'combine mut CombineFields<'a, 'tcx>,
+    a_is_expected: bool,
+    structurally_relate_aliases: StructurallyRelateAliases,
+    ambient_variance: ty::Variance,
+}
+
+impl<'combine, 'infcx, 'tcx> TypeRelating<'combine, 'infcx, 'tcx> {
+    pub fn new(
+        f: &'combine mut CombineFields<'infcx, 'tcx>,
+        a_is_expected: bool,
+        structurally_relate_aliases: StructurallyRelateAliases,
+        ambient_variance: ty::Variance,
+    ) -> TypeRelating<'combine, 'infcx, 'tcx> {
+        TypeRelating { fields: f, a_is_expected, structurally_relate_aliases, ambient_variance }
+    }
+}
+
+impl<'tcx> TypeRelation<'tcx> for TypeRelating<'_, '_, 'tcx> {
+    fn tag(&self) -> &'static str {
+        "TypeRelating"
+    }
+
+    fn tcx(&self) -> TyCtxt<'tcx> {
+        self.fields.infcx.tcx
+    }
+
+    fn a_is_expected(&self) -> bool {
+        self.a_is_expected
+    }
+
+    fn with_cause<F, R>(&mut self, cause: Cause, f: F) -> R
+    where
+        F: FnOnce(&mut Self) -> R,
+    {
+        debug!("sub with_cause={:?}", cause);
+        let old_cause = mem::replace(&mut self.fields.cause, Some(cause));
+        let r = f(self);
+        debug!("sub old_cause={:?}", old_cause);
+        self.fields.cause = old_cause;
+        r
+    }
+
+    fn relate_with_variance<T: Relate<'tcx>>(
+        &mut self,
+        variance: ty::Variance,
+        _info: ty::VarianceDiagInfo<'tcx>,
+        a: T,
+        b: T,
+    ) -> RelateResult<'tcx, T> {
+        let old_ambient_variance = self.ambient_variance;
+        self.ambient_variance = self.ambient_variance.xform(variance);
+        debug!(?self.ambient_variance, "new ambient variance");
+
+        let r = if self.ambient_variance == ty::Bivariant { Ok(a) } else { self.relate(a, b) };
+
+        self.ambient_variance = old_ambient_variance;
+        r
+    }
+
+    #[instrument(skip(self), level = "debug")]
+    fn tys(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) -> RelateResult<'tcx, Ty<'tcx>> {
+        if a == b {
+            return Ok(a);
+        }
+
+        let infcx = self.fields.infcx;
+        let a = infcx.inner.borrow_mut().type_variables().replace_if_possible(a);
+        let b = infcx.inner.borrow_mut().type_variables().replace_if_possible(b);
+
+        match (a.kind(), b.kind()) {
+            (&ty::Infer(TyVar(a_id)), &ty::Infer(TyVar(b_id))) => {
+                // Shouldn't have any LBR here, so we can safely put
+                // this under a binder below without fear of accidental
+                // capture.
+                assert!(!a.has_escaping_bound_vars());
+                assert!(!b.has_escaping_bound_vars());
+
+                match self.ambient_variance {
+                    ty::Covariant => {
+                        // can't make progress on `A <: B` if both A and B are
+                        // type variables, so record an obligation.
+                        self.fields.obligations.push(Obligation::new(
+                            self.tcx(),
+                            self.fields.trace.cause.clone(),
+                            self.fields.param_env,
+                            ty::Binder::dummy(ty::PredicateKind::Subtype(ty::SubtypePredicate {
+                                a_is_expected: self.a_is_expected,
+                                a,
+                                b,
+                            })),
+                        ));
+                    }
+                    ty::Contravariant => {
+                        // can't make progress on `B <: A` if both A and B are
+                        // type variables, so record an obligation.
+                        self.fields.obligations.push(Obligation::new(
+                            self.tcx(),
+                            self.fields.trace.cause.clone(),
+                            self.fields.param_env,
+                            ty::Binder::dummy(ty::PredicateKind::Subtype(ty::SubtypePredicate {
+                                a_is_expected: !self.a_is_expected,
+                                a: b,
+                                b: a,
+                            })),
+                        ));
+                    }
+                    ty::Invariant => {
+                        infcx.inner.borrow_mut().type_variables().equate(a_id, b_id);
+                    }
+                    ty::Bivariant => {
+                        unreachable!("Expected bivariance to be handled in relate_with_variance")
+                    }
+                }
+            }
+
+            (&ty::Infer(TyVar(a_vid)), _) => {
+                infcx.instantiate_ty_var(
+                    self,
+                    self.a_is_expected,
+                    a_vid,
+                    self.ambient_variance,
+                    b,
+                )?;
+            }
+            (_, &ty::Infer(TyVar(b_vid))) => {
+                infcx.instantiate_ty_var(
+                    self,
+                    !self.a_is_expected,
+                    b_vid,
+                    self.ambient_variance.xform(ty::Contravariant),
+                    a,
+                )?;
+            }
+
+            (&ty::Error(e), _) | (_, &ty::Error(e)) => {
+                infcx.set_tainted_by_errors(e);
+                return Ok(Ty::new_error(self.tcx(), e));
+            }
+
+            (
+                &ty::Alias(ty::Opaque, ty::AliasTy { def_id: a_def_id, .. }),
+                &ty::Alias(ty::Opaque, ty::AliasTy { def_id: b_def_id, .. }),
+            ) if a_def_id == b_def_id => {
+                infcx.super_combine_tys(self, a, b)?;
+            }
+
+            (&ty::Alias(ty::Opaque, ty::AliasTy { def_id, .. }), _)
+            | (_, &ty::Alias(ty::Opaque, ty::AliasTy { def_id, .. }))
+                if self.fields.define_opaque_types == DefineOpaqueTypes::Yes
+                    && def_id.is_local()
+                    && !infcx.next_trait_solver() =>
+            {
+                self.fields.obligations.extend(
+                    infcx
+                        .handle_opaque_type(
+                            a,
+                            b,
+                            self.a_is_expected,
+                            &self.fields.trace.cause,
+                            self.param_env(),
+                        )?
+                        .obligations,
+                );
+            }
+
+            _ => {
+                infcx.super_combine_tys(self, a, b)?;
+            }
+        }
+
+        Ok(a)
+    }
+
+    fn regions(
+        &mut self,
+        a: ty::Region<'tcx>,
+        b: ty::Region<'tcx>,
+    ) -> RelateResult<'tcx, ty::Region<'tcx>> {
+        debug!("{}.regions({:?}, {:?}) self.cause={:?}", self.tag(), a, b, self.fields.cause);
+
+        // FIXME -- we have more fine-grained information available
+        // from the "cause" field, we could perhaps give more tailored
+        // error messages.
+        let origin = SubregionOrigin::Subtype(Box::new(self.fields.trace.clone()));
+
+        match self.ambient_variance {
+            // Subtype(&'a u8, &'b u8) => Outlives('a: 'b) => SubRegion('b, 'a)
+            ty::Covariant => {
+                self.fields
+                    .infcx
+                    .inner
+                    .borrow_mut()
+                    .unwrap_region_constraints()
+                    .make_subregion(origin, b, a);
+            }
+            // Suptype(&'a u8, &'b u8) => Outlives('b: 'a) => SubRegion('a, 'b)
+            ty::Contravariant => {
+                self.fields
+                    .infcx
+                    .inner
+                    .borrow_mut()
+                    .unwrap_region_constraints()
+                    .make_subregion(origin, a, b);
+            }
+            ty::Invariant => {
+                // The order of `make_eqregion` apparently matters.
+                self.fields
+                    .infcx
+                    .inner
+                    .borrow_mut()
+                    .unwrap_region_constraints()
+                    .make_eqregion(origin, a, b);
+            }
+            ty::Bivariant => {
+                unreachable!("Expected bivariance to be handled in relate_with_variance")
+            }
+        }
+
+        Ok(a)
+    }
+
+    fn consts(
+        &mut self,
+        a: ty::Const<'tcx>,
+        b: ty::Const<'tcx>,
+    ) -> RelateResult<'tcx, ty::Const<'tcx>> {
+        self.fields.infcx.super_combine_consts(self, a, b)
+    }
+
+    fn binders<T>(
+        &mut self,
+        a: ty::Binder<'tcx, T>,
+        b: ty::Binder<'tcx, T>,
+    ) -> RelateResult<'tcx, ty::Binder<'tcx, T>>
+    where
+        T: Relate<'tcx>,
+    {
+        if a == b {
+            // Do nothing
+        } else if let Some(a) = a.no_bound_vars()
+            && let Some(b) = b.no_bound_vars()
+        {
+            self.relate(a, b)?;
+        } else {
+            match self.ambient_variance {
+                ty::Covariant => {
+                    self.fields.higher_ranked_sub(a, b, self.a_is_expected)?;
+                }
+                ty::Contravariant => {
+                    self.fields.higher_ranked_sub(b, a, !self.a_is_expected)?;
+                }
+                ty::Invariant => {
+                    self.fields.higher_ranked_sub(a, b, self.a_is_expected)?;
+                    self.fields.higher_ranked_sub(b, a, !self.a_is_expected)?;
+                }
+                ty::Bivariant => {
+                    unreachable!("Expected bivariance to be handled in relate_with_variance")
+                }
+            }
+        }
+
+        Ok(a)
+    }
+}
+
+impl<'tcx> ObligationEmittingRelation<'tcx> for TypeRelating<'_, '_, 'tcx> {
+    fn span(&self) -> Span {
+        self.fields.trace.span()
+    }
+
+    fn param_env(&self) -> ty::ParamEnv<'tcx> {
+        self.fields.param_env
+    }
+
+    fn structurally_relate_aliases(&self) -> StructurallyRelateAliases {
+        self.structurally_relate_aliases
+    }
+
+    fn register_predicates(&mut self, obligations: impl IntoIterator<Item: ty::ToPredicate<'tcx>>) {
+        self.fields.register_predicates(obligations);
+    }
+
+    fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>) {
+        self.fields.register_obligations(obligations);
+    }
+
+    fn register_type_relate_obligation(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) {
+        self.register_predicates([ty::Binder::dummy(match self.ambient_variance {
+            ty::Variance::Covariant => ty::PredicateKind::AliasRelate(
+                a.into(),
+                b.into(),
+                ty::AliasRelationDirection::Subtype,
+            ),
+            // a :> b is b <: a
+            ty::Variance::Contravariant => ty::PredicateKind::AliasRelate(
+                b.into(),
+                a.into(),
+                ty::AliasRelationDirection::Subtype,
+            ),
+            ty::Variance::Invariant => ty::PredicateKind::AliasRelate(
+                a.into(),
+                b.into(),
+                ty::AliasRelationDirection::Equate,
+            ),
+            ty::Variance::Bivariant => {
+                unreachable!("Expected bivariance to be handled in relate_with_variance")
+            }
+        })]);
+    }
+}

From 801dd1d061bc6db31547d45c64a32bbd9b4f6124 Mon Sep 17 00:00:00 2001
From: Michael Goulet <michael@errs.io>
Date: Mon, 26 Feb 2024 15:58:53 +0000
Subject: [PATCH 3/7] Remove cause

---
 compiler/rustc_infer/src/infer/mod.rs         |  1 -
 .../rustc_infer/src/infer/relate/combine.rs   |  1 -
 .../src/infer/relate/type_relating.rs         | 24 ++-----------------
 compiler/rustc_middle/src/ty/relate.rs        | 17 +++++--------
 4 files changed, 8 insertions(+), 35 deletions(-)

diff --git a/compiler/rustc_infer/src/infer/mod.rs b/compiler/rustc_infer/src/infer/mod.rs
index 6f52ded3551..b99ea35c22c 100644
--- a/compiler/rustc_infer/src/infer/mod.rs
+++ b/compiler/rustc_infer/src/infer/mod.rs
@@ -836,7 +836,6 @@ impl<'tcx> InferCtxt<'tcx> {
         CombineFields {
             infcx: self,
             trace,
-            cause: None,
             param_env,
             obligations: PredicateObligations::new(),
             define_opaque_types,
diff --git a/compiler/rustc_infer/src/infer/relate/combine.rs b/compiler/rustc_infer/src/infer/relate/combine.rs
index 0a550660f94..749c50b57b5 100644
--- a/compiler/rustc_infer/src/infer/relate/combine.rs
+++ b/compiler/rustc_infer/src/infer/relate/combine.rs
@@ -42,7 +42,6 @@ use rustc_span::Span;
 pub struct CombineFields<'infcx, 'tcx> {
     pub infcx: &'infcx InferCtxt<'tcx>,
     pub trace: TypeTrace<'tcx>,
-    pub cause: Option<ty::relate::Cause>,
     pub param_env: ty::ParamEnv<'tcx>,
     pub obligations: PredicateObligations<'tcx>,
     pub define_opaque_types: DefineOpaqueTypes,
diff --git a/compiler/rustc_infer/src/infer/relate/type_relating.rs b/compiler/rustc_infer/src/infer/relate/type_relating.rs
index c4de324e6ff..ddc4bf9a514 100644
--- a/compiler/rustc_infer/src/infer/relate/type_relating.rs
+++ b/compiler/rustc_infer/src/infer/relate/type_relating.rs
@@ -4,12 +4,10 @@ use crate::infer::{
 };
 use crate::traits::{Obligation, PredicateObligations};
 
-use rustc_middle::ty::relate::{Cause, Relate, RelateResult, TypeRelation};
-use rustc_middle::ty::visit::TypeVisitableExt;
+use rustc_middle::ty::relate::{Relate, RelateResult, TypeRelation};
 use rustc_middle::ty::TyVar;
 use rustc_middle::ty::{self, Ty, TyCtxt};
 use rustc_span::Span;
-use std::mem;
 
 /// Enforce that `a` is equal to or a subtype of `b`.
 pub struct TypeRelating<'combine, 'a, 'tcx> {
@@ -43,18 +41,6 @@ impl<'tcx> TypeRelation<'tcx> for TypeRelating<'_, '_, 'tcx> {
         self.a_is_expected
     }
 
-    fn with_cause<F, R>(&mut self, cause: Cause, f: F) -> R
-    where
-        F: FnOnce(&mut Self) -> R,
-    {
-        debug!("sub with_cause={:?}", cause);
-        let old_cause = mem::replace(&mut self.fields.cause, Some(cause));
-        let r = f(self);
-        debug!("sub old_cause={:?}", old_cause);
-        self.fields.cause = old_cause;
-        r
-    }
-
     fn relate_with_variance<T: Relate<'tcx>>(
         &mut self,
         variance: ty::Variance,
@@ -84,12 +70,6 @@ impl<'tcx> TypeRelation<'tcx> for TypeRelating<'_, '_, 'tcx> {
 
         match (a.kind(), b.kind()) {
             (&ty::Infer(TyVar(a_id)), &ty::Infer(TyVar(b_id))) => {
-                // Shouldn't have any LBR here, so we can safely put
-                // this under a binder below without fear of accidental
-                // capture.
-                assert!(!a.has_escaping_bound_vars());
-                assert!(!b.has_escaping_bound_vars());
-
                 match self.ambient_variance {
                     ty::Covariant => {
                         // can't make progress on `A <: B` if both A and B are
@@ -191,7 +171,7 @@ impl<'tcx> TypeRelation<'tcx> for TypeRelating<'_, '_, 'tcx> {
         a: ty::Region<'tcx>,
         b: ty::Region<'tcx>,
     ) -> RelateResult<'tcx, ty::Region<'tcx>> {
-        debug!("{}.regions({:?}, {:?}) self.cause={:?}", self.tag(), a, b, self.fields.cause);
+        debug!("{}.regions({:?}, {:?})", self.tag(), a, b);
 
         // FIXME -- we have more fine-grained information available
         // from the "cause" field, we could perhaps give more tailored
diff --git a/compiler/rustc_middle/src/ty/relate.rs b/compiler/rustc_middle/src/ty/relate.rs
index 303f285b00c..abd39914cfd 100644
--- a/compiler/rustc_middle/src/ty/relate.rs
+++ b/compiler/rustc_middle/src/ty/relate.rs
@@ -30,13 +30,6 @@ pub trait TypeRelation<'tcx>: Sized {
     /// relation. Just affects error messages.
     fn a_is_expected(&self) -> bool;
 
-    fn with_cause<F, R>(&mut self, _cause: Cause, f: F) -> R
-    where
-        F: FnOnce(&mut Self) -> R,
-    {
-        f(self)
-    }
-
     /// Generic relation routine suitable for most anything.
     fn relate<T: Relate<'tcx>>(&mut self, a: T, b: T) -> RelateResult<'tcx, T> {
         Relate::relate(self, a, b)
@@ -452,10 +445,12 @@ pub fn structurally_relate_tys<'tcx, R: TypeRelation<'tcx>>(
         (&ty::Dynamic(a_obj, a_region, a_repr), &ty::Dynamic(b_obj, b_region, b_repr))
             if a_repr == b_repr =>
         {
-            let region_bound = relation.with_cause(Cause::ExistentialRegionBound, |relation| {
-                relation.relate(a_region, b_region)
-            })?;
-            Ok(Ty::new_dynamic(tcx, relation.relate(a_obj, b_obj)?, region_bound, a_repr))
+            Ok(Ty::new_dynamic(
+                tcx,
+                relation.relate(a_obj, b_obj)?,
+                relation.relate(a_region, b_region)?,
+                a_repr,
+            ))
         }
 
         (&ty::Coroutine(a_id, a_args), &ty::Coroutine(b_id, b_args)) if a_id == b_id => {

From 61daee66a89f52eb5fa6f103d5ac8dbcaa885709 Mon Sep 17 00:00:00 2001
From: Michael Goulet <michael@errs.io>
Date: Mon, 26 Feb 2024 19:52:52 +0000
Subject: [PATCH 4/7] Get rid of some sub_exp and eq_exp

---
 compiler/rustc_borrowck/src/type_check/mod.rs |  1 -
 compiler/rustc_hir_typeck/src/coercion.rs     | 38 +++++++++----------
 compiler/rustc_infer/src/infer/mod.rs         |  6 ++-
 .../rustc_infer/src/infer/opaque_types.rs     | 12 ++----
 .../src/solve/eval_ctxt/mod.rs                |  1 -
 .../src/traits/engine.rs                      | 18 ---------
 .../error_reporting/type_err_ctxt_ext.rs      | 13 +++++--
 .../opaque-type-unsatisfied-bound.rs          |  6 +--
 .../opaque-type-unsatisfied-bound.stderr      |  6 +--
 .../opaque-type-unsatisfied-fn-bound.rs       |  2 +-
 .../opaque-type-unsatisfied-fn-bound.stderr   |  2 +-
 .../itiat-allow-nested-closures.bad.stderr    |  3 ++
 12 files changed, 48 insertions(+), 60 deletions(-)

diff --git a/compiler/rustc_borrowck/src/type_check/mod.rs b/compiler/rustc_borrowck/src/type_check/mod.rs
index 75cc28bcab0..26e1e24d1a1 100644
--- a/compiler/rustc_borrowck/src/type_check/mod.rs
+++ b/compiler/rustc_borrowck/src/type_check/mod.rs
@@ -1066,7 +1066,6 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
                             &cause,
                             param_env,
                             hidden_ty.ty,
-                            true,
                             &mut obligations,
                         )?;
 
diff --git a/compiler/rustc_hir_typeck/src/coercion.rs b/compiler/rustc_hir_typeck/src/coercion.rs
index 179255993b4..792359c9dda 100644
--- a/compiler/rustc_hir_typeck/src/coercion.rs
+++ b/compiler/rustc_hir_typeck/src/coercion.rs
@@ -1493,6 +1493,21 @@ impl<'tcx, 'exprs, E: AsCoercionSite> CoerceMany<'tcx, 'exprs, E> {
             return;
         }
 
+        let (expected, found) = if label_expression_as_expected {
+            // In the case where this is a "forced unit", like
+            // `break`, we want to call the `()` "expected"
+            // since it is implied by the syntax.
+            // (Note: not all force-units work this way.)"
+            (expression_ty, self.merged_ty())
+        } else {
+            // Otherwise, the "expected" type for error
+            // reporting is the current unification type,
+            // which is basically the LUB of the expressions
+            // we've seen so far (combined with the expected
+            // type)
+            (self.merged_ty(), expression_ty)
+        };
+
         // Handle the actual type unification etc.
         let result = if let Some(expression) = expression {
             if self.pushed == 0 {
@@ -1540,12 +1555,11 @@ impl<'tcx, 'exprs, E: AsCoercionSite> CoerceMany<'tcx, 'exprs, E> {
             // Another example is `break` with no argument expression.
             assert!(expression_ty.is_unit(), "if let hack without unit type");
             fcx.at(cause, fcx.param_env)
-                // needed for tests/ui/type-alias-impl-trait/issue-65679-inst-opaque-ty-from-val-twice.rs
-                .eq_exp(
+                .eq(
+                    // needed for tests/ui/type-alias-impl-trait/issue-65679-inst-opaque-ty-from-val-twice.rs
                     DefineOpaqueTypes::Yes,
-                    label_expression_as_expected,
-                    expression_ty,
-                    self.merged_ty(),
+                    expected,
+                    found,
                 )
                 .map(|infer_ok| {
                     fcx.register_infer_ok_obligations(infer_ok);
@@ -1579,20 +1593,6 @@ impl<'tcx, 'exprs, E: AsCoercionSite> CoerceMany<'tcx, 'exprs, E> {
                 fcx.set_tainted_by_errors(
                     fcx.dcx().span_delayed_bug(cause.span, "coercion error but no error emitted"),
                 );
-                let (expected, found) = if label_expression_as_expected {
-                    // In the case where this is a "forced unit", like
-                    // `break`, we want to call the `()` "expected"
-                    // since it is implied by the syntax.
-                    // (Note: not all force-units work this way.)"
-                    (expression_ty, self.merged_ty())
-                } else {
-                    // Otherwise, the "expected" type for error
-                    // reporting is the current unification type,
-                    // which is basically the LUB of the expressions
-                    // we've seen so far (combined with the expected
-                    // type)
-                    (self.merged_ty(), expression_ty)
-                };
                 let (expected, found) = fcx.resolve_vars_if_possible((expected, found));
 
                 let mut err;
diff --git a/compiler/rustc_infer/src/infer/mod.rs b/compiler/rustc_infer/src/infer/mod.rs
index b99ea35c22c..73a25637e1a 100644
--- a/compiler/rustc_infer/src/infer/mod.rs
+++ b/compiler/rustc_infer/src/infer/mod.rs
@@ -1032,7 +1032,11 @@ impl<'tcx> InferCtxt<'tcx> {
         }
 
         self.enter_forall(predicate, |ty::SubtypePredicate { a_is_expected, a, b }| {
-            Ok(self.at(cause, param_env).sub_exp(DefineOpaqueTypes::No, a_is_expected, a, b))
+            if a_is_expected {
+                Ok(self.at(cause, param_env).sub(DefineOpaqueTypes::No, a, b))
+            } else {
+                Ok(self.at(cause, param_env).sup(DefineOpaqueTypes::No, b, a))
+            }
         })
     }
 
diff --git a/compiler/rustc_infer/src/infer/opaque_types.rs b/compiler/rustc_infer/src/infer/opaque_types.rs
index ec674407e52..d381c77ec66 100644
--- a/compiler/rustc_infer/src/infer/opaque_types.rs
+++ b/compiler/rustc_infer/src/infer/opaque_types.rs
@@ -102,7 +102,7 @@ impl<'tcx> InferCtxt<'tcx> {
             return Ok(InferOk { value: (), obligations: vec![] });
         }
         let (a, b) = if a_is_expected { (a, b) } else { (b, a) };
-        let process = |a: Ty<'tcx>, b: Ty<'tcx>, a_is_expected| match *a.kind() {
+        let process = |a: Ty<'tcx>, b: Ty<'tcx>| match *a.kind() {
             ty::Alias(ty::Opaque, ty::AliasTy { def_id, args, .. }) if def_id.is_local() => {
                 let def_id = def_id.expect_local();
                 match self.defining_use_anchor {
@@ -169,14 +169,13 @@ impl<'tcx> InferCtxt<'tcx> {
                     cause.clone(),
                     param_env,
                     b,
-                    a_is_expected,
                 ))
             }
             _ => None,
         };
-        if let Some(res) = process(a, b, true) {
+        if let Some(res) = process(a, b) {
             res
-        } else if let Some(res) = process(b, a, false) {
+        } else if let Some(res) = process(b, a) {
             res
         } else {
             let (a, b) = self.resolve_vars_if_possible((a, b));
@@ -520,7 +519,6 @@ impl<'tcx> InferCtxt<'tcx> {
         cause: ObligationCause<'tcx>,
         param_env: ty::ParamEnv<'tcx>,
         hidden_ty: Ty<'tcx>,
-        a_is_expected: bool,
     ) -> InferResult<'tcx, ()> {
         let mut obligations = Vec::new();
 
@@ -529,7 +527,6 @@ impl<'tcx> InferCtxt<'tcx> {
             &cause,
             param_env,
             hidden_ty,
-            a_is_expected,
             &mut obligations,
         )?;
 
@@ -558,7 +555,6 @@ impl<'tcx> InferCtxt<'tcx> {
         cause: &ObligationCause<'tcx>,
         param_env: ty::ParamEnv<'tcx>,
         hidden_ty: Ty<'tcx>,
-        a_is_expected: bool,
         obligations: &mut Vec<PredicateObligation<'tcx>>,
     ) -> Result<(), TypeError<'tcx>> {
         // Ideally, we'd get the span where *this specific `ty` came
@@ -586,7 +582,7 @@ impl<'tcx> InferCtxt<'tcx> {
             if let Some(prev) = prev {
                 obligations.extend(
                     self.at(cause, param_env)
-                        .eq_exp(DefineOpaqueTypes::Yes, a_is_expected, prev, hidden_ty)?
+                        .eq(DefineOpaqueTypes::Yes, prev, hidden_ty)?
                         .obligations,
                 );
             }
diff --git a/compiler/rustc_trait_selection/src/solve/eval_ctxt/mod.rs b/compiler/rustc_trait_selection/src/solve/eval_ctxt/mod.rs
index df4dcaff1e7..4a86f708632 100644
--- a/compiler/rustc_trait_selection/src/solve/eval_ctxt/mod.rs
+++ b/compiler/rustc_trait_selection/src/solve/eval_ctxt/mod.rs
@@ -904,7 +904,6 @@ impl<'tcx> EvalCtxt<'_, 'tcx> {
             &ObligationCause::dummy(),
             param_env,
             hidden_ty,
-            true,
             &mut obligations,
         )?;
         self.add_goals(GoalSource::Misc, obligations.into_iter().map(|o| o.into()));
diff --git a/compiler/rustc_trait_selection/src/traits/engine.rs b/compiler/rustc_trait_selection/src/traits/engine.rs
index e789e9c2b6e..9fbec174ce8 100644
--- a/compiler/rustc_trait_selection/src/traits/engine.rs
+++ b/compiler/rustc_trait_selection/src/traits/engine.rs
@@ -116,24 +116,6 @@ impl<'a, 'tcx> ObligationCtxt<'a, 'tcx> {
         self.infcx.at(cause, param_env).deeply_normalize(value, &mut **self.engine.borrow_mut())
     }
 
-    /// Makes `expected <: actual`.
-    pub fn eq_exp<T>(
-        &self,
-        cause: &ObligationCause<'tcx>,
-        param_env: ty::ParamEnv<'tcx>,
-        a_is_expected: bool,
-        a: T,
-        b: T,
-    ) -> Result<(), TypeError<'tcx>>
-    where
-        T: ToTrace<'tcx>,
-    {
-        self.infcx
-            .at(cause, param_env)
-            .eq_exp(DefineOpaqueTypes::Yes, a_is_expected, a, b)
-            .map(|infer_ok| self.register_infer_ok_obligations(infer_ok))
-    }
-
     pub fn eq<T: ToTrace<'tcx>>(
         &self,
         cause: &ObligationCause<'tcx>,
diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/type_err_ctxt_ext.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/type_err_ctxt_ext.rs
index 3275a4f3527..fa8edd11594 100644
--- a/compiler/rustc_trait_selection/src/traits/error_reporting/type_err_ctxt_ext.rs
+++ b/compiler/rustc_trait_selection/src/traits/error_reporting/type_err_ctxt_ext.rs
@@ -1528,6 +1528,12 @@ impl<'tcx> TypeErrCtxt<'_, 'tcx> {
                         | ObligationCauseCode::Coercion { .. }
                 );
 
+                let (expected, actual) = if is_normalized_term_expected {
+                    (normalized_term, data.term)
+                } else {
+                    (data.term, normalized_term)
+                };
+
                 // constrain inference variables a bit more to nested obligations from normalize so
                 // we can have more helpful errors.
                 //
@@ -1535,12 +1541,11 @@ impl<'tcx> TypeErrCtxt<'_, 'tcx> {
                 // since the normalization is just done to improve the error message.
                 let _ = ocx.select_where_possible();
 
-                if let Err(new_err) = ocx.eq_exp(
+                if let Err(new_err) = ocx.eq(
                     &obligation.cause,
                     obligation.param_env,
-                    is_normalized_term_expected,
-                    normalized_term,
-                    data.term,
+                    expected,
+                    actual,
                 ) {
                     (Some((data, is_normalized_term_expected, normalized_term, data.term)), new_err)
                 } else {
diff --git a/tests/ui/traits/negative-bounds/opaque-type-unsatisfied-bound.rs b/tests/ui/traits/negative-bounds/opaque-type-unsatisfied-bound.rs
index 05b167326d4..2607f047024 100644
--- a/tests/ui/traits/negative-bounds/opaque-type-unsatisfied-bound.rs
+++ b/tests/ui/traits/negative-bounds/opaque-type-unsatisfied-bound.rs
@@ -13,8 +13,8 @@ fn main() {
 }
 
 fn weird0() -> impl Sized + !Sized {}
-//~^ ERROR type mismatch resolving `() == impl !Sized + Sized`
+//~^ ERROR type mismatch resolving `impl !Sized + Sized == ()`
 fn weird1() -> impl !Sized + Sized {}
-//~^ ERROR type mismatch resolving `() == impl !Sized + Sized`
+//~^ ERROR type mismatch resolving `impl !Sized + Sized == ()`
 fn weird2() -> impl !Sized {}
-//~^ ERROR type mismatch resolving `() == impl !Sized`
+//~^ ERROR type mismatch resolving `impl !Sized == ()`
diff --git a/tests/ui/traits/negative-bounds/opaque-type-unsatisfied-bound.stderr b/tests/ui/traits/negative-bounds/opaque-type-unsatisfied-bound.stderr
index d803e56e817..ceaf42431fe 100644
--- a/tests/ui/traits/negative-bounds/opaque-type-unsatisfied-bound.stderr
+++ b/tests/ui/traits/negative-bounds/opaque-type-unsatisfied-bound.stderr
@@ -1,16 +1,16 @@
-error[E0271]: type mismatch resolving `() == impl !Sized + Sized`
+error[E0271]: type mismatch resolving `impl !Sized + Sized == ()`
   --> $DIR/opaque-type-unsatisfied-bound.rs:15:16
    |
 LL | fn weird0() -> impl Sized + !Sized {}
    |                ^^^^^^^^^^^^^^^^^^^ types differ
 
-error[E0271]: type mismatch resolving `() == impl !Sized + Sized`
+error[E0271]: type mismatch resolving `impl !Sized + Sized == ()`
   --> $DIR/opaque-type-unsatisfied-bound.rs:17:16
    |
 LL | fn weird1() -> impl !Sized + Sized {}
    |                ^^^^^^^^^^^^^^^^^^^ types differ
 
-error[E0271]: type mismatch resolving `() == impl !Sized`
+error[E0271]: type mismatch resolving `impl !Sized == ()`
   --> $DIR/opaque-type-unsatisfied-bound.rs:19:16
    |
 LL | fn weird2() -> impl !Sized {}
diff --git a/tests/ui/traits/negative-bounds/opaque-type-unsatisfied-fn-bound.rs b/tests/ui/traits/negative-bounds/opaque-type-unsatisfied-fn-bound.rs
index d714d781c88..9951826a846 100644
--- a/tests/ui/traits/negative-bounds/opaque-type-unsatisfied-fn-bound.rs
+++ b/tests/ui/traits/negative-bounds/opaque-type-unsatisfied-fn-bound.rs
@@ -3,6 +3,6 @@
 #![feature(negative_bounds, unboxed_closures)]
 
 fn produce() -> impl !Fn<(u32,)> {}
-//~^ ERROR type mismatch resolving `() == impl !Fn<(u32,)>`
+//~^ ERROR type mismatch resolving `impl !Fn<(u32,)> == ()`
 
 fn main() {}
diff --git a/tests/ui/traits/negative-bounds/opaque-type-unsatisfied-fn-bound.stderr b/tests/ui/traits/negative-bounds/opaque-type-unsatisfied-fn-bound.stderr
index 1fd30410b00..e1b84e0df7a 100644
--- a/tests/ui/traits/negative-bounds/opaque-type-unsatisfied-fn-bound.stderr
+++ b/tests/ui/traits/negative-bounds/opaque-type-unsatisfied-fn-bound.stderr
@@ -1,4 +1,4 @@
-error[E0271]: type mismatch resolving `() == impl !Fn<(u32,)>`
+error[E0271]: type mismatch resolving `impl !Fn<(u32,)> == ()`
   --> $DIR/opaque-type-unsatisfied-fn-bound.rs:5:17
    |
 LL | fn produce() -> impl !Fn<(u32,)> {}
diff --git a/tests/ui/type-alias-impl-trait/itiat-allow-nested-closures.bad.stderr b/tests/ui/type-alias-impl-trait/itiat-allow-nested-closures.bad.stderr
index 4acc47eaef2..9d38e8f36b1 100644
--- a/tests/ui/type-alias-impl-trait/itiat-allow-nested-closures.bad.stderr
+++ b/tests/ui/type-alias-impl-trait/itiat-allow-nested-closures.bad.stderr
@@ -8,6 +8,9 @@ LL |         let _: i32 = closure();
    |                ---   ^^^^^^^^^ expected `i32`, found opaque type
    |                |
    |                expected due to this
+   |
+   = note:     expected type `i32`
+           found opaque type `<() as Foo>::Assoc`
 
 error[E0308]: mismatched types
   --> $DIR/itiat-allow-nested-closures.rs:22:9

From 04e22627f5071b548c38bf0963d03f1115416aa9 Mon Sep 17 00:00:00 2001
From: Michael Goulet <michael@errs.io>
Date: Mon, 26 Feb 2024 20:42:09 +0000
Subject: [PATCH 5/7] Remove a_is_expected from combine relations

---
 .../src/type_check/relate_tys.rs              |  1 -
 compiler/rustc_infer/src/infer/at.rs          | 93 ++++++-------------
 .../rustc_infer/src/infer/opaque_types.rs     |  8 +-
 .../rustc_infer/src/infer/relate/combine.rs   | 19 ++--
 compiler/rustc_infer/src/infer/relate/glb.rs  | 18 ++--
 .../src/infer/relate/higher_ranked.rs         |  8 +-
 compiler/rustc_infer/src/infer/relate/lub.rs  | 18 ++--
 .../src/infer/relate/type_relating.rs         | 36 +++----
 .../error_reporting/type_err_ctxt_ext.rs      |  9 +-
 9 files changed, 75 insertions(+), 135 deletions(-)

diff --git a/compiler/rustc_borrowck/src/type_check/relate_tys.rs b/compiler/rustc_borrowck/src/type_check/relate_tys.rs
index fd6c29379a3..5c5274d7d86 100644
--- a/compiler/rustc_borrowck/src/type_check/relate_tys.rs
+++ b/compiler/rustc_borrowck/src/type_check/relate_tys.rs
@@ -120,7 +120,6 @@ impl<'me, 'bccx, 'tcx> NllTypeRelating<'me, 'bccx, 'tcx> {
     fn relate_opaques(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) -> RelateResult<'tcx, ()> {
         let infcx = self.type_checker.infcx;
         debug_assert!(!infcx.next_trait_solver());
-        let (a, b) = if self.a_is_expected() { (a, b) } else { (b, a) };
         // `handle_opaque_type` cannot handle subtyping, so to support subtyping
         // we instead eagerly generalize here. This is a bit of a mess but will go
         // away once we're using the new solver.
diff --git a/compiler/rustc_infer/src/infer/at.rs b/compiler/rustc_infer/src/infer/at.rs
index a086d89b933..94088216c02 100644
--- a/compiler/rustc_infer/src/infer/at.rs
+++ b/compiler/rustc_infer/src/infer/at.rs
@@ -49,7 +49,6 @@ pub struct At<'a, 'tcx> {
 
 pub struct Trace<'a, 'tcx> {
     at: At<'a, 'tcx>,
-    a_is_expected: bool,
     trace: TypeTrace<'tcx>,
 }
 
@@ -105,23 +104,6 @@ pub trait ToTrace<'tcx>: Relate<'tcx> + Copy {
 }
 
 impl<'a, 'tcx> At<'a, 'tcx> {
-    /// Makes `a <: b`, where `a` may or may not be expected.
-    ///
-    /// See [`At::trace_exp`] and [`Trace::sub`] for a version of
-    /// this method that only requires `T: Relate<'tcx>`
-    pub fn sub_exp<T>(
-        self,
-        define_opaque_types: DefineOpaqueTypes,
-        a_is_expected: bool,
-        a: T,
-        b: T,
-    ) -> InferResult<'tcx, ()>
-    where
-        T: ToTrace<'tcx>,
-    {
-        self.trace_exp(a_is_expected, a, b).sub(define_opaque_types, a, b)
-    }
-
     /// Makes `actual <: expected`. For example, if type-checking a
     /// call like `foo(x)`, where `foo: fn(i32)`, you might have
     /// `sup(i32, x)`, since the "expected" type is the type that
@@ -138,7 +120,7 @@ impl<'a, 'tcx> At<'a, 'tcx> {
     where
         T: ToTrace<'tcx>,
     {
-        self.sub_exp(define_opaque_types, false, actual, expected)
+        self.trace(expected, actual).sup(define_opaque_types, expected, actual)
     }
 
     /// Makes `expected <: actual`.
@@ -154,24 +136,7 @@ impl<'a, 'tcx> At<'a, 'tcx> {
     where
         T: ToTrace<'tcx>,
     {
-        self.sub_exp(define_opaque_types, true, expected, actual)
-    }
-
-    /// Makes `expected <: actual`.
-    ///
-    /// See [`At::trace_exp`] and [`Trace::eq`] for a version of
-    /// this method that only requires `T: Relate<'tcx>`
-    pub fn eq_exp<T>(
-        self,
-        define_opaque_types: DefineOpaqueTypes,
-        a_is_expected: bool,
-        a: T,
-        b: T,
-    ) -> InferResult<'tcx, ()>
-    where
-        T: ToTrace<'tcx>,
-    {
-        self.trace_exp(a_is_expected, a, b).eq(define_opaque_types, a, b)
+        self.trace(expected, actual).sub(define_opaque_types, expected, actual)
     }
 
     /// Makes `expected <: actual`.
@@ -260,48 +225,50 @@ impl<'a, 'tcx> At<'a, 'tcx> {
     where
         T: ToTrace<'tcx>,
     {
-        self.trace_exp(true, expected, actual)
-    }
-
-    /// Like `trace`, but the expected value is determined by the
-    /// boolean argument (if true, then the first argument `a` is the
-    /// "expected" value).
-    pub fn trace_exp<T>(self, a_is_expected: bool, a: T, b: T) -> Trace<'a, 'tcx>
-    where
-        T: ToTrace<'tcx>,
-    {
-        let trace = ToTrace::to_trace(self.cause, a_is_expected, a, b);
-        Trace { at: self, trace, a_is_expected }
+        let trace = ToTrace::to_trace(self.cause, true, expected, actual);
+        Trace { at: self, trace }
     }
 }
 
 impl<'a, 'tcx> Trace<'a, 'tcx> {
-    /// Makes `a <: b` where `a` may or may not be expected (if
-    /// `a_is_expected` is true, then `a` is expected).
+    /// Makes `a <: b`.
     #[instrument(skip(self), level = "debug")]
     pub fn sub<T>(self, define_opaque_types: DefineOpaqueTypes, a: T, b: T) -> InferResult<'tcx, ()>
     where
         T: Relate<'tcx>,
     {
-        let Trace { at, trace, a_is_expected } = self;
+        let Trace { at, trace } = self;
         let mut fields = at.infcx.combine_fields(trace, at.param_env, define_opaque_types);
         fields
-            .sub(a_is_expected)
+            .sub()
             .relate(a, b)
             .map(move |_| InferOk { value: (), obligations: fields.obligations })
     }
 
-    /// Makes `a == b`; the expectation is set by the call to
-    /// `trace()`.
+    /// Makes `a :> b`.
+    #[instrument(skip(self), level = "debug")]
+    pub fn sup<T>(self, define_opaque_types: DefineOpaqueTypes, a: T, b: T) -> InferResult<'tcx, ()>
+    where
+        T: Relate<'tcx>,
+    {
+        let Trace { at, trace } = self;
+        let mut fields = at.infcx.combine_fields(trace, at.param_env, define_opaque_types);
+        fields
+            .sup()
+            .relate(a, b)
+            .map(move |_| InferOk { value: (), obligations: fields.obligations })
+    }
+
+    /// Makes `a == b`.
     #[instrument(skip(self), level = "debug")]
     pub fn eq<T>(self, define_opaque_types: DefineOpaqueTypes, a: T, b: T) -> InferResult<'tcx, ()>
     where
         T: Relate<'tcx>,
     {
-        let Trace { at, trace, a_is_expected } = self;
+        let Trace { at, trace } = self;
         let mut fields = at.infcx.combine_fields(trace, at.param_env, define_opaque_types);
         fields
-            .equate(StructurallyRelateAliases::No, a_is_expected)
+            .equate(StructurallyRelateAliases::No)
             .relate(a, b)
             .map(move |_| InferOk { value: (), obligations: fields.obligations })
     }
@@ -313,11 +280,11 @@ impl<'a, 'tcx> Trace<'a, 'tcx> {
     where
         T: Relate<'tcx>,
     {
-        let Trace { at, trace, a_is_expected } = self;
+        let Trace { at, trace } = self;
         debug_assert!(at.infcx.next_trait_solver());
         let mut fields = at.infcx.combine_fields(trace, at.param_env, DefineOpaqueTypes::No);
         fields
-            .equate(StructurallyRelateAliases::Yes, a_is_expected)
+            .equate(StructurallyRelateAliases::Yes)
             .relate(a, b)
             .map(move |_| InferOk { value: (), obligations: fields.obligations })
     }
@@ -327,10 +294,10 @@ impl<'a, 'tcx> Trace<'a, 'tcx> {
     where
         T: Relate<'tcx>,
     {
-        let Trace { at, trace, a_is_expected } = self;
+        let Trace { at, trace } = self;
         let mut fields = at.infcx.combine_fields(trace, at.param_env, define_opaque_types);
         fields
-            .lub(a_is_expected)
+            .lub()
             .relate(a, b)
             .map(move |t| InferOk { value: t, obligations: fields.obligations })
     }
@@ -340,10 +307,10 @@ impl<'a, 'tcx> Trace<'a, 'tcx> {
     where
         T: Relate<'tcx>,
     {
-        let Trace { at, trace, a_is_expected } = self;
+        let Trace { at, trace } = self;
         let mut fields = at.infcx.combine_fields(trace, at.param_env, define_opaque_types);
         fields
-            .glb(a_is_expected)
+            .glb()
             .relate(a, b)
             .map(move |t| InferOk { value: t, obligations: fields.obligations })
     }
diff --git a/compiler/rustc_infer/src/infer/opaque_types.rs b/compiler/rustc_infer/src/infer/opaque_types.rs
index d381c77ec66..07245643ef5 100644
--- a/compiler/rustc_infer/src/infer/opaque_types.rs
+++ b/compiler/rustc_infer/src/infer/opaque_types.rs
@@ -522,13 +522,7 @@ impl<'tcx> InferCtxt<'tcx> {
     ) -> InferResult<'tcx, ()> {
         let mut obligations = Vec::new();
 
-        self.insert_hidden_type(
-            opaque_type_key,
-            &cause,
-            param_env,
-            hidden_ty,
-            &mut obligations,
-        )?;
+        self.insert_hidden_type(opaque_type_key, &cause, param_env, hidden_ty, &mut obligations)?;
 
         self.add_item_bounds_for_hidden_type(
             opaque_type_key.def_id.to_def_id(),
diff --git a/compiler/rustc_infer/src/infer/relate/combine.rs b/compiler/rustc_infer/src/infer/relate/combine.rs
index 749c50b57b5..099b7ff7c04 100644
--- a/compiler/rustc_infer/src/infer/relate/combine.rs
+++ b/compiler/rustc_infer/src/infer/relate/combine.rs
@@ -321,21 +321,24 @@ impl<'infcx, 'tcx> CombineFields<'infcx, 'tcx> {
     pub fn equate<'a>(
         &'a mut self,
         structurally_relate_aliases: StructurallyRelateAliases,
-        a_is_expected: bool,
     ) -> TypeRelating<'a, 'infcx, 'tcx> {
-        TypeRelating::new(self, a_is_expected, structurally_relate_aliases, ty::Invariant)
+        TypeRelating::new(self, structurally_relate_aliases, ty::Invariant)
     }
 
-    pub fn sub<'a>(&'a mut self, a_is_expected: bool) -> TypeRelating<'a, 'infcx, 'tcx> {
-        TypeRelating::new(self, a_is_expected, StructurallyRelateAliases::No, ty::Covariant)
+    pub fn sub<'a>(&'a mut self) -> TypeRelating<'a, 'infcx, 'tcx> {
+        TypeRelating::new(self, StructurallyRelateAliases::No, ty::Covariant)
     }
 
-    pub fn lub<'a>(&'a mut self, a_is_expected: bool) -> Lub<'a, 'infcx, 'tcx> {
-        Lub::new(self, a_is_expected)
+    pub fn sup<'a>(&'a mut self) -> TypeRelating<'a, 'infcx, 'tcx> {
+        TypeRelating::new(self, StructurallyRelateAliases::No, ty::Contravariant)
     }
 
-    pub fn glb<'a>(&'a mut self, a_is_expected: bool) -> Glb<'a, 'infcx, 'tcx> {
-        Glb::new(self, a_is_expected)
+    pub fn lub<'a>(&'a mut self) -> Lub<'a, 'infcx, 'tcx> {
+        Lub::new(self)
+    }
+
+    pub fn glb<'a>(&'a mut self) -> Glb<'a, 'infcx, 'tcx> {
+        Glb::new(self)
     }
 
     pub fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>) {
diff --git a/compiler/rustc_infer/src/infer/relate/glb.rs b/compiler/rustc_infer/src/infer/relate/glb.rs
index 9b77e6888b2..f6796861b12 100644
--- a/compiler/rustc_infer/src/infer/relate/glb.rs
+++ b/compiler/rustc_infer/src/infer/relate/glb.rs
@@ -13,15 +13,11 @@ use crate::traits::{ObligationCause, PredicateObligations};
 /// "Greatest lower bound" (common subtype)
 pub struct Glb<'combine, 'infcx, 'tcx> {
     fields: &'combine mut CombineFields<'infcx, 'tcx>,
-    a_is_expected: bool,
 }
 
 impl<'combine, 'infcx, 'tcx> Glb<'combine, 'infcx, 'tcx> {
-    pub fn new(
-        fields: &'combine mut CombineFields<'infcx, 'tcx>,
-        a_is_expected: bool,
-    ) -> Glb<'combine, 'infcx, 'tcx> {
-        Glb { fields, a_is_expected }
+    pub fn new(fields: &'combine mut CombineFields<'infcx, 'tcx>) -> Glb<'combine, 'infcx, 'tcx> {
+        Glb { fields }
     }
 }
 
@@ -35,7 +31,7 @@ impl<'tcx> TypeRelation<'tcx> for Glb<'_, '_, 'tcx> {
     }
 
     fn a_is_expected(&self) -> bool {
-        self.a_is_expected
+        true
     }
 
     fn relate_with_variance<T: Relate<'tcx>>(
@@ -46,13 +42,11 @@ impl<'tcx> TypeRelation<'tcx> for Glb<'_, '_, 'tcx> {
         b: T,
     ) -> RelateResult<'tcx, T> {
         match variance {
-            ty::Invariant => {
-                self.fields.equate(StructurallyRelateAliases::No, self.a_is_expected).relate(a, b)
-            }
+            ty::Invariant => self.fields.equate(StructurallyRelateAliases::No).relate(a, b),
             ty::Covariant => self.relate(a, b),
             // FIXME(#41044) -- not correct, need test
             ty::Bivariant => Ok(a),
-            ty::Contravariant => self.fields.lub(self.a_is_expected).relate(a, b),
+            ty::Contravariant => self.fields.lub().relate(a, b),
         }
     }
 
@@ -126,7 +120,7 @@ impl<'combine, 'infcx, 'tcx> LatticeDir<'infcx, 'tcx> for Glb<'combine, 'infcx,
     }
 
     fn relate_bound(&mut self, v: Ty<'tcx>, a: Ty<'tcx>, b: Ty<'tcx>) -> RelateResult<'tcx, ()> {
-        let mut sub = self.fields.sub(self.a_is_expected);
+        let mut sub = self.fields.sub();
         sub.relate(v, a)?;
         sub.relate(v, b)?;
         Ok(())
diff --git a/compiler/rustc_infer/src/infer/relate/higher_ranked.rs b/compiler/rustc_infer/src/infer/relate/higher_ranked.rs
index 90be80f67b4..c94cbb0db03 100644
--- a/compiler/rustc_infer/src/infer/relate/higher_ranked.rs
+++ b/compiler/rustc_infer/src/infer/relate/higher_ranked.rs
@@ -49,7 +49,13 @@ impl<'a, 'tcx> CombineFields<'a, 'tcx> {
             debug!("b_prime={:?}", sup_prime);
 
             // Compare types now that bound regions have been replaced.
-            let result = self.sub(sub_is_expected).relate(sub_prime, sup_prime);
+            // Reorder the inputs so that the expected is passed first.
+            let result = if sub_is_expected {
+                self.sub().relate(sub_prime, sup_prime)
+            } else {
+                self.sup().relate(sup_prime, sub_prime)
+            };
+
             if result.is_ok() {
                 debug!("OK result={result:?}");
             }
diff --git a/compiler/rustc_infer/src/infer/relate/lub.rs b/compiler/rustc_infer/src/infer/relate/lub.rs
index db04e3231d6..3d9cfe7bf05 100644
--- a/compiler/rustc_infer/src/infer/relate/lub.rs
+++ b/compiler/rustc_infer/src/infer/relate/lub.rs
@@ -13,15 +13,11 @@ use rustc_span::Span;
 /// "Least upper bound" (common supertype)
 pub struct Lub<'combine, 'infcx, 'tcx> {
     fields: &'combine mut CombineFields<'infcx, 'tcx>,
-    a_is_expected: bool,
 }
 
 impl<'combine, 'infcx, 'tcx> Lub<'combine, 'infcx, 'tcx> {
-    pub fn new(
-        fields: &'combine mut CombineFields<'infcx, 'tcx>,
-        a_is_expected: bool,
-    ) -> Lub<'combine, 'infcx, 'tcx> {
-        Lub { fields, a_is_expected }
+    pub fn new(fields: &'combine mut CombineFields<'infcx, 'tcx>) -> Lub<'combine, 'infcx, 'tcx> {
+        Lub { fields }
     }
 }
 
@@ -35,7 +31,7 @@ impl<'tcx> TypeRelation<'tcx> for Lub<'_, '_, 'tcx> {
     }
 
     fn a_is_expected(&self) -> bool {
-        self.a_is_expected
+        true
     }
 
     fn relate_with_variance<T: Relate<'tcx>>(
@@ -46,13 +42,11 @@ impl<'tcx> TypeRelation<'tcx> for Lub<'_, '_, 'tcx> {
         b: T,
     ) -> RelateResult<'tcx, T> {
         match variance {
-            ty::Invariant => {
-                self.fields.equate(StructurallyRelateAliases::No, self.a_is_expected).relate(a, b)
-            }
+            ty::Invariant => self.fields.equate(StructurallyRelateAliases::No).relate(a, b),
             ty::Covariant => self.relate(a, b),
             // FIXME(#41044) -- not correct, need test
             ty::Bivariant => Ok(a),
-            ty::Contravariant => self.fields.glb(self.a_is_expected).relate(a, b),
+            ty::Contravariant => self.fields.glb().relate(a, b),
         }
     }
 
@@ -126,7 +120,7 @@ impl<'combine, 'infcx, 'tcx> LatticeDir<'infcx, 'tcx> for Lub<'combine, 'infcx,
     }
 
     fn relate_bound(&mut self, v: Ty<'tcx>, a: Ty<'tcx>, b: Ty<'tcx>) -> RelateResult<'tcx, ()> {
-        let mut sub = self.fields.sub(self.a_is_expected);
+        let mut sub = self.fields.sub();
         sub.relate(a, v)?;
         sub.relate(b, v)?;
         Ok(())
diff --git a/compiler/rustc_infer/src/infer/relate/type_relating.rs b/compiler/rustc_infer/src/infer/relate/type_relating.rs
index ddc4bf9a514..7464b525724 100644
--- a/compiler/rustc_infer/src/infer/relate/type_relating.rs
+++ b/compiler/rustc_infer/src/infer/relate/type_relating.rs
@@ -12,7 +12,6 @@ use rustc_span::Span;
 /// Enforce that `a` is equal to or a subtype of `b`.
 pub struct TypeRelating<'combine, 'a, 'tcx> {
     fields: &'combine mut CombineFields<'a, 'tcx>,
-    a_is_expected: bool,
     structurally_relate_aliases: StructurallyRelateAliases,
     ambient_variance: ty::Variance,
 }
@@ -20,11 +19,10 @@ pub struct TypeRelating<'combine, 'a, 'tcx> {
 impl<'combine, 'infcx, 'tcx> TypeRelating<'combine, 'infcx, 'tcx> {
     pub fn new(
         f: &'combine mut CombineFields<'infcx, 'tcx>,
-        a_is_expected: bool,
         structurally_relate_aliases: StructurallyRelateAliases,
         ambient_variance: ty::Variance,
     ) -> TypeRelating<'combine, 'infcx, 'tcx> {
-        TypeRelating { fields: f, a_is_expected, structurally_relate_aliases, ambient_variance }
+        TypeRelating { fields: f, structurally_relate_aliases, ambient_variance }
     }
 }
 
@@ -38,7 +36,7 @@ impl<'tcx> TypeRelation<'tcx> for TypeRelating<'_, '_, 'tcx> {
     }
 
     fn a_is_expected(&self) -> bool {
-        self.a_is_expected
+        true
     }
 
     fn relate_with_variance<T: Relate<'tcx>>(
@@ -79,7 +77,7 @@ impl<'tcx> TypeRelation<'tcx> for TypeRelating<'_, '_, 'tcx> {
                             self.fields.trace.cause.clone(),
                             self.fields.param_env,
                             ty::Binder::dummy(ty::PredicateKind::Subtype(ty::SubtypePredicate {
-                                a_is_expected: self.a_is_expected,
+                                a_is_expected: true,
                                 a,
                                 b,
                             })),
@@ -93,7 +91,7 @@ impl<'tcx> TypeRelation<'tcx> for TypeRelating<'_, '_, 'tcx> {
                             self.fields.trace.cause.clone(),
                             self.fields.param_env,
                             ty::Binder::dummy(ty::PredicateKind::Subtype(ty::SubtypePredicate {
-                                a_is_expected: !self.a_is_expected,
+                                a_is_expected: false,
                                 a: b,
                                 b: a,
                             })),
@@ -109,18 +107,12 @@ impl<'tcx> TypeRelation<'tcx> for TypeRelating<'_, '_, 'tcx> {
             }
 
             (&ty::Infer(TyVar(a_vid)), _) => {
-                infcx.instantiate_ty_var(
-                    self,
-                    self.a_is_expected,
-                    a_vid,
-                    self.ambient_variance,
-                    b,
-                )?;
+                infcx.instantiate_ty_var(self, true, a_vid, self.ambient_variance, b)?;
             }
             (_, &ty::Infer(TyVar(b_vid))) => {
                 infcx.instantiate_ty_var(
                     self,
-                    !self.a_is_expected,
+                    false,
                     b_vid,
                     self.ambient_variance.xform(ty::Contravariant),
                     a,
@@ -147,13 +139,7 @@ impl<'tcx> TypeRelation<'tcx> for TypeRelating<'_, '_, 'tcx> {
             {
                 self.fields.obligations.extend(
                     infcx
-                        .handle_opaque_type(
-                            a,
-                            b,
-                            self.a_is_expected,
-                            &self.fields.trace.cause,
-                            self.param_env(),
-                        )?
+                        .handle_opaque_type(a, b, true, &self.fields.trace.cause, self.param_env())?
                         .obligations,
                 );
             }
@@ -239,14 +225,14 @@ impl<'tcx> TypeRelation<'tcx> for TypeRelating<'_, '_, 'tcx> {
         } else {
             match self.ambient_variance {
                 ty::Covariant => {
-                    self.fields.higher_ranked_sub(a, b, self.a_is_expected)?;
+                    self.fields.higher_ranked_sub(a, b, true)?;
                 }
                 ty::Contravariant => {
-                    self.fields.higher_ranked_sub(b, a, !self.a_is_expected)?;
+                    self.fields.higher_ranked_sub(b, a, false)?;
                 }
                 ty::Invariant => {
-                    self.fields.higher_ranked_sub(a, b, self.a_is_expected)?;
-                    self.fields.higher_ranked_sub(b, a, !self.a_is_expected)?;
+                    self.fields.higher_ranked_sub(a, b, true)?;
+                    self.fields.higher_ranked_sub(b, a, false)?;
                 }
                 ty::Bivariant => {
                     unreachable!("Expected bivariance to be handled in relate_with_variance")
diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/type_err_ctxt_ext.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/type_err_ctxt_ext.rs
index fa8edd11594..7f7bd867f63 100644
--- a/compiler/rustc_trait_selection/src/traits/error_reporting/type_err_ctxt_ext.rs
+++ b/compiler/rustc_trait_selection/src/traits/error_reporting/type_err_ctxt_ext.rs
@@ -1541,12 +1541,9 @@ impl<'tcx> TypeErrCtxt<'_, 'tcx> {
                 // since the normalization is just done to improve the error message.
                 let _ = ocx.select_where_possible();
 
-                if let Err(new_err) = ocx.eq(
-                    &obligation.cause,
-                    obligation.param_env,
-                    expected,
-                    actual,
-                ) {
+                if let Err(new_err) =
+                    ocx.eq(&obligation.cause, obligation.param_env, expected, actual)
+                {
                     (Some((data, is_normalized_term_expected, normalized_term, data.term)), new_err)
                 } else {
                     (None, error.err)

From b1536568db8890c2e9b7ecae30f26ab7d74219ff Mon Sep 17 00:00:00 2001
From: Michael Goulet <michael@errs.io>
Date: Wed, 28 Feb 2024 17:41:18 +0000
Subject: [PATCH 6/7] Fallout from removing a_is_expected

---
 .../src/type_check/relate_tys.rs              |  7 +-
 .../src/infer/error_reporting/mod.rs          |  4 -
 .../rustc_infer/src/infer/opaque_types.rs     |  6 +-
 .../src/infer/outlives/test_type_match.rs     |  4 -
 .../rustc_infer/src/infer/relate/combine.rs   | 31 +++-----
 .../src/infer/relate/generalize.rs            | 12 +--
 compiler/rustc_infer/src/infer/relate/glb.rs  |  4 -
 .../rustc_infer/src/infer/relate/lattice.rs   |  4 +-
 compiler/rustc_infer/src/infer/relate/lub.rs  |  4 -
 .../src/infer/relate/type_relating.rs         | 11 +--
 compiler/rustc_middle/src/ty/_match.rs        |  8 +-
 compiler/rustc_middle/src/ty/relate.rs        | 78 ++++++-------------
 12 files changed, 46 insertions(+), 127 deletions(-)

diff --git a/compiler/rustc_borrowck/src/type_check/relate_tys.rs b/compiler/rustc_borrowck/src/type_check/relate_tys.rs
index 5c5274d7d86..78609a482ed 100644
--- a/compiler/rustc_borrowck/src/type_check/relate_tys.rs
+++ b/compiler/rustc_borrowck/src/type_check/relate_tys.rs
@@ -160,8 +160,7 @@ impl<'me, 'bccx, 'tcx> NllTypeRelating<'me, 'bccx, 'tcx> {
             ),
         };
         let cause = ObligationCause::dummy_with_span(self.span());
-        let obligations =
-            infcx.handle_opaque_type(a, b, true, &cause, self.param_env())?.obligations;
+        let obligations = infcx.handle_opaque_type(a, b, &cause, self.param_env())?.obligations;
         self.register_obligations(obligations);
         Ok(())
     }
@@ -330,10 +329,6 @@ impl<'bccx, 'tcx> TypeRelation<'tcx> for NllTypeRelating<'_, 'bccx, 'tcx> {
         "nll::subtype"
     }
 
-    fn a_is_expected(&self) -> bool {
-        true
-    }
-
     #[instrument(skip(self, info), level = "trace", ret)]
     fn relate_with_variance<T: Relate<'tcx>>(
         &mut self,
diff --git a/compiler/rustc_infer/src/infer/error_reporting/mod.rs b/compiler/rustc_infer/src/infer/error_reporting/mod.rs
index 911b2f16c8b..1cf990fef04 100644
--- a/compiler/rustc_infer/src/infer/error_reporting/mod.rs
+++ b/compiler/rustc_infer/src/infer/error_reporting/mod.rs
@@ -2654,10 +2654,6 @@ impl<'tcx> TypeRelation<'tcx> for SameTypeModuloInfer<'_, 'tcx> {
         "SameTypeModuloInfer"
     }
 
-    fn a_is_expected(&self) -> bool {
-        true
-    }
-
     fn relate_with_variance<T: relate::Relate<'tcx>>(
         &mut self,
         _variance: ty::Variance,
diff --git a/compiler/rustc_infer/src/infer/opaque_types.rs b/compiler/rustc_infer/src/infer/opaque_types.rs
index 07245643ef5..7a789a1b41b 100644
--- a/compiler/rustc_infer/src/infer/opaque_types.rs
+++ b/compiler/rustc_infer/src/infer/opaque_types.rs
@@ -78,9 +78,7 @@ impl<'tcx> InferCtxt<'tcx> {
                         span,
                     });
                     obligations.extend(
-                        self.handle_opaque_type(ty, ty_var, true, &cause, param_env)
-                            .unwrap()
-                            .obligations,
+                        self.handle_opaque_type(ty, ty_var, &cause, param_env).unwrap().obligations,
                     );
                     ty_var
                 }
@@ -94,14 +92,12 @@ impl<'tcx> InferCtxt<'tcx> {
         &self,
         a: Ty<'tcx>,
         b: Ty<'tcx>,
-        a_is_expected: bool,
         cause: &ObligationCause<'tcx>,
         param_env: ty::ParamEnv<'tcx>,
     ) -> InferResult<'tcx, ()> {
         if a.references_error() || b.references_error() {
             return Ok(InferOk { value: (), obligations: vec![] });
         }
-        let (a, b) = if a_is_expected { (a, b) } else { (b, a) };
         let process = |a: Ty<'tcx>, b: Ty<'tcx>| match *a.kind() {
             ty::Alias(ty::Opaque, ty::AliasTy { def_id, args, .. }) if def_id.is_local() => {
                 let def_id = def_id.expect_local();
diff --git a/compiler/rustc_infer/src/infer/outlives/test_type_match.rs b/compiler/rustc_infer/src/infer/outlives/test_type_match.rs
index d547f51f381..29c11d4247d 100644
--- a/compiler/rustc_infer/src/infer/outlives/test_type_match.rs
+++ b/compiler/rustc_infer/src/infer/outlives/test_type_match.rs
@@ -144,10 +144,6 @@ impl<'tcx> TypeRelation<'tcx> for MatchAgainstHigherRankedOutlives<'tcx> {
         self.tcx
     }
 
-    fn a_is_expected(&self) -> bool {
-        true
-    } // irrelevant
-
     #[instrument(level = "trace", skip(self))]
     fn relate_with_variance<T: Relate<'tcx>>(
         &mut self,
diff --git a/compiler/rustc_infer/src/infer/relate/combine.rs b/compiler/rustc_infer/src/infer/relate/combine.rs
index 099b7ff7c04..28b7db275a3 100644
--- a/compiler/rustc_infer/src/infer/relate/combine.rs
+++ b/compiler/rustc_infer/src/infer/relate/combine.rs
@@ -17,12 +17,6 @@
 //!
 //! On success, the  LUB/GLB operations return the appropriate bound. The
 //! return value of `Equate` or `Sub` shouldn't really be used.
-//!
-//! ## Contravariance
-//!
-//! We explicitly track which argument is expected using
-//! [TypeRelation::a_is_expected], so when dealing with contravariance
-//! this should be correctly updated.
 
 use super::glb::Glb;
 use super::lub::Lub;
@@ -57,7 +51,6 @@ impl<'tcx> InferCtxt<'tcx> {
     where
         R: ObligationEmittingRelation<'tcx>,
     {
-        let a_is_expected = relation.a_is_expected();
         debug_assert!(!a.has_escaping_bound_vars());
         debug_assert!(!b.has_escaping_bound_vars());
 
@@ -68,20 +61,20 @@ impl<'tcx> InferCtxt<'tcx> {
                     .borrow_mut()
                     .int_unification_table()
                     .unify_var_var(a_id, b_id)
-                    .map_err(|e| int_unification_error(a_is_expected, e))?;
+                    .map_err(|e| int_unification_error(true, e))?;
                 Ok(a)
             }
             (&ty::Infer(ty::IntVar(v_id)), &ty::Int(v)) => {
-                self.unify_integral_variable(a_is_expected, v_id, IntType(v))
+                self.unify_integral_variable(true, v_id, IntType(v))
             }
             (&ty::Int(v), &ty::Infer(ty::IntVar(v_id))) => {
-                self.unify_integral_variable(!a_is_expected, v_id, IntType(v))
+                self.unify_integral_variable(false, v_id, IntType(v))
             }
             (&ty::Infer(ty::IntVar(v_id)), &ty::Uint(v)) => {
-                self.unify_integral_variable(a_is_expected, v_id, UintType(v))
+                self.unify_integral_variable(true, v_id, UintType(v))
             }
             (&ty::Uint(v), &ty::Infer(ty::IntVar(v_id))) => {
-                self.unify_integral_variable(!a_is_expected, v_id, UintType(v))
+                self.unify_integral_variable(false, v_id, UintType(v))
             }
 
             // Relate floating-point variables to other types
@@ -90,14 +83,14 @@ impl<'tcx> InferCtxt<'tcx> {
                     .borrow_mut()
                     .float_unification_table()
                     .unify_var_var(a_id, b_id)
-                    .map_err(|e| float_unification_error(a_is_expected, e))?;
+                    .map_err(|e| float_unification_error(true, e))?;
                 Ok(a)
             }
             (&ty::Infer(ty::FloatVar(v_id)), &ty::Float(v)) => {
-                self.unify_float_variable(a_is_expected, v_id, v)
+                self.unify_float_variable(true, v_id, v)
             }
             (&ty::Float(v), &ty::Infer(ty::FloatVar(v_id))) => {
-                self.unify_float_variable(!a_is_expected, v_id, v)
+                self.unify_float_variable(false, v_id, v)
             }
 
             // We don't expect `TyVar` or `Fresh*` vars at this point with lazy norm.
@@ -130,7 +123,7 @@ impl<'tcx> InferCtxt<'tcx> {
 
             // All other cases of inference are errors
             (&ty::Infer(_), _) | (_, &ty::Infer(_)) => {
-                Err(TypeError::Sorts(ty::relate::expected_found(relation, a, b)))
+                Err(TypeError::Sorts(ty::relate::expected_found(a, b)))
             }
 
             // During coherence, opaque types should be treated as *possibly*
@@ -228,12 +221,12 @@ impl<'tcx> InferCtxt<'tcx> {
             }
 
             (ty::ConstKind::Infer(InferConst::Var(vid)), _) => {
-                self.instantiate_const_var(relation, relation.a_is_expected(), vid, b)?;
+                self.instantiate_const_var(relation, true, vid, b)?;
                 Ok(b)
             }
 
             (_, ty::ConstKind::Infer(InferConst::Var(vid))) => {
-                self.instantiate_const_var(relation, !relation.a_is_expected(), vid, a)?;
+                self.instantiate_const_var(relation, false, vid, a)?;
                 Ok(a)
             }
 
@@ -250,8 +243,6 @@ impl<'tcx> InferCtxt<'tcx> {
             {
                 match relation.structurally_relate_aliases() {
                     StructurallyRelateAliases::No => {
-                        let (a, b) = if relation.a_is_expected() { (a, b) } else { (b, a) };
-
                         relation.register_predicates([if self.next_trait_solver() {
                             ty::PredicateKind::AliasRelate(
                                 a.into(),
diff --git a/compiler/rustc_infer/src/infer/relate/generalize.rs b/compiler/rustc_infer/src/infer/relate/generalize.rs
index b18c8a8b844..5fb9d9341e0 100644
--- a/compiler/rustc_infer/src/infer/relate/generalize.rs
+++ b/compiler/rustc_infer/src/infer/relate/generalize.rs
@@ -130,7 +130,7 @@ impl<'tcx> InferCtxt<'tcx> {
             // instantiate_ty_var(?b, A) # expected and variance flipped
             // A rel A'
             // ```
-            if target_is_expected == relation.a_is_expected() {
+            if target_is_expected {
                 relation.relate(generalized_ty, source_ty)?;
             } else {
                 debug!("flip relation");
@@ -204,9 +204,9 @@ impl<'tcx> InferCtxt<'tcx> {
             .const_unification_table()
             .union_value(target_vid, ConstVariableValue::Known { value: generalized_ct });
 
-        // HACK: make sure that we `a_is_expected` continues to be
-        // correct when relating the generalized type with the source.
-        if target_is_expected == relation.a_is_expected() {
+        // Make sure that the order is correct when relating the
+        // generalized const and the source.
+        if target_is_expected {
             relation.relate_with_variance(
                 ty::Variance::Invariant,
                 ty::VarianceDiagInfo::default(),
@@ -398,10 +398,6 @@ impl<'tcx> TypeRelation<'tcx> for Generalizer<'_, 'tcx> {
         "Generalizer"
     }
 
-    fn a_is_expected(&self) -> bool {
-        true
-    }
-
     fn relate_item_args(
         &mut self,
         item_def_id: DefId,
diff --git a/compiler/rustc_infer/src/infer/relate/glb.rs b/compiler/rustc_infer/src/infer/relate/glb.rs
index f6796861b12..b86d1b2671d 100644
--- a/compiler/rustc_infer/src/infer/relate/glb.rs
+++ b/compiler/rustc_infer/src/infer/relate/glb.rs
@@ -30,10 +30,6 @@ impl<'tcx> TypeRelation<'tcx> for Glb<'_, '_, 'tcx> {
         self.fields.tcx()
     }
 
-    fn a_is_expected(&self) -> bool {
-        true
-    }
-
     fn relate_with_variance<T: Relate<'tcx>>(
         &mut self,
         variance: ty::Variance,
diff --git a/compiler/rustc_infer/src/infer/relate/lattice.rs b/compiler/rustc_infer/src/infer/relate/lattice.rs
index 744e2dfa380..747158585db 100644
--- a/compiler/rustc_infer/src/infer/relate/lattice.rs
+++ b/compiler/rustc_infer/src/infer/relate/lattice.rs
@@ -116,9 +116,7 @@ where
                 && !this.infcx().next_trait_solver() =>
         {
             this.register_obligations(
-                infcx
-                    .handle_opaque_type(a, b, this.a_is_expected(), this.cause(), this.param_env())?
-                    .obligations,
+                infcx.handle_opaque_type(a, b, this.cause(), this.param_env())?.obligations,
             );
             Ok(a)
         }
diff --git a/compiler/rustc_infer/src/infer/relate/lub.rs b/compiler/rustc_infer/src/infer/relate/lub.rs
index 3d9cfe7bf05..20f5f65c984 100644
--- a/compiler/rustc_infer/src/infer/relate/lub.rs
+++ b/compiler/rustc_infer/src/infer/relate/lub.rs
@@ -30,10 +30,6 @@ impl<'tcx> TypeRelation<'tcx> for Lub<'_, '_, 'tcx> {
         self.fields.tcx()
     }
 
-    fn a_is_expected(&self) -> bool {
-        true
-    }
-
     fn relate_with_variance<T: Relate<'tcx>>(
         &mut self,
         variance: ty::Variance,
diff --git a/compiler/rustc_infer/src/infer/relate/type_relating.rs b/compiler/rustc_infer/src/infer/relate/type_relating.rs
index 7464b525724..c053adc5e07 100644
--- a/compiler/rustc_infer/src/infer/relate/type_relating.rs
+++ b/compiler/rustc_infer/src/infer/relate/type_relating.rs
@@ -35,10 +35,6 @@ impl<'tcx> TypeRelation<'tcx> for TypeRelating<'_, '_, 'tcx> {
         self.fields.infcx.tcx
     }
 
-    fn a_is_expected(&self) -> bool {
-        true
-    }
-
     fn relate_with_variance<T: Relate<'tcx>>(
         &mut self,
         variance: ty::Variance,
@@ -139,7 +135,7 @@ impl<'tcx> TypeRelation<'tcx> for TypeRelating<'_, '_, 'tcx> {
             {
                 self.fields.obligations.extend(
                     infcx
-                        .handle_opaque_type(a, b, true, &self.fields.trace.cause, self.param_env())?
+                        .handle_opaque_type(a, b, &self.fields.trace.cause, self.param_env())?
                         .obligations,
                 );
             }
@@ -158,10 +154,6 @@ impl<'tcx> TypeRelation<'tcx> for TypeRelating<'_, '_, 'tcx> {
         b: ty::Region<'tcx>,
     ) -> RelateResult<'tcx, ty::Region<'tcx>> {
         debug!("{}.regions({:?}, {:?})", self.tag(), a, b);
-
-        // FIXME -- we have more fine-grained information available
-        // from the "cause" field, we could perhaps give more tailored
-        // error messages.
         let origin = SubregionOrigin::Subtype(Box::new(self.fields.trace.clone()));
 
         match self.ambient_variance {
@@ -184,7 +176,6 @@ impl<'tcx> TypeRelation<'tcx> for TypeRelating<'_, '_, 'tcx> {
                     .make_subregion(origin, a, b);
             }
             ty::Invariant => {
-                // The order of `make_eqregion` apparently matters.
                 self.fields
                     .infcx
                     .inner
diff --git a/compiler/rustc_middle/src/ty/_match.rs b/compiler/rustc_middle/src/ty/_match.rs
index 425a2dbd890..e28e4d66faf 100644
--- a/compiler/rustc_middle/src/ty/_match.rs
+++ b/compiler/rustc_middle/src/ty/_match.rs
@@ -37,10 +37,6 @@ impl<'tcx> TypeRelation<'tcx> for MatchAgainstFreshVars<'tcx> {
         self.tcx
     }
 
-    fn a_is_expected(&self) -> bool {
-        true
-    } // irrelevant
-
     fn relate_with_variance<T: Relate<'tcx>>(
         &mut self,
         _: ty::Variance,
@@ -75,7 +71,7 @@ impl<'tcx> TypeRelation<'tcx> for MatchAgainstFreshVars<'tcx> {
             ) => Ok(a),
 
             (&ty::Infer(_), _) | (_, &ty::Infer(_)) => {
-                Err(TypeError::Sorts(relate::expected_found(self, a, b)))
+                Err(TypeError::Sorts(relate::expected_found(a, b)))
             }
 
             (&ty::Error(guar), _) | (_, &ty::Error(guar)) => Ok(Ty::new_error(self.tcx(), guar)),
@@ -100,7 +96,7 @@ impl<'tcx> TypeRelation<'tcx> for MatchAgainstFreshVars<'tcx> {
             }
 
             (ty::ConstKind::Infer(_), _) | (_, ty::ConstKind::Infer(_)) => {
-                return Err(TypeError::ConstMismatch(relate::expected_found(self, a, b)));
+                return Err(TypeError::ConstMismatch(relate::expected_found(a, b)));
             }
 
             _ => {}
diff --git a/compiler/rustc_middle/src/ty/relate.rs b/compiler/rustc_middle/src/ty/relate.rs
index abd39914cfd..990e78aff8a 100644
--- a/compiler/rustc_middle/src/ty/relate.rs
+++ b/compiler/rustc_middle/src/ty/relate.rs
@@ -15,21 +15,12 @@ use std::iter;
 
 pub type RelateResult<'tcx, T> = Result<T, TypeError<'tcx>>;
 
-#[derive(Clone, Debug)]
-pub enum Cause {
-    ExistentialRegionBound, // relating an existential region bound
-}
-
 pub trait TypeRelation<'tcx>: Sized {
     fn tcx(&self) -> TyCtxt<'tcx>;
 
     /// Returns a static string we can use for printouts.
     fn tag(&self) -> &'static str;
 
-    /// Returns `true` if the value `a` is the "expected" type in the
-    /// relation. Just affects error messages.
-    fn a_is_expected(&self) -> bool;
-
     /// Generic relation routine suitable for most anything.
     fn relate<T: Relate<'tcx>>(&mut self, a: T, b: T) -> RelateResult<'tcx, T> {
         Relate::relate(self, a, b)
@@ -171,11 +162,7 @@ impl<'tcx> Relate<'tcx> for ty::FnSig<'tcx> {
         let tcx = relation.tcx();
 
         if a.c_variadic != b.c_variadic {
-            return Err(TypeError::VariadicMismatch(expected_found(
-                relation,
-                a.c_variadic,
-                b.c_variadic,
-            )));
+            return Err(TypeError::VariadicMismatch(expected_found(a.c_variadic, b.c_variadic)));
         }
         let unsafety = relation.relate(a.unsafety, b.unsafety)?;
         let abi = relation.relate(a.abi, b.abi)?;
@@ -220,39 +207,31 @@ impl<'tcx> Relate<'tcx> for ty::FnSig<'tcx> {
 
 impl<'tcx> Relate<'tcx> for ty::BoundConstness {
     fn relate<R: TypeRelation<'tcx>>(
-        relation: &mut R,
+        _relation: &mut R,
         a: ty::BoundConstness,
         b: ty::BoundConstness,
     ) -> RelateResult<'tcx, ty::BoundConstness> {
-        if a != b {
-            Err(TypeError::ConstnessMismatch(expected_found(relation, a, b)))
-        } else {
-            Ok(a)
-        }
+        if a != b { Err(TypeError::ConstnessMismatch(expected_found(a, b))) } else { Ok(a) }
     }
 }
 
 impl<'tcx> Relate<'tcx> for hir::Unsafety {
     fn relate<R: TypeRelation<'tcx>>(
-        relation: &mut R,
+        _relation: &mut R,
         a: hir::Unsafety,
         b: hir::Unsafety,
     ) -> RelateResult<'tcx, hir::Unsafety> {
-        if a != b {
-            Err(TypeError::UnsafetyMismatch(expected_found(relation, a, b)))
-        } else {
-            Ok(a)
-        }
+        if a != b { Err(TypeError::UnsafetyMismatch(expected_found(a, b))) } else { Ok(a) }
     }
 }
 
 impl<'tcx> Relate<'tcx> for abi::Abi {
     fn relate<R: TypeRelation<'tcx>>(
-        relation: &mut R,
+        _relation: &mut R,
         a: abi::Abi,
         b: abi::Abi,
     ) -> RelateResult<'tcx, abi::Abi> {
-        if a == b { Ok(a) } else { Err(TypeError::AbiMismatch(expected_found(relation, a, b))) }
+        if a == b { Ok(a) } else { Err(TypeError::AbiMismatch(expected_found(a, b))) }
     }
 }
 
@@ -263,7 +242,7 @@ impl<'tcx> Relate<'tcx> for ty::AliasTy<'tcx> {
         b: ty::AliasTy<'tcx>,
     ) -> RelateResult<'tcx, ty::AliasTy<'tcx>> {
         if a.def_id != b.def_id {
-            Err(TypeError::ProjectionMismatched(expected_found(relation, a.def_id, b.def_id)))
+            Err(TypeError::ProjectionMismatched(expected_found(a.def_id, b.def_id)))
         } else {
             let args = match relation.tcx().def_kind(a.def_id) {
                 DefKind::OpaqueTy => relate_args_with_variances(
@@ -291,7 +270,7 @@ impl<'tcx> Relate<'tcx> for ty::ExistentialProjection<'tcx> {
         b: ty::ExistentialProjection<'tcx>,
     ) -> RelateResult<'tcx, ty::ExistentialProjection<'tcx>> {
         if a.def_id != b.def_id {
-            Err(TypeError::ProjectionMismatched(expected_found(relation, a.def_id, b.def_id)))
+            Err(TypeError::ProjectionMismatched(expected_found(a.def_id, b.def_id)))
         } else {
             let term = relation.relate_with_variance(
                 ty::Invariant,
@@ -318,7 +297,7 @@ impl<'tcx> Relate<'tcx> for ty::TraitRef<'tcx> {
     ) -> RelateResult<'tcx, ty::TraitRef<'tcx>> {
         // Different traits cannot be related.
         if a.def_id != b.def_id {
-            Err(TypeError::Traits(expected_found(relation, a.def_id, b.def_id)))
+            Err(TypeError::Traits(expected_found(a.def_id, b.def_id)))
         } else {
             let args = relate_args_invariantly(relation, a.args, b.args)?;
             Ok(ty::TraitRef::new(relation.tcx(), a.def_id, args))
@@ -334,7 +313,7 @@ impl<'tcx> Relate<'tcx> for ty::ExistentialTraitRef<'tcx> {
     ) -> RelateResult<'tcx, ty::ExistentialTraitRef<'tcx>> {
         // Different traits cannot be related.
         if a.def_id != b.def_id {
-            Err(TypeError::Traits(expected_found(relation, a.def_id, b.def_id)))
+            Err(TypeError::Traits(expected_found(a.def_id, b.def_id)))
         } else {
             let args = relate_args_invariantly(relation, a.args, b.args)?;
             Ok(ty::ExistentialTraitRef { def_id: a.def_id, args })
@@ -510,9 +489,9 @@ pub fn structurally_relate_tys<'tcx, R: TypeRelation<'tcx>>(
                     let sz_b = sz_b.try_to_target_usize(tcx);
 
                     match (sz_a, sz_b) {
-                        (Some(sz_a_val), Some(sz_b_val)) if sz_a_val != sz_b_val => Err(
-                            TypeError::FixedArraySize(expected_found(relation, sz_a_val, sz_b_val)),
-                        ),
+                        (Some(sz_a_val), Some(sz_b_val)) if sz_a_val != sz_b_val => {
+                            Err(TypeError::FixedArraySize(expected_found(sz_a_val, sz_b_val)))
+                        }
                         _ => Err(err),
                     }
                 }
@@ -531,9 +510,9 @@ pub fn structurally_relate_tys<'tcx, R: TypeRelation<'tcx>>(
                     iter::zip(as_, bs).map(|(a, b)| relation.relate(a, b)),
                 )?)
             } else if !(as_.is_empty() || bs.is_empty()) {
-                Err(TypeError::TupleSize(expected_found(relation, as_.len(), bs.len())))
+                Err(TypeError::TupleSize(expected_found(as_.len(), bs.len())))
             } else {
-                Err(TypeError::Sorts(expected_found(relation, a, b)))
+                Err(TypeError::Sorts(expected_found(a, b)))
             }
         }
 
@@ -554,7 +533,7 @@ pub fn structurally_relate_tys<'tcx, R: TypeRelation<'tcx>>(
             Ok(Ty::new_alias(tcx, a_kind, alias_ty))
         }
 
-        _ => Err(TypeError::Sorts(expected_found(relation, a, b))),
+        _ => Err(TypeError::Sorts(expected_found(a, b))),
     }
 }
 
@@ -652,13 +631,13 @@ pub fn structurally_relate_consts<'tcx, R: TypeRelation<'tcx>>(
                     let related_args = tcx.mk_const_list(&related_args);
                     Expr::FunctionCall(func, related_args)
                 }
-                _ => return Err(TypeError::ConstMismatch(expected_found(r, a, b))),
+                _ => return Err(TypeError::ConstMismatch(expected_found(a, b))),
             };
             return Ok(ty::Const::new_expr(tcx, expr, a.ty()));
         }
         _ => false,
     };
-    if is_match { Ok(a) } else { Err(TypeError::ConstMismatch(expected_found(relation, a, b))) }
+    if is_match { Ok(a) } else { Err(TypeError::ConstMismatch(expected_found(a, b))) }
 }
 
 impl<'tcx> Relate<'tcx> for &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>> {
@@ -680,7 +659,7 @@ impl<'tcx> Relate<'tcx> for &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>> {
         b_v.sort_by(|a, b| a.skip_binder().stable_cmp(tcx, &b.skip_binder()));
         b_v.dedup();
         if a_v.len() != b_v.len() {
-            return Err(TypeError::ExistentialMismatch(expected_found(relation, a, b)));
+            return Err(TypeError::ExistentialMismatch(expected_found(a, b)));
         }
 
         let v = iter::zip(a_v, b_v).map(|(ep_a, ep_b)| {
@@ -692,7 +671,7 @@ impl<'tcx> Relate<'tcx> for &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>> {
                     relation.relate(ep_a.rebind(a), ep_b.rebind(b))?.skip_binder(),
                 ))),
                 (AutoTrait(a), AutoTrait(b)) if a == b => Ok(ep_a.rebind(AutoTrait(a))),
-                _ => Err(TypeError::ExistentialMismatch(expected_found(relation, a, b))),
+                _ => Err(TypeError::ExistentialMismatch(expected_found(a, b))),
             }
         });
         tcx.mk_poly_existential_predicates_from_iter(v)
@@ -792,15 +771,11 @@ impl<'tcx> Relate<'tcx> for GenericArg<'tcx> {
 
 impl<'tcx> Relate<'tcx> for ty::ImplPolarity {
     fn relate<R: TypeRelation<'tcx>>(
-        relation: &mut R,
+        _relation: &mut R,
         a: ty::ImplPolarity,
         b: ty::ImplPolarity,
     ) -> RelateResult<'tcx, ty::ImplPolarity> {
-        if a != b {
-            Err(TypeError::PolarityMismatch(expected_found(relation, a, b)))
-        } else {
-            Ok(a)
-        }
+        if a != b { Err(TypeError::PolarityMismatch(expected_found(a, b))) } else { Ok(a) }
     }
 }
 
@@ -834,9 +809,6 @@ impl<'tcx> Relate<'tcx> for Term<'tcx> {
 ///////////////////////////////////////////////////////////////////////////
 // Error handling
 
-pub fn expected_found<'tcx, R, T>(relation: &mut R, a: T, b: T) -> ExpectedFound<T>
-where
-    R: TypeRelation<'tcx>,
-{
-    ExpectedFound::new(relation.a_is_expected(), a, b)
+pub fn expected_found<T>(a: T, b: T) -> ExpectedFound<T> {
+    ExpectedFound::new(true, a, b)
 }

From 5072b659ffd502a099535342698a39f4acf8a32e Mon Sep 17 00:00:00 2001
From: Michael Goulet <michael@errs.io>
Date: Fri, 1 Mar 2024 01:20:11 +0000
Subject: [PATCH 7/7] Rebase fallout from TypeRelating::binders, inline
 higher_ranked_sub

---
 .../src/infer/relate/higher_ranked.rs         | 65 +------------------
 .../src/infer/relate/type_relating.rs         | 52 +++++++++++++--
 2 files changed, 51 insertions(+), 66 deletions(-)

diff --git a/compiler/rustc_infer/src/infer/relate/higher_ranked.rs b/compiler/rustc_infer/src/infer/relate/higher_ranked.rs
index c94cbb0db03..f30e366c198 100644
--- a/compiler/rustc_infer/src/infer/relate/higher_ranked.rs
+++ b/compiler/rustc_infer/src/infer/relate/higher_ranked.rs
@@ -1,70 +1,11 @@
 //! Helper routines for higher-ranked things. See the `doc` module at
 //! the end of the file for details.
 
-use super::combine::CombineFields;
 use crate::infer::CombinedSnapshot;
-use crate::infer::{HigherRankedType, InferCtxt};
+use crate::infer::InferCtxt;
 use rustc_middle::ty::fold::FnMutDelegate;
-use rustc_middle::ty::relate::{Relate, RelateResult, TypeRelation};
-use rustc_middle::ty::{self, Binder, Ty, TyCtxt, TypeFoldable};
-
-impl<'a, 'tcx> CombineFields<'a, 'tcx> {
-    /// Checks whether `for<..> sub <: for<..> sup` holds.
-    ///
-    /// For this to hold, **all** instantiations of the super type
-    /// have to be a super type of **at least one** instantiation of
-    /// the subtype.
-    ///
-    /// This is implemented by first entering a new universe.
-    /// We then replace all bound variables in `sup` with placeholders,
-    /// and all bound variables in `sub` with inference vars.
-    /// We can then just relate the two resulting types as normal.
-    ///
-    /// Note: this is a subtle algorithm. For a full explanation, please see
-    /// the [rustc dev guide][rd]
-    ///
-    /// [rd]: https://rustc-dev-guide.rust-lang.org/borrow_check/region_inference/placeholders_and_universes.html
-    #[instrument(skip(self), level = "debug")]
-    pub fn higher_ranked_sub<T>(
-        &mut self,
-        sub: Binder<'tcx, T>,
-        sup: Binder<'tcx, T>,
-        sub_is_expected: bool,
-    ) -> RelateResult<'tcx, ()>
-    where
-        T: Relate<'tcx>,
-    {
-        let span = self.trace.cause.span;
-        // First, we instantiate each bound region in the supertype with a
-        // fresh placeholder region. Note that this automatically creates
-        // a new universe if needed.
-        self.infcx.enter_forall(sup, |sup_prime| {
-            // Next, we instantiate each bound region in the subtype
-            // with a fresh region variable. These region variables --
-            // but no other preexisting region variables -- can name
-            // the placeholders.
-            let sub_prime =
-                self.infcx.instantiate_binder_with_fresh_vars(span, HigherRankedType, sub);
-            debug!("a_prime={:?}", sub_prime);
-            debug!("b_prime={:?}", sup_prime);
-
-            // Compare types now that bound regions have been replaced.
-            // Reorder the inputs so that the expected is passed first.
-            let result = if sub_is_expected {
-                self.sub().relate(sub_prime, sup_prime)
-            } else {
-                self.sup().relate(sup_prime, sub_prime)
-            };
-
-            if result.is_ok() {
-                debug!("OK result={result:?}");
-            }
-            // NOTE: returning the result here would be dangerous as it contains
-            // placeholders which **must not** be named afterwards.
-            result.map(|_| ())
-        })
-    }
-}
+use rustc_middle::ty::relate::RelateResult;
+use rustc_middle::ty::{self, Ty, TyCtxt, TypeFoldable};
 
 impl<'tcx> InferCtxt<'tcx> {
     /// Replaces all bound variables (lifetimes, types, and constants) bound by
diff --git a/compiler/rustc_infer/src/infer/relate/type_relating.rs b/compiler/rustc_infer/src/infer/relate/type_relating.rs
index c053adc5e07..18822351f4f 100644
--- a/compiler/rustc_infer/src/infer/relate/type_relating.rs
+++ b/compiler/rustc_infer/src/infer/relate/type_relating.rs
@@ -1,4 +1,5 @@
 use super::combine::CombineFields;
+use crate::infer::BoundRegionConversionTime::HigherRankedType;
 use crate::infer::{
     DefineOpaqueTypes, ObligationEmittingRelation, StructurallyRelateAliases, SubregionOrigin,
 };
@@ -214,16 +215,59 @@ impl<'tcx> TypeRelation<'tcx> for TypeRelating<'_, '_, 'tcx> {
         {
             self.relate(a, b)?;
         } else {
+            let span = self.fields.trace.cause.span;
+            let infcx = self.fields.infcx;
+
             match self.ambient_variance {
+                // Checks whether `for<..> sub <: for<..> sup` holds.
+                //
+                // For this to hold, **all** instantiations of the super type
+                // have to be a super type of **at least one** instantiation of
+                // the subtype.
+                //
+                // This is implemented by first entering a new universe.
+                // We then replace all bound variables in `sup` with placeholders,
+                // and all bound variables in `sub` with inference vars.
+                // We can then just relate the two resulting types as normal.
+                //
+                // Note: this is a subtle algorithm. For a full explanation, please see
+                // the [rustc dev guide][rd]
+                //
+                // [rd]: https://rustc-dev-guide.rust-lang.org/borrow_check/region_inference/placeholders_and_universes.html
                 ty::Covariant => {
-                    self.fields.higher_ranked_sub(a, b, true)?;
+                    infcx.enter_forall(b, |b| {
+                        let a = infcx.instantiate_binder_with_fresh_vars(span, HigherRankedType, a);
+                        self.relate(a, b)
+                    })?;
                 }
                 ty::Contravariant => {
-                    self.fields.higher_ranked_sub(b, a, false)?;
+                    infcx.enter_forall(a, |a| {
+                        let b = infcx.instantiate_binder_with_fresh_vars(span, HigherRankedType, b);
+                        self.relate(a, b)
+                    })?;
                 }
+
+                // When **equating** binders, we check that there is a 1-to-1
+                // correspondence between the bound vars in both types.
+                //
+                // We do so by separately instantiating one of the binders with
+                // placeholders and the other with inference variables and then
+                // equating the instantiated types.
+                //
+                // We want `for<..> A == for<..> B` -- therefore we want
+                // `exists<..> A == for<..> B` and `exists<..> B == for<..> A`.
+                // Check if `exists<..> A == for<..> B`
                 ty::Invariant => {
-                    self.fields.higher_ranked_sub(a, b, true)?;
-                    self.fields.higher_ranked_sub(b, a, false)?;
+                    infcx.enter_forall(b, |b| {
+                        let a = infcx.instantiate_binder_with_fresh_vars(span, HigherRankedType, a);
+                        self.relate(a, b)
+                    })?;
+
+                    // Check if `exists<..> B == for<..> A`.
+                    infcx.enter_forall(a, |a| {
+                        let b = infcx.instantiate_binder_with_fresh_vars(span, HigherRankedType, b);
+                        self.relate(a, b)
+                    })?;
                 }
                 ty::Bivariant => {
                     unreachable!("Expected bivariance to be handled in relate_with_variance")