diff --git a/crates/ide_assists/src/handlers/replace_if_let_with_match.rs b/crates/ide_assists/src/handlers/replace_if_let_with_match.rs index f37aa0d53c6..888b4d090b0 100644 --- a/crates/ide_assists/src/handlers/replace_if_let_with_match.rs +++ b/crates/ide_assists/src/handlers/replace_if_let_with_match.rs @@ -1,4 +1,4 @@ -use std::iter; +use std::iter::{self, successors}; use ide_db::{ty_filter::TryEnum, RootDatabase}; use syntax::{ @@ -17,7 +17,7 @@ use crate::{ // Assist: replace_if_let_with_match // -// Replaces `if let` with an else branch with a `match` expression. +// Replaces a `if let` expression with a `match` expression. // // ``` // enum Action { Move { distance: u32 }, Stop } @@ -43,14 +43,28 @@ use crate::{ // ``` pub(crate) fn replace_if_let_with_match(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { let if_expr: ast::IfExpr = ctx.find_node_at_offset()?; - let cond = if_expr.condition()?; - let pat = cond.pat()?; - let expr = cond.expr()?; - let then_block = if_expr.then_branch()?; - let else_block = match if_expr.else_branch()? { - ast::ElseBranch::Block(it) => it, - ast::ElseBranch::IfExpr(_) => return None, - }; + let mut else_block = None; + let if_exprs = successors(Some(if_expr.clone()), |expr| match expr.else_branch()? { + ast::ElseBranch::IfExpr(expr) => Some(expr), + ast::ElseBranch::Block(block) => { + else_block = Some(block); + None + } + }); + let scrutinee_to_be_expr = if_expr.condition()?.expr()?; + + let mut pat_bodies = Vec::new(); + for if_expr in if_exprs { + let cond = if_expr.condition()?; + let expr = cond.expr()?; + if scrutinee_to_be_expr.syntax().text() != expr.syntax().text() { + // Only if all condition expressions are equal we can merge them into a match + return None; + } + let pat = cond.pat()?; + let body = if_expr.then_branch()?; + pat_bodies.push((pat, body)); + } let target = if_expr.syntax().text_range(); acc.add( @@ -59,33 +73,50 @@ pub(crate) fn replace_if_let_with_match(acc: &mut Assists, ctx: &AssistContext) target, move |edit| { let match_expr = { - let then_arm = { - let then_block = then_block.reset_indent().indent(IndentLevel(1)); - let then_expr = unwrap_trivial_block(then_block); - make::match_arm(vec![pat.clone()], then_expr) - }; let else_arm = { - let pattern = ctx - .sema - .type_of_pat(&pat) - .and_then(|ty| TryEnum::from_ty(&ctx.sema, &ty)) - .map(|it| { - if does_pat_match_variant(&pat, &it.sad_pattern()) { - it.happy_pattern() - } else { - it.sad_pattern() + match else_block { + Some(else_block) => { + let pattern = match &*pat_bodies { + [(pat, _)] => ctx + .sema + .type_of_pat(&pat) + .and_then(|ty| TryEnum::from_ty(&ctx.sema, &ty)) + .map(|it| { + if does_pat_match_variant(&pat, &it.sad_pattern()) { + it.happy_pattern() + } else { + it.sad_pattern() + } + }), + _ => None, } - }) - .unwrap_or_else(|| make::wildcard_pat().into()); - let else_expr = unwrap_trivial_block(else_block); - make::match_arm(vec![pattern], else_expr) + .unwrap_or_else(|| make::wildcard_pat().into()); + make::match_arm(iter::once(pattern), unwrap_trivial_block(else_block)) + } + None => make::match_arm( + iter::once(make::wildcard_pat().into()), + make::expr_unit().into(), + ), + } }; - let match_expr = - make::expr_match(expr, make::match_arm_list(vec![then_arm, else_arm])); + let arms = pat_bodies + .into_iter() + .map(|(pat, body)| { + let body = body.reset_indent().indent(IndentLevel(1)); + make::match_arm(vec![pat], unwrap_trivial_block(body)) + }) + .chain(iter::once(else_arm)); + let match_expr = make::expr_match(scrutinee_to_be_expr, make::match_arm_list(arms)); match_expr.indent(IndentLevel::from_node(if_expr.syntax())) }; - edit.replace_ast::(if_expr.into(), match_expr); + let expr = + if if_expr.syntax().parent().map_or(false, |it| ast::IfExpr::can_cast(it.kind())) { + make::block_expr(None, Some(match_expr)).into() + } else { + match_expr + }; + edit.replace_ast::(if_expr.into(), expr); }, ) } @@ -182,7 +213,33 @@ mod tests { use crate::tests::{check_assist, check_assist_target}; #[test] - fn test_replace_if_let_with_match_unwraps_simple_expressions() { + fn test_if_let_with_match_no_else() { + check_assist( + replace_if_let_with_match, + r#" +impl VariantData { + pub fn foo(&self) { + if $0let VariantData::Struct(..) = *self { + self.foo(); + } + } +} "#, + r#" +impl VariantData { + pub fn foo(&self) { + match *self { + VariantData::Struct(..) => { + self.foo(); + } + _ => (), + } + } +} "#, + ) + } + + #[test] + fn test_if_let_with_match_basic() { check_assist( replace_if_let_with_match, r#" @@ -190,8 +247,12 @@ impl VariantData { pub fn is_struct(&self) -> bool { if $0let VariantData::Struct(..) = *self { true - } else { + } else if let VariantData::Tuple(..) = *self { false + } else { + bar( + 123 + ) } } } "#, @@ -200,7 +261,12 @@ impl VariantData { pub fn is_struct(&self) -> bool { match *self { VariantData::Struct(..) => true, - _ => false, + VariantData::Tuple(..) => false, + _ => { + bar( + 123 + ) + } } } } "#, @@ -208,53 +274,35 @@ impl VariantData { } #[test] - fn test_replace_if_let_with_match_doesnt_unwrap_multiline_expressions() { + fn test_if_let_with_match_on_tail_if_let() { check_assist( replace_if_let_with_match, r#" -fn foo() { - if $0let VariantData::Struct(..) = a { - bar( - 123 - ) - } else { - false - } -} "#, - r#" -fn foo() { - match a { - VariantData::Struct(..) => { - bar( - 123 - ) - } - _ => false, - } -} "#, - ) - } - - #[test] - fn replace_if_let_with_match_target() { - check_assist_target( - replace_if_let_with_match, - r#" impl VariantData { pub fn is_struct(&self) -> bool { - if $0let VariantData::Struct(..) = *self { + if let VariantData::Struct(..) = *self { true + } else if let$0 VariantData::Tuple(..) = *self { + false } else { false } } } "#, - "if let VariantData::Struct(..) = *self { + r#" +impl VariantData { + pub fn is_struct(&self) -> bool { + if let VariantData::Struct(..) = *self { true } else { - false - }", - ); + match *self { + VariantData::Tuple(..) => false, + _ => false, + } +} + } +} "#, + ) } #[test] diff --git a/crates/ide_assists/src/utils.rs b/crates/ide_assists/src/utils.rs index 0ec236aa0c7..c3839a2b0be 100644 --- a/crates/ide_assists/src/utils.rs +++ b/crates/ide_assists/src/utils.rs @@ -48,15 +48,14 @@ pub fn extract_trivial_expression(block: &ast::BlockExpr) -> Option { return Some(expr); } // Unwrap `{ continue; }` - let (stmt,) = block.statements().next_tuple()?; + let stmt = block.statements().next()?; if let ast::Stmt::ExprStmt(expr_stmt) = stmt { if has_anything_else(expr_stmt.syntax()) { return None; } let expr = expr_stmt.expr()?; - match expr.syntax().kind() { - CONTINUE_EXPR | BREAK_EXPR | RETURN_EXPR => return Some(expr), - _ => (), + if matches!(expr.syntax().kind(), CONTINUE_EXPR | BREAK_EXPR | RETURN_EXPR) { + return Some(expr); } } None