From 5dc3fd7c05010654b4d80b87eac8c937e4808607 Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Tue, 21 Mar 2023 21:50:16 +0000 Subject: [PATCH] Include relation direction in AliasEq predicate --- compiler/rustc_infer/src/infer/combine.rs | 8 +- compiler/rustc_infer/src/infer/equate.rs | 4 + compiler/rustc_infer/src/infer/glb.rs | 5 + compiler/rustc_infer/src/infer/lub.rs | 5 + .../rustc_infer/src/infer/nll_relate/mod.rs | 10 ++ compiler/rustc_infer/src/infer/sub.rs | 4 + compiler/rustc_middle/src/ty/flags.rs | 2 +- compiler/rustc_middle/src/ty/mod.rs | 24 ++++- compiler/rustc_middle/src/ty/print/pretty.rs | 3 +- .../rustc_middle/src/ty/structural_impls.rs | 4 +- compiler/rustc_privacy/src/lib.rs | 2 +- .../src/solve/eval_ctxt.rs | 4 +- .../src/solve/fulfill.rs | 2 +- .../rustc_trait_selection/src/solve/mod.rs | 91 ++++++++++++++----- .../traits/error_reporting/method_chain.rs | 5 + tests/ui/traits/new-solver/alias-sub.rs | 34 +++++++ 16 files changed, 173 insertions(+), 34 deletions(-) create mode 100644 tests/ui/traits/new-solver/alias-sub.rs diff --git a/compiler/rustc_infer/src/infer/combine.rs b/compiler/rustc_infer/src/infer/combine.rs index 4503af03ca3..c82438f05e1 100644 --- a/compiler/rustc_infer/src/infer/combine.rs +++ b/compiler/rustc_infer/src/infer/combine.rs @@ -842,7 +842,7 @@ pub trait ObligationEmittingRelation<'tcx>: TypeRelation<'tcx> { let (a, b) = if self.a_is_expected() { (a, b) } else { (b, a) }; self.register_predicates([ty::Binder::dummy(if self.tcx().trait_solver_next() { - ty::PredicateKind::AliasEq(a.into(), b.into()) + ty::PredicateKind::AliasEq(a.into(), b.into(), ty::AliasRelationDirection::Equate) } else { ty::PredicateKind::ConstEquate(a, b) })]); @@ -852,13 +852,15 @@ pub trait ObligationEmittingRelation<'tcx>: TypeRelation<'tcx> { /// /// If they aren't equal then the relation doesn't hold. fn register_type_equate_obligation(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) { - let (a, b) = if self.a_is_expected() { (a, b) } else { (b, a) }; - self.register_predicates([ty::Binder::dummy(ty::PredicateKind::AliasEq( a.into(), b.into(), + self.alias_relate_direction(), ))]); } + + /// Relation direction emitted for `AliasEq` predicates + fn alias_relate_direction(&self) -> ty::AliasRelationDirection; } fn int_unification_error<'tcx>( diff --git a/compiler/rustc_infer/src/infer/equate.rs b/compiler/rustc_infer/src/infer/equate.rs index c92a74b6241..38002357cde 100644 --- a/compiler/rustc_infer/src/infer/equate.rs +++ b/compiler/rustc_infer/src/infer/equate.rs @@ -210,4 +210,8 @@ impl<'tcx> ObligationEmittingRelation<'tcx> for Equate<'_, '_, 'tcx> { fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>) { self.fields.register_obligations(obligations); } + + fn alias_relate_direction(&self) -> ty::AliasRelationDirection { + ty::AliasRelationDirection::Equate + } } diff --git a/compiler/rustc_infer/src/infer/glb.rs b/compiler/rustc_infer/src/infer/glb.rs index 5c12351226a..6395c4d4b20 100644 --- a/compiler/rustc_infer/src/infer/glb.rs +++ b/compiler/rustc_infer/src/infer/glb.rs @@ -155,4 +155,9 @@ impl<'tcx> ObligationEmittingRelation<'tcx> for Glb<'_, '_, 'tcx> { fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>) { self.fields.register_obligations(obligations); } + + fn alias_relate_direction(&self) -> ty::AliasRelationDirection { + // FIXME(deferred_projection_equality): This isn't right, I think? + ty::AliasRelationDirection::Equate + } } diff --git a/compiler/rustc_infer/src/infer/lub.rs b/compiler/rustc_infer/src/infer/lub.rs index dbef42db8f1..98cbd4c561c 100644 --- a/compiler/rustc_infer/src/infer/lub.rs +++ b/compiler/rustc_infer/src/infer/lub.rs @@ -155,4 +155,9 @@ impl<'tcx> ObligationEmittingRelation<'tcx> for Lub<'_, '_, 'tcx> { fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>) { self.fields.register_obligations(obligations) } + + fn alias_relate_direction(&self) -> ty::AliasRelationDirection { + // FIXME(deferred_projection_equality): This isn't right, I think? + ty::AliasRelationDirection::Equate + } } diff --git a/compiler/rustc_infer/src/infer/nll_relate/mod.rs b/compiler/rustc_infer/src/infer/nll_relate/mod.rs index 573cd91a2a2..216d2965da7 100644 --- a/compiler/rustc_infer/src/infer/nll_relate/mod.rs +++ b/compiler/rustc_infer/src/infer/nll_relate/mod.rs @@ -777,6 +777,16 @@ where fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>) { self.delegate.register_obligations(obligations); } + + fn alias_relate_direction(&self) -> ty::AliasRelationDirection { + match self.ambient_variance { + ty::Variance::Covariant => ty::AliasRelationDirection::Subtype, + ty::Variance::Contravariant => ty::AliasRelationDirection::Supertype, + ty::Variance::Invariant => ty::AliasRelationDirection::Equate, + // FIXME(deferred_projection_equality): Implement this when we trigger it + ty::Variance::Bivariant => unreachable!(), + } + } } /// When we encounter a binder like `for<..> fn(..)`, we actually have diff --git a/compiler/rustc_infer/src/infer/sub.rs b/compiler/rustc_infer/src/infer/sub.rs index 230cadb1184..fc73ca7606d 100644 --- a/compiler/rustc_infer/src/infer/sub.rs +++ b/compiler/rustc_infer/src/infer/sub.rs @@ -236,4 +236,8 @@ impl<'tcx> ObligationEmittingRelation<'tcx> for Sub<'_, '_, 'tcx> { fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>) { self.fields.register_obligations(obligations); } + + fn alias_relate_direction(&self) -> ty::AliasRelationDirection { + ty::AliasRelationDirection::Subtype + } } diff --git a/compiler/rustc_middle/src/ty/flags.rs b/compiler/rustc_middle/src/ty/flags.rs index 91241ff404f..1ae47fddccc 100644 --- a/compiler/rustc_middle/src/ty/flags.rs +++ b/compiler/rustc_middle/src/ty/flags.rs @@ -288,7 +288,7 @@ impl FlagComputation { self.add_ty(ty); } ty::PredicateKind::Ambiguous => {} - ty::PredicateKind::AliasEq(t1, t2) => { + ty::PredicateKind::AliasEq(t1, t2, _) => { self.add_term(t1); self.add_term(t2); } diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index d383a413208..d1e268f49c4 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -640,7 +640,25 @@ pub enum PredicateKind<'tcx> { /// This predicate requires two terms to be equal to eachother. /// /// Only used for new solver - AliasEq(Term<'tcx>, Term<'tcx>), + AliasEq(Term<'tcx>, Term<'tcx>, AliasRelationDirection), +} + +#[derive(Clone, Copy, PartialEq, Eq, Hash, TyEncodable, TyDecodable)] +#[derive(HashStable, Debug)] +pub enum AliasRelationDirection { + Equate, + Subtype, + Supertype, +} + +impl AliasRelationDirection { + pub fn invert(self) -> Self { + match self { + AliasRelationDirection::Equate => AliasRelationDirection::Equate, + AliasRelationDirection::Subtype => AliasRelationDirection::Supertype, + AliasRelationDirection::Supertype => AliasRelationDirection::Subtype, + } + } } /// The crate outlives map is computed during typeck and contains the @@ -976,11 +994,11 @@ impl<'tcx> Term<'tcx> { } } - /// This function returns `None` for `AliasKind::Opaque`. + /// This function returns the inner `AliasTy` if this term is a projection. /// /// FIXME: rename `AliasTy` to `AliasTerm` and make sure we correctly /// deal with constants. - pub fn to_alias_term_no_opaque(&self, tcx: TyCtxt<'tcx>) -> Option> { + pub fn to_projection_term(&self, tcx: TyCtxt<'tcx>) -> Option> { match self.unpack() { TermKind::Ty(ty) => match ty.kind() { ty::Alias(kind, alias_ty) => match kind { diff --git a/compiler/rustc_middle/src/ty/print/pretty.rs b/compiler/rustc_middle/src/ty/print/pretty.rs index fffdbfc9660..e96856231ce 100644 --- a/compiler/rustc_middle/src/ty/print/pretty.rs +++ b/compiler/rustc_middle/src/ty/print/pretty.rs @@ -2847,7 +2847,8 @@ define_print_and_forward_display! { p!("the type `", print(ty), "` is found in the environment") } ty::PredicateKind::Ambiguous => p!("ambiguous"), - ty::PredicateKind::AliasEq(t1, t2) => p!(print(t1), " == ", print(t2)), + // TODO + ty::PredicateKind::AliasEq(t1, t2, _) => p!(print(t1), " == ", print(t2)), } } diff --git a/compiler/rustc_middle/src/ty/structural_impls.rs b/compiler/rustc_middle/src/ty/structural_impls.rs index ef643531bb2..92a2b4b18dd 100644 --- a/compiler/rustc_middle/src/ty/structural_impls.rs +++ b/compiler/rustc_middle/src/ty/structural_impls.rs @@ -177,7 +177,8 @@ impl<'tcx> fmt::Debug for ty::PredicateKind<'tcx> { write!(f, "TypeWellFormedFromEnv({:?})", ty) } ty::PredicateKind::Ambiguous => write!(f, "Ambiguous"), - ty::PredicateKind::AliasEq(t1, t2) => write!(f, "AliasEq({t1:?}, {t2:?})"), + // TODO + ty::PredicateKind::AliasEq(t1, t2, _) => write!(f, "AliasEq({t1:?}, {t2:?})"), } } } @@ -250,6 +251,7 @@ TrivialTypeTraversalAndLiftImpls! { crate::ty::AssocItem, crate::ty::AssocKind, crate::ty::AliasKind, + crate::ty::AliasRelationDirection, crate::ty::Placeholder, crate::ty::Placeholder, crate::ty::ClosureKind, diff --git a/compiler/rustc_privacy/src/lib.rs b/compiler/rustc_privacy/src/lib.rs index d884ebd9acc..fda9ef008d7 100644 --- a/compiler/rustc_privacy/src/lib.rs +++ b/compiler/rustc_privacy/src/lib.rs @@ -180,7 +180,7 @@ where | ty::PredicateKind::ConstEquate(_, _) | ty::PredicateKind::TypeWellFormedFromEnv(_) | ty::PredicateKind::Ambiguous - | ty::PredicateKind::AliasEq(_, _) => bug!("unexpected predicate: {:?}", predicate), + | ty::PredicateKind::AliasEq(..) => bug!("unexpected predicate: {:?}", predicate), } } diff --git a/compiler/rustc_trait_selection/src/solve/eval_ctxt.rs b/compiler/rustc_trait_selection/src/solve/eval_ctxt.rs index 95412922357..106a989b22f 100644 --- a/compiler/rustc_trait_selection/src/solve/eval_ctxt.rs +++ b/compiler/rustc_trait_selection/src/solve/eval_ctxt.rs @@ -223,8 +223,8 @@ impl<'a, 'tcx> EvalCtxt<'a, 'tcx> { ty::PredicateKind::TypeWellFormedFromEnv(..) => { bug!("TypeWellFormedFromEnv is only used for Chalk") } - ty::PredicateKind::AliasEq(lhs, rhs) => { - self.compute_alias_eq_goal(Goal { param_env, predicate: (lhs, rhs) }) + ty::PredicateKind::AliasEq(lhs, rhs, direction) => { + self.compute_alias_eq_goal(Goal { param_env, predicate: (lhs, rhs, direction) }) } } } else { diff --git a/compiler/rustc_trait_selection/src/solve/fulfill.rs b/compiler/rustc_trait_selection/src/solve/fulfill.rs index 38120b9760f..a9c9e5761a8 100644 --- a/compiler/rustc_trait_selection/src/solve/fulfill.rs +++ b/compiler/rustc_trait_selection/src/solve/fulfill.rs @@ -73,7 +73,7 @@ impl<'tcx> TraitEngine<'tcx> for FulfillmentCtxt<'tcx> { MismatchedProjectionTypes { err: TypeError::Mismatch }, ) } - ty::PredicateKind::AliasEq(_, _) => { + ty::PredicateKind::AliasEq(_, _, _) => { FulfillmentErrorCode::CodeProjectionError( MismatchedProjectionTypes { err: TypeError::Mismatch }, ) diff --git a/compiler/rustc_trait_selection/src/solve/mod.rs b/compiler/rustc_trait_selection/src/solve/mod.rs index 606c2eaa510..48ef089a7b2 100644 --- a/compiler/rustc_trait_selection/src/solve/mod.rs +++ b/compiler/rustc_trait_selection/src/solve/mod.rs @@ -158,13 +158,35 @@ impl<'a, 'tcx> EvalCtxt<'a, 'tcx> { #[instrument(level = "debug", skip(self), ret)] fn compute_alias_eq_goal( &mut self, - goal: Goal<'tcx, (ty::Term<'tcx>, ty::Term<'tcx>)>, + goal: Goal<'tcx, (ty::Term<'tcx>, ty::Term<'tcx>, ty::AliasRelationDirection)>, ) -> QueryResult<'tcx> { let tcx = self.tcx(); - let evaluate_normalizes_to = |ecx: &mut EvalCtxt<'_, 'tcx>, alias, other| { + let evaluate_normalizes_to = |ecx: &mut EvalCtxt<'_, 'tcx>, alias, other, direction| { debug!("evaluate_normalizes_to(alias={:?}, other={:?})", alias, other); - let r = ecx.probe(|ecx| { + let result = ecx.probe(|ecx| { + let other = match direction { + // This is purely an optimization. + ty::AliasRelationDirection::Equate => other, + + ty::AliasRelationDirection::Subtype | ty::AliasRelationDirection::Supertype => { + let fresh = ecx.next_term_infer_of_kind(other); + let (sub, sup) = if direction == ty::AliasRelationDirection::Subtype { + (fresh, other) + } else { + (other, fresh) + }; + ecx.add_goals( + ecx.infcx + .at(&ObligationCause::dummy(), goal.param_env) + .sub(DefineOpaqueTypes::No, sub, sup)? + .into_obligations() + .into_iter() + .map(|o| o.into()), + ); + fresh + } + }; ecx.add_goal(goal.with( tcx, ty::Binder::dummy(ty::ProjectionPredicate { @@ -174,37 +196,64 @@ impl<'a, 'tcx> EvalCtxt<'a, 'tcx> { )); ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes) }); - debug!("evaluate_normalizes_to(..) -> {:?}", r); - r + debug!("evaluate_normalizes_to({alias}, {other}, {direction:?}) -> {result:?}"); + result }; - if goal.predicate.0.is_infer() || goal.predicate.1.is_infer() { + let (lhs, rhs, direction) = goal.predicate; + + if lhs.is_infer() || rhs.is_infer() { bug!( "`AliasEq` goal with an infer var on lhs or rhs which should have been instantiated" ); } - match ( - goal.predicate.0.to_alias_term_no_opaque(tcx), - goal.predicate.1.to_alias_term_no_opaque(tcx), - ) { + match (lhs.to_projection_term(tcx), rhs.to_projection_term(tcx)) { (None, None) => bug!("`AliasEq` goal without an alias on either lhs or rhs"), - (Some(alias), None) => evaluate_normalizes_to(self, alias, goal.predicate.1), - (None, Some(alias)) => evaluate_normalizes_to(self, alias, goal.predicate.0), + + // RHS is not a projection, only way this is true is if LHS normalizes-to RHS + (Some(alias_lhs), None) => evaluate_normalizes_to(self, alias_lhs, rhs, direction), + + // LHS is not a projection, only way this is true is if RHS normalizes-to LHS + (None, Some(alias_rhs)) => { + evaluate_normalizes_to(self, alias_rhs, lhs, direction.invert()) + } + (Some(alias_lhs), Some(alias_rhs)) => { debug!("compute_alias_eq_goal: both sides are aliases"); - let mut candidates = Vec::with_capacity(3); + let candidates = vec![ + // LHS normalizes-to RHS + evaluate_normalizes_to(self, alias_lhs, rhs, direction), + // RHS normalizes-to RHS + evaluate_normalizes_to(self, alias_rhs, lhs, direction.invert()), + // Relate via substs + self.probe(|ecx| { + debug!("compute_alias_eq_goal: alias defids are equal, equating substs"); - // Evaluate all 3 potential candidates for the alias' being equal - candidates.push(evaluate_normalizes_to(self, alias_lhs, goal.predicate.1)); - candidates.push(evaluate_normalizes_to(self, alias_rhs, goal.predicate.0)); - candidates.push(self.probe(|ecx| { - debug!("compute_alias_eq_goal: alias defids are equal, equating substs"); - ecx.eq(goal.param_env, alias_lhs, alias_rhs)?; - ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes) - })); + ecx.add_goals( + match direction { + ty::AliasRelationDirection::Equate => ecx + .infcx + .at(&ObligationCause::dummy(), goal.param_env) + .eq(DefineOpaqueTypes::No, alias_lhs, alias_rhs), + ty::AliasRelationDirection::Subtype => ecx + .infcx + .at(&ObligationCause::dummy(), goal.param_env) + .sub(DefineOpaqueTypes::No, alias_lhs, alias_rhs), + ty::AliasRelationDirection::Supertype => ecx + .infcx + .at(&ObligationCause::dummy(), goal.param_env) + .sup(DefineOpaqueTypes::No, alias_lhs, alias_rhs), + }? + .into_obligations() + .into_iter() + .map(|o| o.into()), + ); + ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes) + }), + ]; debug!(?candidates); self.try_merge_responses(candidates.into_iter()) diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/method_chain.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/method_chain.rs index 1174efdbfa8..13607b9079a 100644 --- a/compiler/rustc_trait_selection/src/traits/error_reporting/method_chain.rs +++ b/compiler/rustc_trait_selection/src/traits/error_reporting/method_chain.rs @@ -92,6 +92,11 @@ impl<'a, 'tcx> TypeRelation<'tcx> for CollectAllMismatches<'a, 'tcx> { } impl<'tcx> ObligationEmittingRelation<'tcx> for CollectAllMismatches<'_, 'tcx> { + fn alias_relate_direction(&self) -> ty::AliasRelationDirection { + // FIXME(deferred_projection_equality): We really should get rid of this relation. + ty::AliasRelationDirection::Equate + } + fn register_obligations(&mut self, _obligations: PredicateObligations<'tcx>) { // FIXME(deferred_projection_equality) } diff --git a/tests/ui/traits/new-solver/alias-sub.rs b/tests/ui/traits/new-solver/alias-sub.rs new file mode 100644 index 00000000000..30c1981a92e --- /dev/null +++ b/tests/ui/traits/new-solver/alias-sub.rs @@ -0,0 +1,34 @@ +// compile-flags: -Ztrait-solver=next +// check-pass + +trait Trait { + type Assoc: Sized; +} + +impl Trait for &'static str { + type Assoc = &'static str; +} + +// Wrapper is just here to get around stupid `Sized` obligations in mir typeck +struct Wrapper(std::marker::PhantomData); +fn mk(x: T) -> Wrapper<::Assoc> { todo!() } + + +trait IsStaticStr {} +impl IsStaticStr for (&'static str,) {} +fn define(_: T) {} + +fn foo<'a, T: Trait>() { + let y = Default::default(); + + // `::Assoc <: &'a str` + // In the old solver, this would *equate* the LHS and RHS. + let _: Wrapper<&'a str> = mk(y); + + // ... then later on, we constrain `?0 = &'static str` + // but that should not mean that `'a = 'static`, because + // we should use *sub* above. + define((y,)); +} + +fn main() {}