From 45f441a7b41d09bb78b0ba3e260d2c868ef6add7 Mon Sep 17 00:00:00 2001 From: lcnr Date: Mon, 14 Nov 2022 17:42:46 +0100 Subject: [PATCH] nll: correctly deal with bivariance --- .../src/interpret/eval_context.rs | 6 +- .../src/transform/validate.rs | 85 +++++++++---------- .../rustc_infer/src/infer/nll_relate/mod.rs | 8 +- compiler/rustc_mir_transform/src/inline.rs | 12 +-- .../ui/mir/important-higher-ranked-regions.rs | 26 ++++++ 5 files changed, 82 insertions(+), 55 deletions(-) create mode 100644 src/test/ui/mir/important-higher-ranked-regions.rs diff --git a/compiler/rustc_const_eval/src/interpret/eval_context.rs b/compiler/rustc_const_eval/src/interpret/eval_context.rs index ab82268dde3..753c1e87d21 100644 --- a/compiler/rustc_const_eval/src/interpret/eval_context.rs +++ b/compiler/rustc_const_eval/src/interpret/eval_context.rs @@ -23,7 +23,7 @@ use super::{ MemPlaceMeta, Memory, MemoryKind, Operand, Place, PlaceTy, PointerArithmetic, Provenance, Scalar, StackPopJump, }; -use crate::transform::validate::equal_up_to_regions; +use crate::transform::validate; pub struct InterpCx<'mir, 'tcx, M: Machine<'mir, 'tcx>> { /// Stores the `Machine` instance. @@ -354,8 +354,8 @@ pub(super) fn mir_assign_valid_types<'tcx>( // Type-changing assignments can happen when subtyping is used. While // all normal lifetimes are erased, higher-ranked types with their // late-bound lifetimes are still around and can lead to type - // differences. So we compare ignoring lifetimes. - if equal_up_to_regions(tcx, param_env, src.ty, dest.ty) { + // differences. + if validate::is_subtype(tcx, param_env, src.ty, dest.ty) { // Make sure the layout is equal, too -- just to be safe. Miri really // needs layout equality. For performance reason we skip this check when // the types are equal. Equal types *can* have different layouts when diff --git a/compiler/rustc_const_eval/src/transform/validate.rs b/compiler/rustc_const_eval/src/transform/validate.rs index 81b82a21fa1..69e457f0a1a 100644 --- a/compiler/rustc_const_eval/src/transform/validate.rs +++ b/compiler/rustc_const_eval/src/transform/validate.rs @@ -2,7 +2,8 @@ use rustc_data_structures::fx::FxHashSet; use rustc_index::bit_set::BitSet; -use rustc_infer::infer::TyCtxtInferExt; +use rustc_infer::infer::{DefiningAnchor, TyCtxtInferExt}; +use rustc_infer::traits::ObligationCause; use rustc_middle::mir::interpret::Scalar; use rustc_middle::mir::visit::NonUseContext::VarDebugInfo; use rustc_middle::mir::visit::{PlaceContext, Visitor}; @@ -12,12 +13,12 @@ use rustc_middle::mir::{ ProjectionElem, RuntimePhase, Rvalue, SourceScope, Statement, StatementKind, Terminator, TerminatorKind, UnOp, START_BLOCK, }; -use rustc_middle::ty::fold::BottomUpFolder; -use rustc_middle::ty::{self, InstanceDef, ParamEnv, Ty, TyCtxt, TypeFoldable, TypeVisitable}; +use rustc_middle::ty::{self, InstanceDef, ParamEnv, Ty, TyCtxt, TypeVisitable}; use rustc_mir_dataflow::impls::MaybeStorageLive; use rustc_mir_dataflow::storage::always_storage_live_locals; use rustc_mir_dataflow::{Analysis, ResultsCursor}; use rustc_target::abi::{Size, VariantIdx}; +use rustc_trait_selection::traits::ObligationCtxt; #[derive(Copy, Clone, Debug)] enum EdgeKind { @@ -70,13 +71,11 @@ impl<'tcx> MirPass<'tcx> for Validator { } } -/// Returns whether the two types are equal up to lifetimes. -/// All lifetimes, including higher-ranked ones, get ignored for this comparison. -/// (This is unlike the `erasing_regions` methods, which keep higher-ranked lifetimes for soundness reasons.) +/// Returns whether the two types are equal up to subtyping. /// -/// The point of this function is to approximate "equal up to subtyping". However, -/// the approximation is incorrect as variance is ignored. -pub fn equal_up_to_regions<'tcx>( +/// This is used in case we don't know the expected subtyping direction +/// and still want to check whether anything is broken. +pub fn is_equal_up_to_subtyping<'tcx>( tcx: TyCtxt<'tcx>, param_env: ParamEnv<'tcx>, src: Ty<'tcx>, @@ -87,27 +86,40 @@ pub fn equal_up_to_regions<'tcx>( return true; } - // Normalize lifetimes away on both sides, then compare. - let normalize = |ty: Ty<'tcx>| { - tcx.try_normalize_erasing_regions(param_env, ty).unwrap_or(ty).fold_with( - &mut BottomUpFolder { - tcx, - // FIXME: We erase all late-bound lifetimes, but this is not fully correct. - // If you have a type like ` fn(&'a u32) as SomeTrait>::Assoc`, - // this is not necessarily equivalent to `::Assoc`, - // since one may have an `impl SomeTrait for fn(&32)` and - // `impl SomeTrait for fn(&'static u32)` at the same time which - // specify distinct values for Assoc. (See also #56105) - lt_op: |_| tcx.lifetimes.re_erased, - // Leave consts and types unchanged. - ct_op: |ct| ct, - ty_op: |ty| ty, - }, - ) - }; - tcx.infer_ctxt().build().can_eq(param_env, normalize(src), normalize(dest)).is_ok() + // Check for subtyping in either direction. + is_subtype(tcx, param_env, src, dest) || is_subtype(tcx, param_env, dest, src) } +pub fn is_subtype<'tcx>( + tcx: TyCtxt<'tcx>, + param_env: ParamEnv<'tcx>, + src: Ty<'tcx>, + dest: Ty<'tcx>, +) -> bool { + if src == dest { + return true; + } + + let mut builder = + tcx.infer_ctxt().ignoring_regions().with_opaque_type_inference(DefiningAnchor::Bubble); + let infcx = builder.build(); + let ocx = ObligationCtxt::new(&infcx); + let cause = ObligationCause::dummy(); + let src = ocx.normalize(cause.clone(), param_env, src); + let dest = ocx.normalize(cause.clone(), param_env, dest); + let Ok(infer_ok) = infcx.at(&cause, param_env).sub(src, dest) else { + return false; + }; + let () = ocx.register_infer_ok_obligations(infer_ok); + let errors = ocx.select_all_or_error(); + // With `Reveal::All`, opaque types get normalized away, with `Reveal::UserFacing` + // we would get unification errors because we're unable to look into opaque types, + // even if they're constrained in our current function. + // + // It seems very unlikely that this hides any bugs. + let _ = infcx.inner.borrow_mut().opaque_type_storage.take_opaque_types(); + errors.is_empty() +} struct TypeChecker<'a, 'tcx> { when: &'a str, body: &'a Body<'tcx>, @@ -183,22 +195,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> { return true; } - // Normalize projections and things like that. - // Type-changing assignments can happen when subtyping is used. While - // all normal lifetimes are erased, higher-ranked types with their - // late-bound lifetimes are still around and can lead to type - // differences. So we compare ignoring lifetimes. - - // First, try with reveal_all. This might not work in some cases, as the predicates - // can be cleared in reveal_all mode. We try the reveal first anyways as it is used - // by some other passes like inlining as well. - let param_env = self.param_env.with_reveal_all_normalized(self.tcx); - if equal_up_to_regions(self.tcx, param_env, src, dest) { - return true; - } - - // If this fails, we can try it without the reveal. - equal_up_to_regions(self.tcx, self.param_env, src, dest) + is_subtype(self.tcx, self.param_env, src, dest) } } diff --git a/compiler/rustc_infer/src/infer/nll_relate/mod.rs b/compiler/rustc_infer/src/infer/nll_relate/mod.rs index 600f94f095e..74b2fb613b7 100644 --- a/compiler/rustc_infer/src/infer/nll_relate/mod.rs +++ b/compiler/rustc_infer/src/infer/nll_relate/mod.rs @@ -556,8 +556,12 @@ where self.ambient_variance_info = self.ambient_variance_info.xform(info); debug!(?self.ambient_variance); - - let r = self.relate(a, b)?; + // In a bivariant context this always succeeds. + let r = if self.ambient_variance == ty::Variance::Bivariant { + a + } else { + self.relate(a, b)? + }; self.ambient_variance = old_ambient_variance; diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs index 780b91d9215..2084492cc63 100644 --- a/compiler/rustc_mir_transform/src/inline.rs +++ b/compiler/rustc_mir_transform/src/inline.rs @@ -1,7 +1,6 @@ //! Inlining pass for MIR functions use crate::deref_separator::deref_finder; use rustc_attr::InlineAttr; -use rustc_const_eval::transform::validate::equal_up_to_regions; use rustc_index::bit_set::BitSet; use rustc_index::vec::Idx; use rustc_middle::middle::codegen_fn_attrs::{CodegenFnAttrFlags, CodegenFnAttrs}; @@ -14,7 +13,8 @@ use rustc_span::{hygiene::ExpnKind, ExpnData, LocalExpnId, Span}; use rustc_target::abi::VariantIdx; use rustc_target::spec::abi::Abi; -use super::simplify::{remove_dead_blocks, CfgSimplifier}; +use crate::simplify::{remove_dead_blocks, CfgSimplifier}; +use crate::validate; use crate::MirPass; use std::iter; use std::ops::{Range, RangeFrom}; @@ -180,7 +180,7 @@ impl<'tcx> Inliner<'tcx> { let TerminatorKind::Call { args, destination, .. } = &terminator.kind else { bug!() }; let destination_ty = destination.ty(&caller_body.local_decls, self.tcx).ty; let output_type = callee_body.return_ty(); - if !equal_up_to_regions(self.tcx, self.param_env, output_type, destination_ty) { + if !validate::is_subtype(self.tcx, self.param_env, output_type, destination_ty) { trace!(?output_type, ?destination_ty); return Err("failed to normalize return type"); } @@ -200,7 +200,7 @@ impl<'tcx> Inliner<'tcx> { arg_tuple_tys.iter().zip(callee_body.args_iter().skip(skipped_args)) { let input_type = callee_body.local_decls[input].ty; - if !equal_up_to_regions(self.tcx, self.param_env, arg_ty, input_type) { + if !validate::is_subtype(self.tcx, self.param_env, input_type, arg_ty) { trace!(?arg_ty, ?input_type); return Err("failed to normalize tuple argument type"); } @@ -209,7 +209,7 @@ impl<'tcx> Inliner<'tcx> { for (arg, input) in args.iter().zip(callee_body.args_iter()) { let input_type = callee_body.local_decls[input].ty; let arg_ty = arg.ty(&caller_body.local_decls, self.tcx); - if !equal_up_to_regions(self.tcx, self.param_env, arg_ty, input_type) { + if !validate::is_subtype(self.tcx, self.param_env, input_type, arg_ty) { trace!(?arg_ty, ?input_type); return Err("failed to normalize argument type"); } @@ -847,7 +847,7 @@ impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> { let parent = Place { local, projection: self.tcx.intern_place_elems(proj_base) }; let parent_ty = parent.ty(&self.callee_body.local_decls, self.tcx); let check_equal = |this: &mut Self, f_ty| { - if !equal_up_to_regions(this.tcx, this.param_env, ty, f_ty) { + if !validate::is_equal_up_to_subtyping(this.tcx, this.param_env, ty, f_ty) { trace!(?ty, ?f_ty); this.validation = Err("failed to normalize projection type"); return; diff --git a/src/test/ui/mir/important-higher-ranked-regions.rs b/src/test/ui/mir/important-higher-ranked-regions.rs new file mode 100644 index 00000000000..cadfb3b66f2 --- /dev/null +++ b/src/test/ui/mir/important-higher-ranked-regions.rs @@ -0,0 +1,26 @@ +// check-pass +// compile-flags: -Zvalidate-mir + +// This test checks that bivariant parameters are handled correctly +// in the mir. +#![allow(coherence_leak_check)] +trait Trait { + type Assoc; +} + +struct Foo(T) +where + T: Trait; + +impl Trait for for<'a> fn(&'a ()) { + type Assoc = u32; +} +impl Trait for fn(&'static ()) { + type Assoc = String; +} + +fn foo(x: Foo fn(&'a ()), u32>) -> Foo { + x +} + +fn main() {}