From 44a6f72a725d1e274d734473c95f95caa5c6fbb6 Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Fri, 7 Jun 2024 10:05:47 -0400 Subject: [PATCH] Make ObligationEmittingRelation deal with Goals only --- .../src/type_check/relate_tys.rs | 38 ++++++++++---- compiler/rustc_infer/src/infer/at.rs | 52 ++++++++++--------- .../rustc_infer/src/infer/relate/combine.rs | 35 ++++++++++--- compiler/rustc_infer/src/infer/relate/glb.rs | 10 ++-- .../rustc_infer/src/infer/relate/lattice.rs | 9 +++- compiler/rustc_infer/src/infer/relate/lub.rs | 10 ++-- .../src/infer/relate/type_relating.rs | 25 +++++---- 7 files changed, 119 insertions(+), 60 deletions(-) diff --git a/compiler/rustc_borrowck/src/type_check/relate_tys.rs b/compiler/rustc_borrowck/src/type_check/relate_tys.rs index cd51d73ba55..a87b9f7a23d 100644 --- a/compiler/rustc_borrowck/src/type_check/relate_tys.rs +++ b/compiler/rustc_borrowck/src/type_check/relate_tys.rs @@ -3,7 +3,8 @@ use rustc_errors::ErrorGuaranteed; use rustc_infer::infer::relate::{ObligationEmittingRelation, StructurallyRelateAliases}; use rustc_infer::infer::relate::{Relate, RelateResult, TypeRelation}; use rustc_infer::infer::NllRegionVariableOrigin; -use rustc_infer::traits::{Obligation, PredicateObligation}; +use rustc_infer::traits::solve::Goal; +use rustc_infer::traits::Obligation; use rustc_middle::mir::ConstraintCategory; use rustc_middle::span_bug; use rustc_middle::traits::query::NoSolution; @@ -154,8 +155,13 @@ impl<'me, 'bccx, 'tcx> NllTypeRelating<'me, 'bccx, 'tcx> { ), }; let cause = ObligationCause::dummy_with_span(self.span()); - let obligations = infcx.handle_opaque_type(a, b, &cause, self.param_env())?.obligations; - self.register_obligations(obligations); + self.register_obligations( + infcx + .handle_opaque_type(a, b, &cause, self.param_env())? + .obligations + .into_iter() + .map(Goal::from), + ); Ok(()) } @@ -550,22 +556,32 @@ impl<'bccx, 'tcx> ObligationEmittingRelation<'tcx> for NllTypeRelating<'_, 'bccx &mut self, obligations: impl IntoIterator, ty::Predicate<'tcx>>>, ) { + let tcx = self.tcx(); + let param_env = self.param_env(); self.register_obligations( - obligations - .into_iter() - .map(|to_pred| { - Obligation::new(self.tcx(), ObligationCause::dummy(), self.param_env(), to_pred) - }) - .collect(), + obligations.into_iter().map(|to_pred| Goal::new(tcx, param_env, to_pred)), ); } - fn register_obligations(&mut self, obligations: Vec>) { + fn register_obligations( + &mut self, + obligations: impl IntoIterator>>, + ) { let _: Result<_, ErrorGuaranteed> = self.type_checker.fully_perform_op( self.locations, self.category, InstantiateOpaqueType { - obligations, + obligations: obligations + .into_iter() + .map(|goal| { + Obligation::new( + self.tcx(), + ObligationCause::dummy_with_span(self.span()), + goal.param_env, + goal.predicate, + ) + }) + .collect(), // These fields are filled in during execution of the operation base_universe: None, region_constraints: None, diff --git a/compiler/rustc_infer/src/infer/at.rs b/compiler/rustc_infer/src/infer/at.rs index 046d908d148..8994739f5c7 100644 --- a/compiler/rustc_infer/src/infer/at.rs +++ b/compiler/rustc_infer/src/infer/at.rs @@ -31,6 +31,8 @@ use crate::infer::relate::{Relate, StructurallyRelateAliases, TypeRelation}; use rustc_middle::bug; use rustc_middle::ty::{Const, ImplSubject}; +use crate::traits::Obligation; + /// Whether we should define opaque types or just treat them opaquely. /// /// Currently only used to prevent predicate matching from matching anything @@ -119,10 +121,8 @@ impl<'a, 'tcx> At<'a, 'tcx> { self.param_env, define_opaque_types, ); - fields - .sup() - .relate(expected, actual) - .map(|_| InferOk { value: (), obligations: fields.obligations }) + fields.sup().relate(expected, actual)?; + Ok(InferOk { value: (), obligations: fields.into_obligations() }) } /// Makes `expected <: actual`. @@ -141,10 +141,8 @@ impl<'a, 'tcx> At<'a, 'tcx> { self.param_env, define_opaque_types, ); - fields - .sub() - .relate(expected, actual) - .map(|_| InferOk { value: (), obligations: fields.obligations }) + fields.sub().relate(expected, actual)?; + Ok(InferOk { value: (), obligations: fields.into_obligations() }) } /// Makes `expected == actual`. @@ -163,10 +161,22 @@ impl<'a, 'tcx> At<'a, 'tcx> { self.param_env, define_opaque_types, ); - fields - .equate(StructurallyRelateAliases::No) - .relate(expected, actual) - .map(|_| InferOk { value: (), obligations: fields.obligations }) + fields.equate(StructurallyRelateAliases::No).relate(expected, actual)?; + Ok(InferOk { + value: (), + obligations: fields + .obligations + .into_iter() + .map(|goal| { + Obligation::new( + self.infcx.tcx, + fields.trace.cause.clone(), + goal.param_env, + goal.predicate, + ) + }) + .collect(), + }) } /// Equates `expected` and `found` while structurally relating aliases. @@ -187,10 +197,8 @@ impl<'a, 'tcx> At<'a, 'tcx> { self.param_env, DefineOpaqueTypes::Yes, ); - fields - .equate(StructurallyRelateAliases::Yes) - .relate(expected, actual) - .map(|_| InferOk { value: (), obligations: fields.obligations }) + fields.equate(StructurallyRelateAliases::Yes).relate(expected, actual)?; + Ok(InferOk { value: (), obligations: fields.into_obligations() }) } pub fn relate( @@ -237,10 +245,8 @@ impl<'a, 'tcx> At<'a, 'tcx> { self.param_env, define_opaque_types, ); - fields - .lub() - .relate(expected, actual) - .map(|value| InferOk { value, obligations: fields.obligations }) + let value = fields.lub().relate(expected, actual)?; + Ok(InferOk { value, obligations: fields.into_obligations() }) } /// Computes the greatest-lower-bound, or mutual subtype, of two @@ -261,10 +267,8 @@ impl<'a, 'tcx> At<'a, 'tcx> { self.param_env, define_opaque_types, ); - fields - .glb() - .relate(expected, actual) - .map(|value| InferOk { value, obligations: fields.obligations }) + let value = fields.glb().relate(expected, actual)?; + Ok(InferOk { value, obligations: fields.into_obligations() }) } } diff --git a/compiler/rustc_infer/src/infer/relate/combine.rs b/compiler/rustc_infer/src/infer/relate/combine.rs index e62ef5d4ea4..1a0a0d10c6d 100644 --- a/compiler/rustc_infer/src/infer/relate/combine.rs +++ b/compiler/rustc_infer/src/infer/relate/combine.rs @@ -28,6 +28,7 @@ use crate::infer::{DefineOpaqueTypes, InferCtxt, TypeTrace}; use crate::traits::{Obligation, PredicateObligation}; use rustc_middle::bug; use rustc_middle::infer::unify_key::EffectVarValue; +use rustc_middle::traits::solve::Goal; use rustc_middle::ty::error::{ExpectedFound, TypeError}; use rustc_middle::ty::{self, InferConst, Ty, TyCtxt, TypeVisitableExt, Upcast}; use rustc_middle::ty::{IntType, UintType}; @@ -38,7 +39,7 @@ pub struct CombineFields<'infcx, 'tcx> { pub infcx: &'infcx InferCtxt<'tcx>, pub trace: TypeTrace<'tcx>, pub param_env: ty::ParamEnv<'tcx>, - pub obligations: Vec>, + pub obligations: Vec>>, pub define_opaque_types: DefineOpaqueTypes, } @@ -51,6 +52,20 @@ impl<'infcx, 'tcx> CombineFields<'infcx, 'tcx> { ) -> Self { Self { infcx, trace, param_env, define_opaque_types, obligations: vec![] } } + + pub(crate) fn into_obligations(self) -> Vec> { + self.obligations + .into_iter() + .map(|goal| { + Obligation::new( + self.infcx.tcx, + self.trace.cause.clone(), + goal.param_env, + goal.predicate, + ) + }) + .collect() + } } impl<'tcx> InferCtxt<'tcx> { @@ -290,7 +305,10 @@ impl<'infcx, 'tcx> CombineFields<'infcx, 'tcx> { Glb::new(self) } - pub fn register_obligations(&mut self, obligations: Vec>) { + pub fn register_obligations( + &mut self, + obligations: impl IntoIterator>>, + ) { self.obligations.extend(obligations); } @@ -298,9 +316,11 @@ impl<'infcx, 'tcx> CombineFields<'infcx, 'tcx> { &mut self, obligations: impl IntoIterator, ty::Predicate<'tcx>>>, ) { - self.obligations.extend(obligations.into_iter().map(|to_pred| { - Obligation::new(self.infcx.tcx, self.trace.cause.clone(), self.param_env, to_pred) - })) + self.obligations.extend( + obligations + .into_iter() + .map(|to_pred| Goal::new(self.infcx.tcx, self.param_env, to_pred)), + ) } } @@ -315,7 +335,10 @@ pub trait ObligationEmittingRelation<'tcx>: TypeRelation> { fn structurally_relate_aliases(&self) -> StructurallyRelateAliases; /// Register obligations that must hold in order for this relation to hold - fn register_obligations(&mut self, obligations: Vec>); + fn register_obligations( + &mut self, + obligations: impl IntoIterator>>, + ); /// Register predicates that must hold in order for this relation to hold. Uses /// a default obligation cause, [`ObligationEmittingRelation::register_obligations`] should diff --git a/compiler/rustc_infer/src/infer/relate/glb.rs b/compiler/rustc_infer/src/infer/relate/glb.rs index ca772b349d2..6f37995ac1e 100644 --- a/compiler/rustc_infer/src/infer/relate/glb.rs +++ b/compiler/rustc_infer/src/infer/relate/glb.rs @@ -1,6 +1,7 @@ //! Greatest lower bound. See [`lattice`]. -use super::{Relate, RelateResult, TypeRelation}; +use rustc_middle::traits::solve::Goal; +use rustc_middle::ty::relate::{Relate, RelateResult, TypeRelation}; use rustc_middle::ty::{self, Ty, TyCtxt, TypeVisitableExt}; use rustc_span::Span; @@ -8,7 +9,7 @@ use super::combine::{CombineFields, ObligationEmittingRelation}; use super::lattice::{self, LatticeDir}; use super::StructurallyRelateAliases; use crate::infer::{DefineOpaqueTypes, InferCtxt, SubregionOrigin}; -use crate::traits::{ObligationCause, PredicateObligation}; +use crate::traits::ObligationCause; /// "Greatest lower bound" (common subtype) pub struct Glb<'combine, 'infcx, 'tcx> { @@ -147,7 +148,10 @@ impl<'tcx> ObligationEmittingRelation<'tcx> for Glb<'_, '_, 'tcx> { self.fields.register_predicates(obligations); } - fn register_obligations(&mut self, obligations: Vec>) { + fn register_obligations( + &mut self, + obligations: impl IntoIterator>>, + ) { self.fields.register_obligations(obligations); } diff --git a/compiler/rustc_infer/src/infer/relate/lattice.rs b/compiler/rustc_infer/src/infer/relate/lattice.rs index f05b984142a..8c6f1690ade 100644 --- a/compiler/rustc_infer/src/infer/relate/lattice.rs +++ b/compiler/rustc_infer/src/infer/relate/lattice.rs @@ -21,7 +21,8 @@ use super::combine::ObligationEmittingRelation; use crate::infer::{DefineOpaqueTypes, InferCtxt}; use crate::traits::ObligationCause; -use super::RelateResult; +use rustc_middle::traits::solve::Goal; +use rustc_middle::ty::relate::RelateResult; use rustc_middle::ty::TyVar; use rustc_middle::ty::{self, Ty}; @@ -109,7 +110,11 @@ where && !this.infcx().next_trait_solver() => { this.register_obligations( - infcx.handle_opaque_type(a, b, this.cause(), this.param_env())?.obligations, + infcx + .handle_opaque_type(a, b, this.cause(), this.param_env())? + .obligations + .into_iter() + .map(Goal::from), ); Ok(a) } diff --git a/compiler/rustc_infer/src/infer/relate/lub.rs b/compiler/rustc_infer/src/infer/relate/lub.rs index 0b9de8de001..625cc02115a 100644 --- a/compiler/rustc_infer/src/infer/relate/lub.rs +++ b/compiler/rustc_infer/src/infer/relate/lub.rs @@ -4,9 +4,10 @@ use super::combine::{CombineFields, ObligationEmittingRelation}; use super::lattice::{self, LatticeDir}; use super::StructurallyRelateAliases; use crate::infer::{DefineOpaqueTypes, InferCtxt, SubregionOrigin}; -use crate::traits::{ObligationCause, PredicateObligation}; +use crate::traits::ObligationCause; -use super::{Relate, RelateResult, TypeRelation}; +use rustc_middle::traits::solve::Goal; +use rustc_middle::ty::relate::{Relate, RelateResult, TypeRelation}; use rustc_middle::ty::{self, Ty, TyCtxt, TypeVisitableExt}; use rustc_span::Span; @@ -147,7 +148,10 @@ impl<'tcx> ObligationEmittingRelation<'tcx> for Lub<'_, '_, 'tcx> { self.fields.register_predicates(obligations); } - fn register_obligations(&mut self, obligations: Vec>) { + fn register_obligations( + &mut self, + obligations: impl IntoIterator>>, + ) { self.fields.register_obligations(obligations) } diff --git a/compiler/rustc_infer/src/infer/relate/type_relating.rs b/compiler/rustc_infer/src/infer/relate/type_relating.rs index 447e4d6bfd8..328e4d8902f 100644 --- a/compiler/rustc_infer/src/infer/relate/type_relating.rs +++ b/compiler/rustc_infer/src/infer/relate/type_relating.rs @@ -1,11 +1,10 @@ use super::combine::CombineFields; +use crate::infer::relate::{ObligationEmittingRelation, StructurallyRelateAliases}; use crate::infer::BoundRegionConversionTime::HigherRankedType; use crate::infer::{DefineOpaqueTypes, SubregionOrigin}; -use crate::traits::{Obligation, PredicateObligation}; - -use super::{ - relate_args_invariantly, relate_args_with_variances, ObligationEmittingRelation, Relate, - RelateResult, StructurallyRelateAliases, TypeRelation, +use rustc_middle::traits::solve::Goal; +use rustc_middle::ty::relate::{ + relate_args_invariantly, relate_args_with_variances, Relate, RelateResult, TypeRelation, }; use rustc_middle::ty::TyVar; use rustc_middle::ty::{self, Ty, TyCtxt}; @@ -88,9 +87,8 @@ impl<'tcx> TypeRelation> for TypeRelating<'_, '_, 'tcx> { ty::Covariant => { // can't make progress on `A <: B` if both A and B are // type variables, so record an obligation. - self.fields.obligations.push(Obligation::new( + self.fields.obligations.push(Goal::new( self.tcx(), - self.fields.trace.cause.clone(), self.fields.param_env, ty::Binder::dummy(ty::PredicateKind::Subtype(ty::SubtypePredicate { a_is_expected: true, @@ -102,9 +100,8 @@ impl<'tcx> TypeRelation> for TypeRelating<'_, '_, 'tcx> { ty::Contravariant => { // can't make progress on `B <: A` if both A and B are // type variables, so record an obligation. - self.fields.obligations.push(Obligation::new( + self.fields.obligations.push(Goal::new( self.tcx(), - self.fields.trace.cause.clone(), self.fields.param_env, ty::Binder::dummy(ty::PredicateKind::Subtype(ty::SubtypePredicate { a_is_expected: false, @@ -153,10 +150,13 @@ impl<'tcx> TypeRelation> for TypeRelating<'_, '_, 'tcx> { && def_id.is_local() && !infcx.next_trait_solver() => { + // FIXME: Don't shuttle between Goal and Obligation self.fields.obligations.extend( infcx .handle_opaque_type(a, b, &self.fields.trace.cause, self.param_env())? - .obligations, + .obligations + .into_iter() + .map(Goal::from), ); } @@ -318,7 +318,10 @@ impl<'tcx> ObligationEmittingRelation<'tcx> for TypeRelating<'_, '_, 'tcx> { self.fields.register_predicates(obligations); } - fn register_obligations(&mut self, obligations: Vec>) { + fn register_obligations( + &mut self, + obligations: impl IntoIterator>>, + ) { self.fields.register_obligations(obligations); }