diff --git a/crates/ide-assists/src/handlers/convert_match_to_let_else.rs b/crates/ide-assists/src/handlers/convert_match_to_let_else.rs new file mode 100644 index 00000000000..5bf04a3ad37 --- /dev/null +++ b/crates/ide-assists/src/handlers/convert_match_to_let_else.rs @@ -0,0 +1,413 @@ +use ide_db::defs::{Definition, NameRefClass}; +use syntax::{ + ast::{self, HasName}, + ted, AstNode, SyntaxNode, +}; + +use crate::{ + assist_context::{AssistContext, Assists}, + AssistId, AssistKind, +}; + +// Assist: convert_match_to_let_else +// +// Converts let statement with match initializer to let-else statement. +// +// ``` +// # //- minicore: option +// fn foo(opt: Option<()>) { +// let val = $0match opt { +// Some(it) => it, +// None => return, +// }; +// } +// ``` +// -> +// ``` +// fn foo(opt: Option<()>) { +// let Some(val) = opt else { return }; +// } +// ``` +pub(crate) fn convert_match_to_let_else(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let let_stmt: ast::LetStmt = ctx.find_node_at_offset()?; + let binding = find_binding(let_stmt.pat()?)?; + + let initializer = match let_stmt.initializer() { + Some(ast::Expr::MatchExpr(it)) => it, + _ => return None, + }; + let initializer_expr = initializer.expr()?; + + let (extracting_arm, diverging_arm) = match find_arms(ctx, &initializer) { + Some(it) => it, + None => return None, + }; + if extracting_arm.guard().is_some() { + cov_mark::hit!(extracting_arm_has_guard); + return None; + } + + let diverging_arm_expr = diverging_arm.expr()?; + let extracting_arm_pat = extracting_arm.pat()?; + let extracted_variable = find_extracted_variable(ctx, &extracting_arm)?; + + acc.add( + AssistId("convert_match_to_let_else", AssistKind::RefactorRewrite), + "Convert match to let-else", + let_stmt.syntax().text_range(), + |builder| { + let extracting_arm_pat = rename_variable(&extracting_arm_pat, extracted_variable, binding); + builder.replace( + let_stmt.syntax().text_range(), + format!("let {extracting_arm_pat} = {initializer_expr} else {{ {diverging_arm_expr} }};") + ) + }, + ) +} + +// Given a pattern, find the name introduced to the surrounding scope. +fn find_binding(pat: ast::Pat) -> Option { + if let ast::Pat::IdentPat(ident) = pat { + Some(ident) + } else { + None + } +} + +// Given a match expression, find extracting and diverging arms. +fn find_arms( + ctx: &AssistContext<'_>, + match_expr: &ast::MatchExpr, +) -> Option<(ast::MatchArm, ast::MatchArm)> { + let arms = match_expr.match_arm_list()?.arms().collect::>(); + if arms.len() != 2 { + return None; + } + + let mut extracting = None; + let mut diverging = None; + for arm in arms { + if ctx.sema.type_of_expr(&arm.expr().unwrap()).unwrap().original().is_never() { + diverging = Some(arm); + } else { + extracting = Some(arm); + } + } + + match (extracting, diverging) { + (Some(extracting), Some(diverging)) => Some((extracting, diverging)), + _ => { + cov_mark::hit!(non_diverging_match); + None + } + } +} + +// Given an extracting arm, find the extracted variable. +fn find_extracted_variable(ctx: &AssistContext<'_>, arm: &ast::MatchArm) -> Option { + match arm.expr()? { + ast::Expr::PathExpr(path) => { + let name_ref = path.syntax().descendants().find_map(ast::NameRef::cast)?; + match NameRefClass::classify(&ctx.sema, &name_ref)? { + NameRefClass::Definition(Definition::Local(local)) => { + let source = local.source(ctx.db()).value.left()?; + Some(source.name()?) + } + _ => None, + } + } + _ => { + cov_mark::hit!(extracting_arm_is_not_an_identity_expr); + return None; + } + } +} + +// Rename `extracted` with `binding` in `pat`. +fn rename_variable(pat: &ast::Pat, extracted: ast::Name, binding: ast::IdentPat) -> SyntaxNode { + let syntax = pat.syntax().clone_for_update(); + let extracted_syntax = syntax.covering_element(extracted.syntax().text_range()); + + // If `extracted` variable is a record field, we should rename it to `binding`, + // otherwise we just need to replace `extracted` with `binding`. + + if let Some(record_pat_field) = extracted_syntax.ancestors().find_map(ast::RecordPatField::cast) + { + if let Some(name_ref) = record_pat_field.field_name() { + ted::replace( + record_pat_field.syntax(), + ast::make::record_pat_field(ast::make::name_ref(&name_ref.text()), binding.into()) + .syntax() + .clone_for_update(), + ); + } + } else { + ted::replace(extracted_syntax, binding.syntax().clone_for_update()); + } + + syntax +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn should_not_be_applicable_for_non_diverging_match() { + cov_mark::check!(non_diverging_match); + check_assist_not_applicable( + convert_match_to_let_else, + r#" +//- minicore: option +fn foo(opt: Option<()>) { + let val = $0match opt { + Some(it) => it, + None => (), + }; +} +"#, + ); + } + + #[test] + fn should_not_be_applicable_if_extracting_arm_is_not_an_identity_expr() { + cov_mark::check_count!(extracting_arm_is_not_an_identity_expr, 2); + check_assist_not_applicable( + convert_match_to_let_else, + r#" +//- minicore: option +fn foo(opt: Option) { + let val = $0match opt { + Some(it) => it + 1, + None => return, + }; +} +"#, + ); + + check_assist_not_applicable( + convert_match_to_let_else, + r#" +//- minicore: option +fn foo(opt: Option<()>) { + let val = $0match opt { + Some(it) => { + let _ = 1 + 1; + it + }, + None => return, + }; +} +"#, + ); + } + + #[test] + fn should_not_be_applicable_if_extracting_arm_has_guard() { + cov_mark::check!(extracting_arm_has_guard); + check_assist_not_applicable( + convert_match_to_let_else, + r#" +//- minicore: option +fn foo(opt: Option<()>) { + let val = $0match opt { + Some(it) if 2 > 1 => it, + None => return, + }; +} +"#, + ); + } + + #[test] + fn basic_pattern() { + check_assist( + convert_match_to_let_else, + r#" +//- minicore: option +fn foo(opt: Option<()>) { + let val = $0match opt { + Some(it) => it, + None => return, + }; +} + "#, + r#" +fn foo(opt: Option<()>) { + let Some(val) = opt else { return }; +} + "#, + ); + } + + #[test] + fn keeps_modifiers() { + check_assist( + convert_match_to_let_else, + r#" +//- minicore: option +fn foo(opt: Option<()>) { + let ref mut val = $0match opt { + Some(it) => it, + None => return, + }; +} + "#, + r#" +fn foo(opt: Option<()>) { + let Some(ref mut val) = opt else { return }; +} + "#, + ); + } + + #[test] + fn nested_pattern() { + check_assist( + convert_match_to_let_else, + r#" +//- minicore: option, result +fn foo(opt: Option>) { + let val = $0match opt { + Some(Ok(it)) => it, + _ => return, + }; +} + "#, + r#" +fn foo(opt: Option>) { + let Some(Ok(val)) = opt else { return }; +} + "#, + ); + } + + #[test] + fn works_with_any_diverging_block() { + check_assist( + convert_match_to_let_else, + r#" +//- minicore: option +fn foo(opt: Option<()>) { + loop { + let val = $0match opt { + Some(it) => it, + None => break, + }; + } +} + "#, + r#" +fn foo(opt: Option<()>) { + loop { + let Some(val) = opt else { break }; + } +} + "#, + ); + + check_assist( + convert_match_to_let_else, + r#" +//- minicore: option +fn foo(opt: Option<()>) { + loop { + let val = $0match opt { + Some(it) => it, + None => continue, + }; + } +} + "#, + r#" +fn foo(opt: Option<()>) { + loop { + let Some(val) = opt else { continue }; + } +} + "#, + ); + + check_assist( + convert_match_to_let_else, + r#" +//- minicore: option +fn panic() -> ! {} + +fn foo(opt: Option<()>) { + loop { + let val = $0match opt { + Some(it) => it, + None => panic(), + }; + } +} + "#, + r#" +fn panic() -> ! {} + +fn foo(opt: Option<()>) { + loop { + let Some(val) = opt else { panic() }; + } +} + "#, + ); + } + + #[test] + fn struct_pattern() { + check_assist( + convert_match_to_let_else, + r#" +//- minicore: option +struct Point { + x: i32, + y: i32, +} + +fn foo(opt: Option) { + let val = $0match opt { + Some(Point { x: 0, y }) => y, + _ => return, + }; +} + "#, + r#" +struct Point { + x: i32, + y: i32, +} + +fn foo(opt: Option) { + let Some(Point { x: 0, y: val }) = opt else { return }; +} + "#, + ); + } + + #[test] + fn renames_whole_binding() { + check_assist( + convert_match_to_let_else, + r#" +//- minicore: option +fn foo(opt: Option) -> Option { + let val = $0match opt { + it @ Some(42) => it, + _ => return None, + }; + val +} + "#, + r#" +fn foo(opt: Option) -> Option { + let val @ Some(42) = opt else { return None }; + val +} + "#, + ); + } +} diff --git a/crates/ide-assists/src/lib.rs b/crates/ide-assists/src/lib.rs index a07318cefad..387cc631428 100644 --- a/crates/ide-assists/src/lib.rs +++ b/crates/ide-assists/src/lib.rs @@ -120,6 +120,7 @@ mod handlers { mod convert_into_to_from; mod convert_iter_for_each_to_for; mod convert_let_else_to_match; + mod convert_match_to_let_else; mod convert_tuple_struct_to_named_struct; mod convert_named_struct_to_tuple_struct; mod convert_to_guarded_return; @@ -220,6 +221,7 @@ mod handlers { convert_iter_for_each_to_for::convert_for_loop_with_for_each, convert_let_else_to_match::convert_let_else_to_match, convert_named_struct_to_tuple_struct::convert_named_struct_to_tuple_struct, + convert_match_to_let_else::convert_match_to_let_else, convert_to_guarded_return::convert_to_guarded_return, convert_tuple_struct_to_named_struct::convert_tuple_struct_to_named_struct, convert_two_arm_bool_match_to_matches_macro::convert_two_arm_bool_match_to_matches_macro, diff --git a/crates/ide-assists/src/tests/generated.rs b/crates/ide-assists/src/tests/generated.rs index 2c4000efe0f..029d169899b 100644 --- a/crates/ide-assists/src/tests/generated.rs +++ b/crates/ide-assists/src/tests/generated.rs @@ -407,6 +407,27 @@ fn main() { ) } +#[test] +fn doctest_convert_match_to_let_else() { + check_doc_test( + "convert_match_to_let_else", + r#####" +//- minicore: option +fn foo(opt: Option<()>) { + let val = $0match opt { + Some(it) => it, + None => return, + }; +} +"#####, + r#####" +fn foo(opt: Option<()>) { + let Some(val) = opt else { return }; +} +"#####, + ) +} + #[test] fn doctest_convert_named_struct_to_tuple_struct() { check_doc_test(