From 55f8c66a601236b422e35f56f7e414a8280c78d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Esteban=20K=C3=BCber?= Date: Mon, 14 Aug 2023 18:36:03 +0000 Subject: [PATCH] Point at return type when it influences non-first `match` arm When encountering code like ```rust fn foo() -> i32 { match 0 { 1 => return 0, 2 => "", _ => 1, } } ``` Point at the return type and not at the prior arm, as that arm has type `!` which isn't influencing the arm corresponding to arm `2`. Fix #78124. --- compiler/rustc_hir_typeck/src/_match.rs | 13 ++- compiler/rustc_hir_typeck/src/coercion.rs | 2 +- .../rustc_hir_typeck/src/fn_ctxt/checks.rs | 9 ++- .../src/infer/error_reporting/mod.rs | 80 +++++++++++++++---- .../nice_region_error/static_impl_trait.rs | 2 +- compiler/rustc_middle/src/traits/mod.rs | 2 +- .../src/traits/error_reporting/suggestions.rs | 2 +- .../did_you_mean/compatible-variants.stderr | 2 + ...1632-try-desugar-incompatible-types.stderr | 2 + ...t-arm-doesnt-match-expected-return-type.rs | 21 +++++ ...m-doesnt-match-expected-return-type.stderr | 12 +++ .../remove-question-symbol-with-paren.stderr | 3 + 12 files changed, 129 insertions(+), 21 deletions(-) create mode 100644 tests/ui/match/non-first-arm-doesnt-match-expected-return-type.rs create mode 100644 tests/ui/match/non-first-arm-doesnt-match-expected-return-type.stderr diff --git a/compiler/rustc_hir_typeck/src/_match.rs b/compiler/rustc_hir_typeck/src/_match.rs index 6d926cd8aa1..e565dbfe8d2 100644 --- a/compiler/rustc_hir_typeck/src/_match.rs +++ b/compiler/rustc_hir_typeck/src/_match.rs @@ -107,7 +107,14 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { let (span, code) = match prior_arm { // The reason for the first arm to fail is not that the match arms diverge, // but rather that there's a prior obligation that doesn't hold. - None => (arm_span, ObligationCauseCode::BlockTailExpression(arm.body.hir_id)), + None => ( + arm_span, + ObligationCauseCode::BlockTailExpression( + arm.body.hir_id, + scrut.hir_id, + match_src, + ), + ), Some((prior_arm_block_id, prior_arm_ty, prior_arm_span)) => ( expr.span, ObligationCauseCode::MatchExpressionArm(Box::new(MatchExpressionArmCause { @@ -145,7 +152,9 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { other_arms.remove(0); } - prior_arm = Some((arm_block_id, arm_ty, arm_span)); + if !arm_ty.is_never() { + prior_arm = Some((arm_block_id, arm_ty, arm_span)); + } } // If all of the arms in the `match` diverge, diff --git a/compiler/rustc_hir_typeck/src/coercion.rs b/compiler/rustc_hir_typeck/src/coercion.rs index 1cfdc5b9e7f..ca616fd1518 100644 --- a/compiler/rustc_hir_typeck/src/coercion.rs +++ b/compiler/rustc_hir_typeck/src/coercion.rs @@ -1603,7 +1603,7 @@ impl<'tcx, 'exprs, E: AsCoercionSite> CoerceMany<'tcx, 'exprs, E> { ); err.span_label(cause.span, "return type is not `()`"); } - ObligationCauseCode::BlockTailExpression(blk_id) => { + ObligationCauseCode::BlockTailExpression(blk_id, ..) => { let parent_id = fcx.tcx.hir().parent_id(blk_id); err = self.report_return_mismatched_types( cause, diff --git a/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs b/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs index 40f9a954034..5254f0796d8 100644 --- a/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs +++ b/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs @@ -1580,7 +1580,14 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { let coerce = ctxt.coerce.as_mut().unwrap(); if let Some((tail_expr, tail_expr_ty)) = tail_expr_ty { let span = self.get_expr_coercion_span(tail_expr); - let cause = self.cause(span, ObligationCauseCode::BlockTailExpression(blk.hir_id)); + let cause = self.cause( + span, + ObligationCauseCode::BlockTailExpression( + blk.hir_id, + blk.hir_id, + hir::MatchSource::Normal, + ), + ); let ty_for_diagnostic = coerce.merged_ty(); // We use coerce_inner here because we want to augment the error // suggesting to wrap the block in square brackets if it might've diff --git a/compiler/rustc_infer/src/infer/error_reporting/mod.rs b/compiler/rustc_infer/src/infer/error_reporting/mod.rs index 75cca973306..717103ed0b4 100644 --- a/compiler/rustc_infer/src/infer/error_reporting/mod.rs +++ b/compiler/rustc_infer/src/infer/error_reporting/mod.rs @@ -743,6 +743,36 @@ impl<'tcx> TypeErrCtxt<'_, 'tcx> { ObligationCauseCode::Pattern { origin_expr: false, span: Some(span), .. } => { err.span_label(span, "expected due to this"); } + ObligationCauseCode::BlockTailExpression( + _, + scrut_hir_id, + hir::MatchSource::TryDesugar, + ) => { + if let Some(ty::error::ExpectedFound { expected, .. }) = exp_found { + let scrut_expr = self.tcx.hir().expect_expr(scrut_hir_id); + let scrut_ty = if let hir::ExprKind::Call(_, args) = &scrut_expr.kind { + let arg_expr = args.first().expect("try desugaring call w/out arg"); + self.typeck_results.as_ref().and_then(|typeck_results| { + typeck_results.expr_ty_opt(arg_expr) + }) + } else { + bug!("try desugaring w/out call expr as scrutinee"); + }; + + match scrut_ty { + Some(ty) if expected == ty => { + let source_map = self.tcx.sess.source_map(); + err.span_suggestion( + source_map.end_point(cause.span()), + "try removing this `?`", + "", + Applicability::MachineApplicable, + ); + } + _ => {} + } + } + }, ObligationCauseCode::MatchExpressionArm(box MatchExpressionArmCause { arm_block_id, arm_span, @@ -1973,7 +2003,7 @@ impl<'tcx> TypeErrCtxt<'_, 'tcx> { trace: &TypeTrace<'tcx>, terr: TypeError<'tcx>, ) -> Vec { - use crate::traits::ObligationCauseCode::MatchExpressionArm; + use crate::traits::ObligationCauseCode::{BlockTailExpression, MatchExpressionArm}; let mut suggestions = Vec::new(); let span = trace.cause.span(); let values = self.resolve_vars_if_possible(trace.values); @@ -1991,11 +2021,17 @@ impl<'tcx> TypeErrCtxt<'_, 'tcx> { // specify a byte literal (ty::Uint(ty::UintTy::U8), ty::Char) => { if let Ok(code) = self.tcx.sess().source_map().span_to_snippet(span) - && let Some(code) = code.strip_prefix('\'').and_then(|s| s.strip_suffix('\'')) - && !code.starts_with("\\u") // forbid all Unicode escapes - && code.chars().next().is_some_and(|c| c.is_ascii()) // forbids literal Unicode characters beyond ASCII + && let Some(code) = + code.strip_prefix('\'').and_then(|s| s.strip_suffix('\'')) + // forbid all Unicode escapes + && !code.starts_with("\\u") + // forbids literal Unicode characters beyond ASCII + && code.chars().next().is_some_and(|c| c.is_ascii()) { - suggestions.push(TypeErrorAdditionalDiags::MeantByteLiteral { span, code: escape_literal(code) }) + suggestions.push(TypeErrorAdditionalDiags::MeantByteLiteral { + span, + code: escape_literal(code), + }) } } // If a character was expected and the found expression is a string literal @@ -2006,7 +2042,10 @@ impl<'tcx> TypeErrCtxt<'_, 'tcx> { && let Some(code) = code.strip_prefix('"').and_then(|s| s.strip_suffix('"')) && code.chars().count() == 1 { - suggestions.push(TypeErrorAdditionalDiags::MeantCharLiteral { span, code: escape_literal(code) }) + suggestions.push(TypeErrorAdditionalDiags::MeantCharLiteral { + span, + code: escape_literal(code), + }) } } // If a string was expected and the found expression is a character literal, @@ -2016,7 +2055,10 @@ impl<'tcx> TypeErrCtxt<'_, 'tcx> { if let Some(code) = code.strip_prefix('\'').and_then(|s| s.strip_suffix('\'')) { - suggestions.push(TypeErrorAdditionalDiags::MeantStrLiteral { span, code: escape_literal(code) }) + suggestions.push(TypeErrorAdditionalDiags::MeantStrLiteral { + span, + code: escape_literal(code), + }) } } } @@ -2025,17 +2067,24 @@ impl<'tcx> TypeErrCtxt<'_, 'tcx> { (ty::Bool, ty::Tuple(list)) => if list.len() == 0 { suggestions.extend(self.suggest_let_for_letchains(&trace.cause, span)); } - (ty::Array(_, _), ty::Array(_, _)) => suggestions.extend(self.suggest_specify_actual_length(terr, trace, span)), + (ty::Array(_, _), ty::Array(_, _)) => { + suggestions.extend(self.suggest_specify_actual_length(terr, trace, span)) + } _ => {} } } let code = trace.cause.code(); - if let &MatchExpressionArm(box MatchExpressionArmCause { source, .. }) = code - && let hir::MatchSource::TryDesugar = source - && let Some((expected_ty, found_ty, _, _)) = self.values_str(trace.values) - { - suggestions.push(TypeErrorAdditionalDiags::TryCannotConvert { found: found_ty.content(), expected: expected_ty.content() }); - } + if let &(MatchExpressionArm(box MatchExpressionArmCause { source, .. }) + | BlockTailExpression(.., source) + ) = code + && let hir::MatchSource::TryDesugar = source + && let Some((expected_ty, found_ty, _, _)) = self.values_str(trace.values) + { + suggestions.push(TypeErrorAdditionalDiags::TryCannotConvert { + found: found_ty.content(), + expected: expected_ty.content(), + }); + } suggestions } @@ -2905,6 +2954,9 @@ impl<'tcx> ObligationCauseExt<'tcx> for ObligationCause<'tcx> { CompareImplItemObligation { kind: ty::AssocKind::Const, .. } => { ObligationCauseFailureCode::ConstCompat { span, subdiags } } + BlockTailExpression(.., hir::MatchSource::TryDesugar) => { + ObligationCauseFailureCode::TryCompat { span, subdiags } + } MatchExpressionArm(box MatchExpressionArmCause { source, .. }) => match source { hir::MatchSource::TryDesugar => { ObligationCauseFailureCode::TryCompat { span, subdiags } diff --git a/compiler/rustc_infer/src/infer/error_reporting/nice_region_error/static_impl_trait.rs b/compiler/rustc_infer/src/infer/error_reporting/nice_region_error/static_impl_trait.rs index d08b6ba5e47..3cfda0cc5c0 100644 --- a/compiler/rustc_infer/src/infer/error_reporting/nice_region_error/static_impl_trait.rs +++ b/compiler/rustc_infer/src/infer/error_reporting/nice_region_error/static_impl_trait.rs @@ -146,7 +146,7 @@ impl<'a, 'tcx> NiceRegionError<'a, 'tcx> { if let SubregionOrigin::Subtype(box TypeTrace { cause, .. }) = sub_origin { if let ObligationCauseCode::ReturnValue(hir_id) - | ObligationCauseCode::BlockTailExpression(hir_id) = cause.code() + | ObligationCauseCode::BlockTailExpression(hir_id, ..) = cause.code() { let parent_id = tcx.hir().get_parent_item(*hir_id); if let Some(fn_decl) = tcx.hir().fn_decl_by_hir_id(parent_id.into()) { diff --git a/compiler/rustc_middle/src/traits/mod.rs b/compiler/rustc_middle/src/traits/mod.rs index 2d655041c32..1845d42bf7f 100644 --- a/compiler/rustc_middle/src/traits/mod.rs +++ b/compiler/rustc_middle/src/traits/mod.rs @@ -402,7 +402,7 @@ pub enum ObligationCauseCode<'tcx> { OpaqueReturnType(Option<(Ty<'tcx>, Span)>), /// Block implicit return - BlockTailExpression(hir::HirId), + BlockTailExpression(hir::HirId, hir::HirId, hir::MatchSource), /// #[feature(trivial_bounds)] is not enabled TrivialBound, diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs index d071cf76fd3..5e075984238 100644 --- a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs +++ b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs @@ -2700,7 +2700,7 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> { | ObligationCauseCode::MatchImpl(..) | ObligationCauseCode::ReturnType | ObligationCauseCode::ReturnValue(_) - | ObligationCauseCode::BlockTailExpression(_) + | ObligationCauseCode::BlockTailExpression(..) | ObligationCauseCode::AwaitableExpr(_) | ObligationCauseCode::ForLoopIterator | ObligationCauseCode::QuestionMark diff --git a/tests/ui/did_you_mean/compatible-variants.stderr b/tests/ui/did_you_mean/compatible-variants.stderr index 7b88d93ead1..f2bbd8ced8f 100644 --- a/tests/ui/did_you_mean/compatible-variants.stderr +++ b/tests/ui/did_you_mean/compatible-variants.stderr @@ -61,6 +61,8 @@ LL + Some(()) error[E0308]: `?` operator has incompatible types --> $DIR/compatible-variants.rs:35:5 | +LL | fn d() -> Option<()> { + | ---------- expected `Option<()>` because of return type LL | c()? | ^^^^ expected `Option<()>`, found `()` | diff --git a/tests/ui/issues/issue-51632-try-desugar-incompatible-types.stderr b/tests/ui/issues/issue-51632-try-desugar-incompatible-types.stderr index 7180a3d2426..c92da53dbc4 100644 --- a/tests/ui/issues/issue-51632-try-desugar-incompatible-types.stderr +++ b/tests/ui/issues/issue-51632-try-desugar-incompatible-types.stderr @@ -1,6 +1,8 @@ error[E0308]: `?` operator has incompatible types --> $DIR/issue-51632-try-desugar-incompatible-types.rs:8:5 | +LL | fn forbidden_narratives() -> Result { + | ----------------- expected `Result` because of return type LL | missing_discourses()? | ^^^^^^^^^^^^^^^^^^^^^ expected `Result`, found `isize` | diff --git a/tests/ui/match/non-first-arm-doesnt-match-expected-return-type.rs b/tests/ui/match/non-first-arm-doesnt-match-expected-return-type.rs new file mode 100644 index 00000000000..85b1ef7555e --- /dev/null +++ b/tests/ui/match/non-first-arm-doesnt-match-expected-return-type.rs @@ -0,0 +1,21 @@ +#![allow(unused)] + +fn test(shouldwe: Option, shouldwe2: Option) -> u32 { + //~^ NOTE expected `u32` because of return type + match shouldwe { + Some(val) => { + match shouldwe2 { + Some(val) => { + return val; + } + None => (), //~ ERROR mismatched types + //~^ NOTE expected `u32`, found `()` + } + } + None => return 12, + } +} + +fn main() { + println!("returned {}", test(None, Some(5))); +} diff --git a/tests/ui/match/non-first-arm-doesnt-match-expected-return-type.stderr b/tests/ui/match/non-first-arm-doesnt-match-expected-return-type.stderr new file mode 100644 index 00000000000..e6d93b8b5f5 --- /dev/null +++ b/tests/ui/match/non-first-arm-doesnt-match-expected-return-type.stderr @@ -0,0 +1,12 @@ +error[E0308]: mismatched types + --> $DIR/non-first-arm-doesnt-match-expected-return-type.rs:11:25 + | +LL | fn test(shouldwe: Option, shouldwe2: Option) -> u32 { + | --- expected `u32` because of return type +... +LL | None => (), + | ^^ expected `u32`, found `()` + +error: aborting due to previous error + +For more information about this error, try `rustc --explain E0308`. diff --git a/tests/ui/suggestions/remove-question-symbol-with-paren.stderr b/tests/ui/suggestions/remove-question-symbol-with-paren.stderr index 39e35f733a1..40b9cf2dcd4 100644 --- a/tests/ui/suggestions/remove-question-symbol-with-paren.stderr +++ b/tests/ui/suggestions/remove-question-symbol-with-paren.stderr @@ -1,6 +1,9 @@ error[E0308]: `?` operator has incompatible types --> $DIR/remove-question-symbol-with-paren.rs:5:6 | +LL | fn foo() -> Option<()> { + | ---------- expected `Option<()>` because of return type +LL | let x = Some(()); LL | (x?) | ^^ expected `Option<()>`, found `()` |