From 357665dae90b28b670ee343f012620183cfc9c2b Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Thu, 21 Nov 2024 19:06:52 +0000 Subject: [PATCH] Simplify fulfill_implication --- .../src/impl_wf_check/min_specialization.rs | 9 +- .../src/traits/specialize/mod.rs | 191 +++++++++--------- .../rustc_trait_selection/src/traits/util.rs | 35 +--- ...pecialize_with_generalize_lifetimes.stderr | 12 +- 4 files changed, 106 insertions(+), 141 deletions(-) diff --git a/compiler/rustc_hir_analysis/src/impl_wf_check/min_specialization.rs b/compiler/rustc_hir_analysis/src/impl_wf_check/min_specialization.rs index 34effd199f1..246643d8074 100644 --- a/compiler/rustc_hir_analysis/src/impl_wf_check/min_specialization.rs +++ b/compiler/rustc_hir_analysis/src/impl_wf_check/min_specialization.rs @@ -70,6 +70,7 @@ use rustc_hir as hir; use rustc_hir::def_id::{DefId, LocalDefId}; use rustc_infer::infer::TyCtxtInferExt; use rustc_infer::infer::outlives::env::OutlivesEnvironment; +use rustc_infer::traits::ObligationCause; use rustc_infer::traits::specialization_graph::Node; use rustc_middle::ty::trait_def::TraitSpecializationKind; use rustc_middle::ty::{ @@ -210,13 +211,7 @@ fn get_impl_args( impl1_def_id.to_def_id(), impl1_args, impl2_node, - |_, span| { - traits::ObligationCause::new( - impl1_span, - impl1_def_id, - traits::ObligationCauseCode::WhereClause(impl2_node.def_id(), span), - ) - }, + &ObligationCause::misc(impl1_span, impl1_def_id), ); let errors = ocx.select_all_or_error(); diff --git a/compiler/rustc_trait_selection/src/traits/specialize/mod.rs b/compiler/rustc_trait_selection/src/traits/specialize/mod.rs index 5bf3dbcbc32..a9cd705465e 100644 --- a/compiler/rustc_trait_selection/src/traits/specialize/mod.rs +++ b/compiler/rustc_trait_selection/src/traits/specialize/mod.rs @@ -15,24 +15,24 @@ use rustc_data_structures::fx::FxIndexSet; use rustc_errors::codes::*; use rustc_errors::{Diag, EmissionGuarantee}; use rustc_hir::def_id::{DefId, LocalDefId}; -use rustc_infer::infer::DefineOpaqueTypes; use rustc_middle::bug; use rustc_middle::query::LocalCrate; use rustc_middle::ty::print::PrintTraitRefExt as _; -use rustc_middle::ty::{ - self, GenericArgsRef, ImplSubject, Ty, TyCtxt, TypeVisitableExt, TypingMode, -}; +use rustc_middle::ty::{self, GenericArgsRef, Ty, TyCtxt, TypeVisitableExt, TypingMode}; use rustc_session::lint::builtin::{COHERENCE_LEAK_CHECK, ORDER_DEPENDENT_TRAIT_OBJECTS}; use rustc_span::{DUMMY_SP, ErrorGuaranteed, Span, sym}; +use rustc_type_ir::solve::NoSolution; use specialization_graph::GraphExt; use tracing::{debug, instrument}; -use super::{SelectionContext, util}; use crate::error_reporting::traits::to_pretty_impl_header; use crate::errors::NegativePositiveConflict; -use crate::infer::{InferCtxt, InferOk, TyCtxtInferExt}; +use crate::infer::{InferCtxt, TyCtxtInferExt}; use crate::traits::select::IntercrateAmbiguityCause; -use crate::traits::{FutureCompatOverlapErrorKind, ObligationCause, ObligationCtxt, coherence}; +use crate::traits::{ + FutureCompatOverlapErrorKind, ObligationCause, ObligationCtxt, coherence, + predicates_for_generics, +}; /// Information pertinent to an overlapping impl error. #[derive(Debug)] @@ -87,9 +87,14 @@ pub fn translate_args<'tcx>( source_args: GenericArgsRef<'tcx>, target_node: specialization_graph::Node, ) -> GenericArgsRef<'tcx> { - translate_args_with_cause(infcx, param_env, source_impl, source_args, target_node, |_, _| { - ObligationCause::dummy() - }) + translate_args_with_cause( + infcx, + param_env, + source_impl, + source_args, + target_node, + &ObligationCause::dummy(), + ) } /// Like [translate_args], but obligations from the parent implementation @@ -104,7 +109,7 @@ pub fn translate_args_with_cause<'tcx>( source_impl: DefId, source_args: GenericArgsRef<'tcx>, target_node: specialization_graph::Node, - cause: impl Fn(usize, Span) -> ObligationCause<'tcx>, + cause: &ObligationCause<'tcx>, ) -> GenericArgsRef<'tcx> { debug!( "translate_args({:?}, {:?}, {:?}, {:?})", @@ -123,7 +128,7 @@ pub fn translate_args_with_cause<'tcx>( } fulfill_implication(infcx, param_env, source_trait_ref, source_impl, target_impl, cause) - .unwrap_or_else(|()| { + .unwrap_or_else(|_| { bug!( "When translating generic parameters from {source_impl:?} to \ {target_impl:?}, the expected specialization failed to hold" @@ -137,6 +142,84 @@ pub fn translate_args_with_cause<'tcx>( source_args.rebase_onto(infcx.tcx, source_impl, target_args) } +/// Attempt to fulfill all obligations of `target_impl` after unification with +/// `source_trait_ref`. If successful, returns the generic parameters for *all* the +/// generics of `target_impl`, including both those needed to unify with +/// `source_trait_ref` and those whose identity is determined via a where +/// clause in the impl. +fn fulfill_implication<'tcx>( + infcx: &InferCtxt<'tcx>, + param_env: ty::ParamEnv<'tcx>, + source_trait_ref: ty::TraitRef<'tcx>, + source_impl: DefId, + target_impl: DefId, + cause: &ObligationCause<'tcx>, +) -> Result, NoSolution> { + debug!( + "fulfill_implication({:?}, trait_ref={:?} |- {:?} applies)", + param_env, source_trait_ref, target_impl + ); + + let ocx = ObligationCtxt::new(infcx); + let source_trait_ref = ocx.normalize(cause, param_env, source_trait_ref); + + if !ocx.select_all_or_error().is_empty() { + infcx.dcx().span_delayed_bug( + infcx.tcx.def_span(source_impl), + format!("failed to fully normalize {source_trait_ref}"), + ); + return Err(NoSolution); + } + + let target_args = infcx.fresh_args_for_item(DUMMY_SP, target_impl); + let target_trait_ref = ocx.normalize( + cause, + param_env, + infcx + .tcx + .impl_trait_ref(target_impl) + .expect("expected source impl to be a trait impl") + .instantiate(infcx.tcx, target_args), + ); + + // do the impls unify? If not, no specialization. + ocx.eq(cause, param_env, source_trait_ref, target_trait_ref)?; + + // Now check that the source trait ref satisfies all the where clauses of the target impl. + // This is not just for correctness; we also need this to constrain any params that may + // only be referenced via projection predicates. + let predicates = ocx.normalize( + cause, + param_env, + infcx.tcx.predicates_of(target_impl).instantiate(infcx.tcx, target_args), + ); + let obligations = predicates_for_generics(|_, _| cause.clone(), param_env, predicates); + ocx.register_obligations(obligations); + + let errors = ocx.select_all_or_error(); + if !errors.is_empty() { + // no dice! + debug!( + "fulfill_implication: for impls on {:?} and {:?}, \ + could not fulfill: {:?} given {:?}", + source_trait_ref, + target_trait_ref, + errors, + param_env.caller_bounds() + ); + return Err(NoSolution); + } + + debug!( + "fulfill_implication: an impl for {:?} specializes {:?}", + source_trait_ref, target_trait_ref + ); + + // Now resolve the *generic parameters* we built for the target earlier, replacing + // the inference variables inside with whatever we got from fulfillment. + Ok(infcx.resolve_vars_if_possible(target_args)) +} + pub(super) fn specialization_enabled_in(tcx: TyCtxt<'_>, _: LocalCrate) -> bool { tcx.features().specialization() || tcx.features().min_specialization() } @@ -182,8 +265,9 @@ pub(super) fn specializes(tcx: TyCtxt<'_>, (impl1_def_id, impl2_def_id): (DefId, return false; } - // create a parameter environment corresponding to a (placeholder) instantiation of impl1 - let penv = tcx.param_env(impl1_def_id); + // create a parameter environment corresponding to an identity instantiation of impl1, + // i.e. the most generic instantiation of impl1. + let param_env = tcx.param_env(impl1_def_id); // Create an infcx, taking the predicates of impl1 as assumptions: let infcx = tcx.infer_ctxt().build(TypingMode::non_body_analysis()); @@ -191,90 +275,15 @@ pub(super) fn specializes(tcx: TyCtxt<'_>, (impl1_def_id, impl2_def_id): (DefId, // Attempt to prove that impl2 applies, given all of the above. fulfill_implication( &infcx, - penv, + param_env, impl1_trait_header.trait_ref.instantiate_identity(), impl1_def_id, impl2_def_id, - |_, _| ObligationCause::dummy(), + &ObligationCause::dummy(), ) .is_ok() } -/// Attempt to fulfill all obligations of `target_impl` after unification with -/// `source_trait_ref`. If successful, returns the generic parameters for *all* the -/// generics of `target_impl`, including both those needed to unify with -/// `source_trait_ref` and those whose identity is determined via a where -/// clause in the impl. -fn fulfill_implication<'tcx>( - infcx: &InferCtxt<'tcx>, - param_env: ty::ParamEnv<'tcx>, - source_trait_ref: ty::TraitRef<'tcx>, - source_impl: DefId, - target_impl: DefId, - error_cause: impl Fn(usize, Span) -> ObligationCause<'tcx>, -) -> Result, ()> { - debug!( - "fulfill_implication({:?}, trait_ref={:?} |- {:?} applies)", - param_env, source_trait_ref, target_impl - ); - - let ocx = ObligationCtxt::new(infcx); - let source_trait_ref = ocx.normalize(&ObligationCause::dummy(), param_env, source_trait_ref); - - if !ocx.select_all_or_error().is_empty() { - infcx.dcx().span_delayed_bug( - infcx.tcx.def_span(source_impl), - format!("failed to fully normalize {source_trait_ref}"), - ); - } - - let source_trait_ref = infcx.resolve_vars_if_possible(source_trait_ref); - let source_trait = ImplSubject::Trait(source_trait_ref); - - let selcx = SelectionContext::new(infcx); - let target_args = infcx.fresh_args_for_item(DUMMY_SP, target_impl); - let (target_trait, obligations) = - util::impl_subject_and_oblig(&selcx, param_env, target_impl, target_args, error_cause); - - // do the impls unify? If not, no specialization. - let Ok(InferOk { obligations: more_obligations, .. }) = infcx - .at(&ObligationCause::dummy(), param_env) - // Ok to use `Yes`, as all the generic params are already replaced by inference variables, - // which will match the opaque type no matter if it is defining or not. - // Any concrete type that would match the opaque would already be handled by coherence rules, - // and thus either be ok to match here and already have errored, or it won't match, in which - // case there is no issue anyway. - .eq(DefineOpaqueTypes::Yes, source_trait, target_trait) - else { - debug!("fulfill_implication: {:?} does not unify with {:?}", source_trait, target_trait); - return Err(()); - }; - - // attempt to prove all of the predicates for impl2 given those for impl1 - // (which are packed up in penv) - ocx.register_obligations(obligations.chain(more_obligations)); - - let errors = ocx.select_all_or_error(); - if !errors.is_empty() { - // no dice! - debug!( - "fulfill_implication: for impls on {:?} and {:?}, \ - could not fulfill: {:?} given {:?}", - source_trait, - target_trait, - errors, - param_env.caller_bounds() - ); - return Err(()); - } - - debug!("fulfill_implication: an impl for {:?} specializes {:?}", source_trait, target_trait); - - // Now resolve the *generic parameters* we built for the target earlier, replacing - // the inference variables inside with whatever we got from fulfillment. - Ok(infcx.resolve_vars_if_possible(target_args)) -} - /// Query provider for `specialization_graph_of`. pub(super) fn specialization_graph_provider( tcx: TyCtxt<'_>, diff --git a/compiler/rustc_trait_selection/src/traits/util.rs b/compiler/rustc_trait_selection/src/traits/util.rs index b7a2f20b769..da1045b664a 100644 --- a/compiler/rustc_trait_selection/src/traits/util.rs +++ b/compiler/rustc_trait_selection/src/traits/util.rs @@ -3,19 +3,16 @@ use std::collections::BTreeMap; use rustc_data_structures::fx::FxIndexMap; use rustc_errors::Diag; use rustc_hir::def_id::DefId; -use rustc_infer::infer::{InferCtxt, InferOk}; +use rustc_infer::infer::InferCtxt; pub use rustc_infer::traits::util::*; use rustc_middle::bug; use rustc_middle::ty::{ - self, GenericArgsRef, ImplSubject, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable, - TypeVisitableExt, Upcast, + self, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitableExt, Upcast, }; use rustc_span::Span; use smallvec::{SmallVec, smallvec}; use tracing::debug; -use super::{NormalizeExt, ObligationCause, PredicateObligation, SelectionContext}; - /////////////////////////////////////////////////////////////////////////// // `TraitAliasExpander` iterator /////////////////////////////////////////////////////////////////////////// @@ -166,34 +163,6 @@ impl<'tcx> Iterator for TraitAliasExpander<'tcx> { // Other /////////////////////////////////////////////////////////////////////////// -/// Instantiate all bound parameters of the impl subject with the given args, -/// returning the resulting subject and all obligations that arise. -/// The obligations are closed under normalization. -pub(crate) fn impl_subject_and_oblig<'a, 'tcx>( - selcx: &SelectionContext<'a, 'tcx>, - param_env: ty::ParamEnv<'tcx>, - impl_def_id: DefId, - impl_args: GenericArgsRef<'tcx>, - cause: impl Fn(usize, Span) -> ObligationCause<'tcx>, -) -> (ImplSubject<'tcx>, impl Iterator>) { - let subject = selcx.tcx().impl_subject(impl_def_id); - let subject = subject.instantiate(selcx.tcx(), impl_args); - - let InferOk { value: subject, obligations: normalization_obligations1 } = - selcx.infcx.at(&ObligationCause::dummy(), param_env).normalize(subject); - - let predicates = selcx.tcx().predicates_of(impl_def_id); - let predicates = predicates.instantiate(selcx.tcx(), impl_args); - let InferOk { value: predicates, obligations: normalization_obligations2 } = - selcx.infcx.at(&ObligationCause::dummy(), param_env).normalize(predicates); - let impl_obligations = super::predicates_for_generics(cause, param_env, predicates); - - let impl_obligations = - impl_obligations.chain(normalization_obligations1).chain(normalization_obligations2); - - (subject, impl_obligations) -} - /// Casts a trait reference into a reference to one of its super /// traits; returns `None` if `target_trait_def_id` is not a /// supertrait. diff --git a/tests/ui/specialization/min_specialization/specialize_with_generalize_lifetimes.stderr b/tests/ui/specialization/min_specialization/specialize_with_generalize_lifetimes.stderr index 2af75876d5b..04a41f0d9dd 100644 --- a/tests/ui/specialization/min_specialization/specialize_with_generalize_lifetimes.stderr +++ b/tests/ui/specialization/min_specialization/specialize_with_generalize_lifetimes.stderr @@ -4,11 +4,7 @@ error[E0477]: the type `&'a i32` does not fulfill the required lifetime LL | impl<'a> Tr for &'a i32 { | ^^^^^^^^^^^^^^^^^^^^^^^ | -note: type must satisfy the static lifetime as required by this binding - --> $DIR/specialize_with_generalize_lifetimes.rs:12:15 - | -LL | impl Tr for T { - | ^^^^^^^ + = note: type must satisfy the static lifetime error[E0477]: the type `Wrapper<'a>` does not fulfill the required lifetime --> $DIR/specialize_with_generalize_lifetimes.rs:31:1 @@ -16,11 +12,7 @@ error[E0477]: the type `Wrapper<'a>` does not fulfill the required lifetime LL | impl<'a> Tr for Wrapper<'a> { | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ | -note: type must satisfy the static lifetime as required by this binding - --> $DIR/specialize_with_generalize_lifetimes.rs:12:15 - | -LL | impl Tr for T { - | ^^^^^^^ + = note: type must satisfy the static lifetime error: aborting due to 2 previous errors