Equate full fn signatures to infer all region variables

This commit is contained in:
Michael Goulet 2022-10-10 23:35:51 +00:00
parent cb20758257
commit e994de803d

View File

@ -465,30 +465,30 @@ pub fn collect_trait_impl_trait_tys<'tcx>(
let ocx = ObligationCtxt::new(infcx); let ocx = ObligationCtxt::new(infcx);
let norm_cause = ObligationCause::misc(return_span, impl_m_hir_id); let norm_cause = ObligationCause::misc(return_span, impl_m_hir_id);
let impl_return_ty = ocx.normalize( let impl_sig = ocx.normalize(
norm_cause.clone(), norm_cause.clone(),
param_env, param_env,
infcx infcx.replace_bound_vars_with_fresh_vars(
.replace_bound_vars_with_fresh_vars( return_span,
return_span, infer::HigherRankedType,
infer::HigherRankedType, tcx.fn_sig(impl_m.def_id),
tcx.fn_sig(impl_m.def_id), ),
)
.output(),
); );
let impl_return_ty = impl_sig.output();
let mut collector = ImplTraitInTraitCollector::new(&ocx, return_span, param_env, impl_m_hir_id); let mut collector = ImplTraitInTraitCollector::new(&ocx, return_span, param_env, impl_m_hir_id);
let unnormalized_trait_return_ty = tcx let unnormalized_trait_sig = tcx
.liberate_late_bound_regions( .liberate_late_bound_regions(
impl_m.def_id, impl_m.def_id,
tcx.bound_fn_sig(trait_m.def_id).subst(tcx, trait_to_placeholder_substs), tcx.bound_fn_sig(trait_m.def_id).subst(tcx, trait_to_placeholder_substs),
) )
.output()
.fold_with(&mut collector); .fold_with(&mut collector);
let trait_return_ty = let trait_sig = ocx.normalize(norm_cause.clone(), param_env, unnormalized_trait_sig);
ocx.normalize(norm_cause.clone(), param_env, unnormalized_trait_return_ty); let trait_return_ty = trait_sig.output();
let wf_tys = FxHashSet::from_iter([unnormalized_trait_return_ty, trait_return_ty]); let wf_tys = FxHashSet::from_iter(
unnormalized_trait_sig.inputs_and_output.iter().chain(trait_sig.inputs_and_output.iter()),
);
match infcx.at(&cause, param_env).eq(trait_return_ty, impl_return_ty) { match infcx.at(&cause, param_env).eq(trait_return_ty, impl_return_ty) {
Ok(infer::InferOk { value: (), obligations }) => { Ok(infer::InferOk { value: (), obligations }) => {
@ -521,6 +521,26 @@ pub fn collect_trait_impl_trait_tys<'tcx>(
} }
} }
// Unify the whole function signature. We need to do this to fully infer
// the lifetimes of the return type, but do this after unifying just the
// return types, since we want to avoid duplicating errors from
// `compare_predicate_entailment`.
match infcx
.at(&cause, param_env)
.eq(tcx.mk_fn_ptr(ty::Binder::dummy(trait_sig)), tcx.mk_fn_ptr(ty::Binder::dummy(impl_sig)))
{
Ok(infer::InferOk { value: (), obligations }) => {
ocx.register_obligations(obligations);
}
Err(terr) => {
let guar = tcx.sess.delay_span_bug(
return_span,
format!("could not unify `{trait_sig}` and `{impl_sig}`: {terr:?}"),
);
return Err(guar);
}
}
// Check that all obligations are satisfied by the implementation's // Check that all obligations are satisfied by the implementation's
// RPITs. // RPITs.
let errors = ocx.select_all_or_error(); let errors = ocx.select_all_or_error();