diff --git a/compiler/rustc_infer/src/infer/context.rs b/compiler/rustc_infer/src/infer/context.rs index 10b22c98a0d..497be6b5404 100644 --- a/compiler/rustc_infer/src/infer/context.rs +++ b/compiler/rustc_infer/src/infer/context.rs @@ -136,6 +136,10 @@ impl<'tcx> rustc_type_ir::InferCtxtLike for InferCtxt<'tcx> { self.enter_forall(value, f) } + fn equate_ty_vids_raw(&self, a: rustc_type_ir::TyVid, b: rustc_type_ir::TyVid) { + self.inner.borrow_mut().type_variables().equate(a, b); + } + fn equate_int_vids_raw(&self, a: rustc_type_ir::IntVid, b: rustc_type_ir::IntVid) { self.inner.borrow_mut().int_unification_table().union(a, b); } @@ -152,6 +156,23 @@ impl<'tcx> rustc_type_ir::InferCtxtLike for InferCtxt<'tcx> { self.inner.borrow_mut().effect_unification_table().union(a, b); } + fn instantiate_ty_var_raw>( + &self, + relation: &mut R, + target_is_expected: bool, + target_vid: rustc_type_ir::TyVid, + instantiation_variance: rustc_type_ir::Variance, + source_ty: Ty<'tcx>, + ) -> RelateResult<'tcx, ()> { + self.instantiate_ty_var( + relation, + target_is_expected, + target_vid, + instantiation_variance, + source_ty, + ) + } + fn instantiate_int_var_raw( &self, vid: rustc_type_ir::IntVid, @@ -228,7 +249,19 @@ impl<'tcx> rustc_type_ir::InferCtxtLike for InferCtxt<'tcx> { } fn sub_regions(&self, sub: ty::Region<'tcx>, sup: ty::Region<'tcx>) { - self.sub_regions(SubregionOrigin::RelateRegionParamBound(DUMMY_SP, None), sub, sup) + self.inner.borrow_mut().unwrap_region_constraints().make_subregion( + SubregionOrigin::RelateRegionParamBound(DUMMY_SP, None), + sub, + sup, + ); + } + + fn equate_regions(&self, a: ty::Region<'tcx>, b: ty::Region<'tcx>) { + self.inner.borrow_mut().unwrap_region_constraints().make_eqregion( + SubregionOrigin::RelateRegionParamBound(DUMMY_SP, None), + a, + b, + ); } fn register_ty_outlives(&self, ty: Ty<'tcx>, r: ty::Region<'tcx>) { diff --git a/compiler/rustc_middle/src/ty/context.rs b/compiler/rustc_middle/src/ty/context.rs index 4a02fce5e7d..90265f67bc1 100644 --- a/compiler/rustc_middle/src/ty/context.rs +++ b/compiler/rustc_middle/src/ty/context.rs @@ -698,6 +698,12 @@ impl<'tcx> rustc_type_ir::inherent::Features> for &'tcx rustc_featu } } +impl<'tcx> rustc_type_ir::inherent::Span> for Span { + fn dummy() -> Self { + DUMMY_SP + } +} + type InternedSet<'tcx, T> = ShardedHashMap, ()>; pub struct CtxtInterners<'tcx> { diff --git a/compiler/rustc_type_ir/src/infer_ctxt.rs b/compiler/rustc_type_ir/src/infer_ctxt.rs index 2f6476c6e23..4b44b5a495e 100644 --- a/compiler/rustc_type_ir/src/infer_ctxt.rs +++ b/compiler/rustc_type_ir/src/infer_ctxt.rs @@ -67,11 +67,20 @@ pub trait InferCtxtLike: Sized { f: impl FnOnce(T) -> U, ) -> U; + fn equate_ty_vids_raw(&self, a: ty::TyVid, b: ty::TyVid); fn equate_int_vids_raw(&self, a: ty::IntVid, b: ty::IntVid); fn equate_float_vids_raw(&self, a: ty::FloatVid, b: ty::FloatVid); fn equate_const_vids_raw(&self, a: ty::ConstVid, b: ty::ConstVid); fn equate_effect_vids_raw(&self, a: ty::EffectVid, b: ty::EffectVid); + fn instantiate_ty_var_raw>( + &self, + relation: &mut R, + target_is_expected: bool, + target_vid: ty::TyVid, + instantiation_variance: ty::Variance, + source_ty: ::Ty, + ) -> RelateResult; fn instantiate_int_var_raw(&self, vid: ty::IntVid, value: ty::IntVarValue); fn instantiate_float_var_raw(&self, vid: ty::FloatVid, value: ty::FloatVarValue); fn instantiate_effect_var_raw( @@ -125,6 +134,12 @@ pub trait InferCtxtLike: Sized { sup: ::Region, ); + fn equate_regions( + &self, + a: ::Region, + b: ::Region, + ); + fn register_ty_outlives( &self, ty: ::Ty, diff --git a/compiler/rustc_type_ir/src/inherent.rs b/compiler/rustc_type_ir/src/inherent.rs index 69665df4bfc..f7875bb5152 100644 --- a/compiler/rustc_type_ir/src/inherent.rs +++ b/compiler/rustc_type_ir/src/inherent.rs @@ -565,6 +565,10 @@ pub trait BoundExistentialPredicates: ) -> impl IntoIterator>>; } +pub trait Span: Copy + Debug + Hash + Eq + TypeFoldable { + fn dummy() -> Self; +} + pub trait SliceLike: Sized + Copy { type Item: Copy; type IntoIter: Iterator; diff --git a/compiler/rustc_type_ir/src/interner.rs b/compiler/rustc_type_ir/src/interner.rs index b2ac67efef6..a72e7b482a6 100644 --- a/compiler/rustc_type_ir/src/interner.rs +++ b/compiler/rustc_type_ir/src/interner.rs @@ -36,7 +36,7 @@ pub trait Interner: { type DefId: DefId; type LocalDefId: Copy + Debug + Hash + Eq + Into + TypeFoldable; - type Span: Copy + Debug + Hash + Eq + TypeFoldable; + type Span: Span; type GenericArgs: GenericArgs; type GenericArgsSlice: Copy + Debug + Hash + Eq + SliceLike; diff --git a/compiler/rustc_type_ir/src/lib.rs b/compiler/rustc_type_ir/src/lib.rs index 02a9ad1e35f..51c887fc4da 100644 --- a/compiler/rustc_type_ir/src/lib.rs +++ b/compiler/rustc_type_ir/src/lib.rs @@ -206,8 +206,8 @@ pub fn debug_bound_var( } } -#[derive(Copy, Clone, PartialEq, Eq)] -#[cfg_attr(feature = "nightly", derive(Decodable, Encodable, Hash, HashStable_NoContext))] +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "nightly", derive(Decodable, Encodable, HashStable_NoContext))] #[cfg_attr(feature = "nightly", rustc_pass_by_value)] pub enum Variance { Covariant, // T <: T iff A <: B -- e.g., function return type diff --git a/compiler/rustc_type_ir/src/relate.rs b/compiler/rustc_type_ir/src/relate.rs index 8e5d917e915..1302906adab 100644 --- a/compiler/rustc_type_ir/src/relate.rs +++ b/compiler/rustc_type_ir/src/relate.rs @@ -10,6 +10,7 @@ use crate::inherent::*; use crate::{self as ty, Interner}; pub mod combine; +pub mod solver_relating; pub type RelateResult = Result>; diff --git a/compiler/rustc_type_ir/src/relate/solver_relating.rs b/compiler/rustc_type_ir/src/relate/solver_relating.rs new file mode 100644 index 00000000000..18a4d5189bb --- /dev/null +++ b/compiler/rustc_type_ir/src/relate/solver_relating.rs @@ -0,0 +1,324 @@ +pub use rustc_type_ir::relate::*; +use rustc_type_ir::solve::Goal; +use rustc_type_ir::{self as ty, InferCtxtLike, Interner}; +use tracing::{debug, instrument}; + +use self::combine::{InferCtxtCombineExt, PredicateEmittingRelation}; +use crate::data_structures::DelayedSet; + +#[allow(unused)] +/// Enforce that `a` is equal to or a subtype of `b`. +pub struct SolverRelating<'infcx, Infcx, I: Interner> { + infcx: &'infcx Infcx, + // Immutable fields. + structurally_relate_aliases: StructurallyRelateAliases, + param_env: I::ParamEnv, + // Mutable fields. + ambient_variance: ty::Variance, + goals: Vec>, + /// The cache only tracks the `ambient_variance` as it's the + /// only field which is mutable and which meaningfully changes + /// the result when relating types. + /// + /// The cache does not track whether the state of the + /// `Infcx` has been changed or whether we've added any + /// goals to `self.goals`. Whether a goal is added once or multiple + /// times is not really meaningful. + /// + /// Changes in the inference state may delay some type inference to + /// the next fulfillment loop. Given that this loop is already + /// necessary, this is also not a meaningful change. Consider + /// the following three relations: + /// ```text + /// Vec sub Vec + /// ?0 eq u32 + /// Vec sub Vec + /// ``` + /// Without a cache, the second `Vec sub Vec` would eagerly + /// constrain `?1` to `u32`. When using the cache entry from the + /// first time we've related these types, this only happens when + /// later proving the `Subtype(?0, ?1)` goal from the first relation. + cache: DelayedSet<(ty::Variance, I::Ty, I::Ty)>, +} + +impl<'infcx, Infcx, I> SolverRelating<'infcx, Infcx, I> +where + Infcx: InferCtxtLike, + I: Interner, +{ +} + +impl TypeRelation for SolverRelating<'_, Infcx, I> +where + Infcx: InferCtxtLike, + I: Interner, +{ + fn cx(&self) -> I { + self.infcx.cx() + } + + fn relate_item_args( + &mut self, + item_def_id: I::DefId, + a_arg: I::GenericArgs, + b_arg: I::GenericArgs, + ) -> RelateResult { + if self.ambient_variance == ty::Invariant { + // Avoid fetching the variance if we are in an invariant + // context; no need, and it can induce dependency cycles + // (e.g., #41849). + relate_args_invariantly(self, a_arg, b_arg) + } else { + let tcx = self.cx(); + let opt_variances = tcx.variances_of(item_def_id); + relate_args_with_variances(self, item_def_id, opt_variances, a_arg, b_arg, false) + } + } + + fn relate_with_variance>( + &mut self, + variance: ty::Variance, + _info: VarianceDiagInfo, + a: T, + b: T, + ) -> RelateResult { + let old_ambient_variance = self.ambient_variance; + self.ambient_variance = self.ambient_variance.xform(variance); + debug!(?self.ambient_variance, "new ambient variance"); + + let r = if self.ambient_variance == ty::Bivariant { Ok(a) } else { self.relate(a, b) }; + + self.ambient_variance = old_ambient_variance; + r + } + + #[instrument(skip(self), level = "trace")] + fn tys(&mut self, a: I::Ty, b: I::Ty) -> RelateResult { + if a == b { + return Ok(a); + } + + let infcx = self.infcx; + let a = infcx.shallow_resolve(a); + let b = infcx.shallow_resolve(b); + + if self.cache.contains(&(self.ambient_variance, a, b)) { + return Ok(a); + } + + match (a.kind(), b.kind()) { + (ty::Infer(ty::TyVar(a_id)), ty::Infer(ty::TyVar(b_id))) => { + match self.ambient_variance { + ty::Covariant => { + // can't make progress on `A <: B` if both A and B are + // type variables, so record an obligation. + self.goals.push(Goal::new( + self.cx(), + self.param_env, + ty::Binder::dummy(ty::PredicateKind::Subtype(ty::SubtypePredicate { + a_is_expected: true, + a, + b, + })), + )); + } + ty::Contravariant => { + // can't make progress on `B <: A` if both A and B are + // type variables, so record an obligation. + self.goals.push(Goal::new( + self.cx(), + self.param_env, + ty::Binder::dummy(ty::PredicateKind::Subtype(ty::SubtypePredicate { + a_is_expected: false, + a: b, + b: a, + })), + )); + } + ty::Invariant => { + infcx.equate_ty_vids_raw(a_id, b_id); + } + ty::Bivariant => { + unreachable!("Expected bivariance to be handled in relate_with_variance") + } + } + } + + (ty::Infer(ty::TyVar(a_vid)), _) => { + infcx.instantiate_ty_var_raw(self, true, a_vid, self.ambient_variance, b)?; + } + (_, ty::Infer(ty::TyVar(b_vid))) => { + infcx.instantiate_ty_var_raw( + self, + false, + b_vid, + self.ambient_variance.xform(ty::Contravariant), + a, + )?; + } + + _ => { + self.infcx.super_combine_tys(self, a, b)?; + } + } + + assert!(self.cache.insert((self.ambient_variance, a, b))); + + Ok(a) + } + + #[instrument(skip(self), level = "trace")] + fn regions(&mut self, a: I::Region, b: I::Region) -> RelateResult { + match self.ambient_variance { + // Subtype(&'a u8, &'b u8) => Outlives('a: 'b) => SubRegion('b, 'a) + ty::Covariant => self.infcx.sub_regions(b, a), + // Suptype(&'a u8, &'b u8) => Outlives('b: 'a) => SubRegion('a, 'b) + ty::Contravariant => self.infcx.sub_regions(a, b), + ty::Invariant => self.infcx.equate_regions(a, b), + ty::Bivariant => { + unreachable!("Expected bivariance to be handled in relate_with_variance") + } + } + + Ok(a) + } + + #[instrument(skip(self), level = "trace")] + fn consts(&mut self, a: I::Const, b: I::Const) -> RelateResult { + self.infcx.super_combine_consts(self, a, b) + } + + fn binders( + &mut self, + a: ty::Binder, + b: ty::Binder, + ) -> RelateResult> + where + T: Relate, + { + // If they're equal, then short-circuit. + if a == b { + return Ok(a); + } + + // If they have no bound vars, relate normally. + if let Some(a_inner) = a.no_bound_vars() { + if let Some(b_inner) = b.no_bound_vars() { + self.relate(a_inner, b_inner)?; + return Ok(a); + } + }; + + match self.ambient_variance { + // Checks whether `for<..> sub <: for<..> sup` holds. + // + // For this to hold, **all** instantiations of the super type + // have to be a super type of **at least one** instantiation of + // the subtype. + // + // This is implemented by first entering a new universe. + // We then replace all bound variables in `sup` with placeholders, + // and all bound variables in `sub` with inference vars. + // We can then just relate the two resulting types as normal. + // + // Note: this is a subtle algorithm. For a full explanation, please see + // the [rustc dev guide][rd] + // + // [rd]: https://rustc-dev-guide.rust-lang.org/borrow_check/region_inference/placeholders_and_universes.html + ty::Covariant => { + self.infcx.enter_forall(b, |b| { + let a = self.infcx.instantiate_binder_with_infer(a); + self.relate(a, b) + })?; + } + ty::Contravariant => { + self.infcx.enter_forall(a, |a| { + let b = self.infcx.instantiate_binder_with_infer(b); + self.relate(a, b) + })?; + } + + // When **equating** binders, we check that there is a 1-to-1 + // correspondence between the bound vars in both types. + // + // We do so by separately instantiating one of the binders with + // placeholders and the other with inference variables and then + // equating the instantiated types. + // + // We want `for<..> A == for<..> B` -- therefore we want + // `exists<..> A == for<..> B` and `exists<..> B == for<..> A`. + // Check if `exists<..> A == for<..> B` + ty::Invariant => { + self.infcx.enter_forall(b, |b| { + let a = self.infcx.instantiate_binder_with_infer(a); + self.relate(a, b) + })?; + + // Check if `exists<..> B == for<..> A`. + self.infcx.enter_forall(a, |a| { + let b = self.infcx.instantiate_binder_with_infer(b); + self.relate(a, b) + })?; + } + ty::Bivariant => { + unreachable!("Expected bivariance to be handled in relate_with_variance") + } + } + Ok(a) + } +} + +impl PredicateEmittingRelation for SolverRelating<'_, Infcx, I> +where + Infcx: InferCtxtLike, + I: Interner, +{ + fn span(&self) -> I::Span { + Span::dummy() + } + + fn param_env(&self) -> I::ParamEnv { + self.param_env + } + + fn structurally_relate_aliases(&self) -> StructurallyRelateAliases { + self.structurally_relate_aliases + } + + fn register_predicates( + &mut self, + obligations: impl IntoIterator>, + ) { + self.goals.extend( + obligations.into_iter().map(|pred| Goal::new(self.infcx.cx(), self.param_env, pred)), + ); + } + + fn register_goals(&mut self, obligations: impl IntoIterator>) { + self.goals.extend(obligations); + } + + fn register_alias_relate_predicate(&mut self, a: I::Ty, b: I::Ty) { + self.register_predicates([ty::Binder::dummy(match self.ambient_variance { + ty::Covariant => ty::PredicateKind::AliasRelate( + a.into(), + b.into(), + ty::AliasRelationDirection::Subtype, + ), + // a :> b is b <: a + ty::Contravariant => ty::PredicateKind::AliasRelate( + b.into(), + a.into(), + ty::AliasRelationDirection::Subtype, + ), + ty::Invariant => ty::PredicateKind::AliasRelate( + a.into(), + b.into(), + ty::AliasRelationDirection::Equate, + ), + ty::Bivariant => { + unreachable!("Expected bivariance to be handled in relate_with_variance") + } + })]); + } +}