diff --git a/crates/ide_assists/src/handlers/replace_if_let_with_match.rs b/crates/ide_assists/src/handlers/replace_if_let_with_match.rs index a84f8f12077..b1c5f44213c 100644 --- a/crates/ide_assists/src/handlers/replace_if_let_with_match.rs +++ b/crates/ide_assists/src/handlers/replace_if_let_with_match.rs @@ -176,21 +176,21 @@ fn make_else_arm( // ``` pub(crate) fn replace_match_with_if_let(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { let match_expr: ast::MatchExpr = ctx.find_node_at_offset()?; + let mut arms = match_expr.match_arm_list()?.arms(); - let first_arm = arms.next()?; - let second_arm = arms.next()?; + let (first_arm, second_arm) = (arms.next()?, arms.next()?); if arms.next().is_some() || first_arm.guard().is_some() || second_arm.guard().is_some() { return None; } - let condition_expr = match_expr.expr()?; - let (if_let_pat, then_expr, else_expr) = if is_pat_wildcard_or_sad(&ctx.sema, &first_arm.pat()?) - { - (second_arm.pat()?, second_arm.expr()?, first_arm.expr()?) - } else if is_pat_wildcard_or_sad(&ctx.sema, &second_arm.pat()?) { - (first_arm.pat()?, first_arm.expr()?, second_arm.expr()?) - } else { - return None; - }; + + let (if_let_pat, then_expr, else_expr) = pick_pattern_and_expr_order( + &ctx.sema, + first_arm.pat()?, + second_arm.pat()?, + first_arm.expr()?, + second_arm.expr()?, + )?; + let scrutinee = match_expr.expr()?; let target = match_expr.syntax().text_range(); acc.add( @@ -198,26 +198,25 @@ pub(crate) fn replace_match_with_if_let(acc: &mut Assists, ctx: &AssistContext) "Replace with if let", target, move |edit| { - let condition = make::condition(condition_expr, Some(if_let_pat)); + let condition = make::condition(scrutinee, Some(if_let_pat)); let then_block = match then_expr.reset_indent() { ast::Expr::BlockExpr(block) => block, expr => make::block_expr(iter::empty(), Some(expr)), }; let else_expr = match else_expr { - ast::Expr::BlockExpr(block) - if block.statements().count() == 0 && block.tail_expr().is_none() => - { - None - } - ast::Expr::TupleExpr(tuple) if tuple.fields().count() == 0 => None, + ast::Expr::BlockExpr(block) if block.is_empty() => None, + ast::Expr::TupleExpr(tuple) if tuple.fields().next().is_none() => None, expr => Some(expr), }; let if_let_expr = make::expr_if( condition, then_block, - else_expr.map(|else_expr| { - ast::ElseBranch::Block(make::block_expr(iter::empty(), Some(else_expr))) - }), + else_expr + .map(|expr| match expr { + ast::Expr::BlockExpr(block) => block, + expr => (make::block_expr(iter::empty(), Some(expr))), + }) + .map(ast::ElseBranch::Block), ) .indent(IndentLevel::from_node(match_expr.syntax())); @@ -226,11 +225,50 @@ pub(crate) fn replace_match_with_if_let(acc: &mut Assists, ctx: &AssistContext) ) } -fn is_pat_wildcard_or_sad(sema: &hir::Semantics, pat: &ast::Pat) -> bool { +/// Pick the pattern for the if let condition and return the expressions for the `then` body and `else` body in that order. +fn pick_pattern_and_expr_order( + sema: &hir::Semantics, + pat: ast::Pat, + pat2: ast::Pat, + expr: ast::Expr, + expr2: ast::Expr, +) -> Option<(ast::Pat, ast::Expr, ast::Expr)> { + let res = match (pat, pat2) { + (ast::Pat::WildcardPat(_), _) => return None, + (pat, sad_pat) if is_sad_pat(sema, &sad_pat) => (pat, expr, expr2), + (sad_pat, pat) if is_sad_pat(sema, &sad_pat) => (pat, expr2, expr), + (pat, pat2) => match (binds_name(&pat), binds_name(&pat2)) { + (true, true) => return None, + (true, false) => (pat, expr, expr2), + (false, true) => (pat2, expr2, expr), + (false, false) => (pat, expr, expr2), + }, + }; + Some(res) +} + +fn binds_name(pat: &ast::Pat) -> bool { + let binds_name_v = |pat| binds_name(&pat); + match pat { + ast::Pat::IdentPat(_) => true, + ast::Pat::MacroPat(_) => true, + ast::Pat::OrPat(pat) => pat.pats().any(binds_name_v), + ast::Pat::SlicePat(pat) => pat.pats().any(binds_name_v), + ast::Pat::TuplePat(it) => it.fields().any(binds_name_v), + ast::Pat::TupleStructPat(it) => it.fields().any(binds_name_v), + ast::Pat::RecordPat(it) => it + .record_pat_field_list() + .map_or(false, |rpfl| rpfl.fields().flat_map(|rpf| rpf.pat()).any(binds_name_v)), + ast::Pat::RefPat(pat) => pat.pat().map_or(false, binds_name_v), + ast::Pat::BoxPat(pat) => pat.pat().map_or(false, binds_name_v), + ast::Pat::ParenPat(pat) => pat.pat().map_or(false, binds_name_v), + _ => false, + } +} +fn is_sad_pat(sema: &hir::Semantics, pat: &ast::Pat) -> bool { sema.type_of_pat(pat) .and_then(|ty| TryEnum::from_ty(sema, &ty)) - .map(|it| it.sad_pattern().syntax().text() == pat.syntax().text()) - .unwrap_or_else(|| matches!(pat, ast::Pat::WildcardPat(_))) + .map_or(false, |it| it.sad_pattern().syntax().text() == pat.syntax().text()) } #[cfg(test)] @@ -662,4 +700,79 @@ fn main() { "#, ) } + + #[test] + fn replace_match_with_if_let_exhaustive() { + check_assist( + replace_match_with_if_let, + r#" +fn print_source(def_source: ModuleSource) { + match def_so$0urce { + ModuleSource::SourceFile(..) => { println!("source file"); } + ModuleSource::Module(..) => { println!("module"); } + } +} +"#, + r#" +fn print_source(def_source: ModuleSource) { + if let ModuleSource::SourceFile(..) = def_source { println!("source file"); } else { println!("module"); } +} +"#, + ) + } + + #[test] + fn replace_match_with_if_let_prefer_name_bind() { + check_assist( + replace_match_with_if_let, + r#" +fn foo() { + match $0Foo(0) { + Foo(_) => (), + Bar(bar) => println!("bar {}", bar), + } +} +"#, + r#" +fn foo() { + if let Bar(bar) = Foo(0) { + println!("bar {}", bar) + } +} +"#, + ); + check_assist( + replace_match_with_if_let, + r#" +fn foo() { + match $0Foo(0) { + Bar(bar) => println!("bar {}", bar), + Foo(_) => (), + } +} +"#, + r#" +fn foo() { + if let Bar(bar) = Foo(0) { + println!("bar {}", bar) + } +} +"#, + ); + } + + #[test] + fn replace_match_with_if_let_rejects_double_name_bindings() { + check_assist_not_applicable( + replace_match_with_if_let, + r#" +fn foo() { + match $0Foo(0) { + Foo(foo) => println!("bar {}", foo), + Bar(bar) => println!("bar {}", bar), + } +} +"#, + ); + } } diff --git a/crates/syntax/src/ast/node_ext.rs b/crates/syntax/src/ast/node_ext.rs index 49c478b817a..d8c5c4e76fb 100644 --- a/crates/syntax/src/ast/node_ext.rs +++ b/crates/syntax/src/ast/node_ext.rs @@ -49,6 +49,10 @@ impl ast::BlockExpr { pub fn items(&self) -> AstChildren { support::children(self.syntax()) } + + pub fn is_empty(&self) -> bool { + self.statements().next().is_none() && self.tail_expr().is_none() + } } impl ast::Expr {