diff --git a/crates/ide_assists/src/handlers/wrap_return_type_in_result.rs b/crates/ide_assists/src/handlers/wrap_return_type_in_result.rs index 65d0640a383..ae582dc7ecb 100644 --- a/crates/ide_assists/src/handlers/wrap_return_type_in_result.rs +++ b/crates/ide_assists/src/handlers/wrap_return_type_in_result.rs @@ -1,6 +1,6 @@ use std::iter; -use ide_db::helpers::for_each_tail_expr; +use ide_db::helpers::{for_each_tail_expr, FamousDefs}; use syntax::{ ast::{self, make, Expr}, match_ast, AstNode, @@ -33,16 +33,15 @@ pub(crate) fn wrap_return_type_in_result(acc: &mut Assists, ctx: &AssistContext) _ => return None, } }; - let body = ast::Expr::BlockExpr(body); let type_ref = &ret_type.ty()?; - let ret_type_str = type_ref.syntax().text().to_string(); - let first_part_ret_type = ret_type_str.splitn(2, '<').next(); - if let Some(ret_type_first_part) = first_part_ret_type { - if ret_type_first_part.ends_with("Result") { - cov_mark::hit!(wrap_return_type_in_result_simple_return_type_already_result); - return None; - } + let ty = ctx.sema.resolve_type(type_ref).and_then(|ty| ty.as_adt()); + let result_enum = + FamousDefs(&ctx.sema, ctx.sema.scope(type_ref.syntax()).krate()).core_result_Result()?; + + if matches!(ty, Some(hir::Adt::Enum(ret_type)) if ret_type == result_enum) { + cov_mark::hit!(wrap_return_type_in_result_simple_return_type_already_result); + return None; } acc.add( @@ -50,6 +49,8 @@ pub(crate) fn wrap_return_type_in_result(acc: &mut Assists, ctx: &AssistContext) "Wrap return type in Result", type_ref.syntax().text_range(), |builder| { + let body = ast::Expr::BlockExpr(body); + let mut exprs_to_wrap = Vec::new(); let tail_cb = &mut |e: &_| tail_cb_impl(&mut exprs_to_wrap, e); body.walk(&mut |expr| { @@ -88,6 +89,11 @@ fn tail_cb_impl(acc: &mut Vec, e: &ast::Expr) { for_each_tail_expr(&break_expr_arg, &mut |e| tail_cb_impl(acc, e)) } } + Expr::ReturnExpr(ret_expr) => { + if let Some(ret_expr_arg) = &ret_expr.expr() { + for_each_tail_expr(ret_expr_arg, &mut |e| tail_cb_impl(acc, e)); + } + } e => acc.push(e.clone()), } } @@ -98,10 +104,17 @@ mod tests { use super::*; - #[test] - fn wrap_return_type_in_result_simple() { + fn check(ra_fixture_before: &str, ra_fixture_after: &str) { check_assist( wrap_return_type_in_result, + &format!("//- minicore: result\n{}", ra_fixture_before.trim_start()), + ra_fixture_after, + ); + } + + #[test] + fn wrap_return_type_in_result_simple() { + check( r#" fn foo() -> i3$02 { let test = "test"; @@ -119,8 +132,7 @@ fn foo() -> Result { #[test] fn wrap_return_type_break_split_tail() { - check_assist( - wrap_return_type_in_result, + check( r#" fn foo() -> i3$02 { loop { @@ -148,8 +160,7 @@ fn foo() -> Result { #[test] fn wrap_return_type_in_result_simple_closure() { - check_assist( - wrap_return_type_in_result, + check( r#" fn foo() { || -> i32$0 { @@ -207,7 +218,8 @@ fn foo() { check_assist_not_applicable( wrap_return_type_in_result, r#" -fn foo() -> std::result::Result { +//- minicore: result +fn foo() -> core::result::Result { let test = "test"; return 42i32; } @@ -221,6 +233,7 @@ fn foo() -> std::result::Result { check_assist_not_applicable( wrap_return_type_in_result, r#" +//- minicore: result fn foo() -> Result { let test = "test"; return 42i32; @@ -246,8 +259,7 @@ fn foo() { #[test] fn wrap_return_type_in_result_simple_with_cursor() { - check_assist( - wrap_return_type_in_result, + check( r#" fn foo() -> $0i32 { let test = "test"; @@ -265,8 +277,7 @@ fn foo() -> Result { #[test] fn wrap_return_type_in_result_simple_with_tail() { - check_assist( - wrap_return_type_in_result, + check( r#" fn foo() ->$0 i32 { let test = "test"; @@ -284,8 +295,7 @@ fn foo() -> Result { #[test] fn wrap_return_type_in_result_simple_with_tail_closure() { - check_assist( - wrap_return_type_in_result, + check( r#" fn foo() { || ->$0 i32 { @@ -307,17 +317,12 @@ fn foo() { #[test] fn wrap_return_type_in_result_simple_with_tail_only() { - check_assist( - wrap_return_type_in_result, - r#"fn foo() -> i32$0 { 42i32 }"#, - r#"fn foo() -> Result { Ok(42i32) }"#, - ); + check(r#"fn foo() -> i32$0 { 42i32 }"#, r#"fn foo() -> Result { Ok(42i32) }"#); } #[test] fn wrap_return_type_in_result_simple_with_tail_block_like() { - check_assist( - wrap_return_type_in_result, + check( r#" fn foo() -> i32$0 { if true { @@ -341,8 +346,7 @@ fn foo() -> Result { #[test] fn wrap_return_type_in_result_simple_without_block_closure() { - check_assist( - wrap_return_type_in_result, + check( r#" fn foo() { || -> i32$0 { @@ -370,8 +374,7 @@ fn foo() { #[test] fn wrap_return_type_in_result_simple_with_nested_if() { - check_assist( - wrap_return_type_in_result, + check( r#" fn foo() -> i32$0 { if true { @@ -403,8 +406,7 @@ fn foo() -> Result { #[test] fn wrap_return_type_in_result_simple_with_await() { - check_assist( - wrap_return_type_in_result, + check( r#" async fn foo() -> i$032 { if true { @@ -436,8 +438,7 @@ async fn foo() -> Result { #[test] fn wrap_return_type_in_result_simple_with_array() { - check_assist( - wrap_return_type_in_result, + check( r#"fn foo() -> [i32;$0 3] { [1, 2, 3] }"#, r#"fn foo() -> Result<[i32; 3], ${0:_}> { Ok([1, 2, 3]) }"#, ); @@ -445,8 +446,7 @@ async fn foo() -> Result { #[test] fn wrap_return_type_in_result_simple_with_cast() { - check_assist( - wrap_return_type_in_result, + check( r#" fn foo() -$0> i32 { if true { @@ -478,8 +478,7 @@ fn foo() -> Result { #[test] fn wrap_return_type_in_result_simple_with_tail_block_like_match() { - check_assist( - wrap_return_type_in_result, + check( r#" fn foo() -> i32$0 { let my_var = 5; @@ -503,8 +502,7 @@ fn foo() -> Result { #[test] fn wrap_return_type_in_result_simple_with_loop_with_tail() { - check_assist( - wrap_return_type_in_result, + check( r#" fn foo() -> i32$0 { let my_var = 5; @@ -530,8 +528,7 @@ fn foo() -> Result { #[test] fn wrap_return_type_in_result_simple_with_loop_in_let_stmt() { - check_assist( - wrap_return_type_in_result, + check( r#" fn foo() -> i32$0 { let my_var = let x = loop { @@ -553,8 +550,7 @@ fn foo() -> Result { #[test] fn wrap_return_type_in_result_simple_with_tail_block_like_match_return_expr() { - check_assist( - wrap_return_type_in_result, + check( r#" fn foo() -> i32$0 { let my_var = 5; @@ -577,8 +573,7 @@ fn foo() -> Result { "#, ); - check_assist( - wrap_return_type_in_result, + check( r#" fn foo() -> i32$0 { let my_var = 5; @@ -606,8 +601,7 @@ fn foo() -> Result { #[test] fn wrap_return_type_in_result_simple_with_tail_block_like_match_deeper() { - check_assist( - wrap_return_type_in_result, + check( r#" fn foo() -> i32$0 { let my_var = 5; @@ -655,8 +649,7 @@ fn foo() -> Result { #[test] fn wrap_return_type_in_result_simple_with_tail_block_like_early_return() { - check_assist( - wrap_return_type_in_result, + check( r#" fn foo() -> i$032 { let test = "test"; @@ -680,8 +673,7 @@ fn foo() -> Result { #[test] fn wrap_return_type_in_result_simple_with_closure() { - check_assist( - wrap_return_type_in_result, + check( r#" fn foo(the_field: u32) ->$0 u32 { let true_closure = || { return true; }; @@ -712,55 +704,53 @@ fn foo(the_field: u32) -> Result { "#, ); - check_assist( - wrap_return_type_in_result, + check( r#" - fn foo(the_field: u32) -> u32$0 { - let true_closure = || { - return true; - }; - if the_field < 5 { - let mut i = 0; +fn foo(the_field: u32) -> u32$0 { + let true_closure = || { + return true; + }; + if the_field < 5 { + let mut i = 0; - if true_closure() { - return 99; - } else { - return 0; - } - } - let t = None; + if true_closure() { + return 99; + } else { + return 0; + } + } + let t = None; - t.unwrap_or_else(|| the_field) - } - "#, + t.unwrap_or_else(|| the_field) +} +"#, r#" - fn foo(the_field: u32) -> Result { - let true_closure = || { - return true; - }; - if the_field < 5 { - let mut i = 0; +fn foo(the_field: u32) -> Result { + let true_closure = || { + return true; + }; + if the_field < 5 { + let mut i = 0; - if true_closure() { - return Ok(99); - } else { - return Ok(0); - } - } - let t = None; + if true_closure() { + return Ok(99); + } else { + return Ok(0); + } + } + let t = None; - Ok(t.unwrap_or_else(|| the_field)) - } - "#, + Ok(t.unwrap_or_else(|| the_field)) +} +"#, ); } #[test] fn wrap_return_type_in_result_simple_with_weird_forms() { - check_assist( - wrap_return_type_in_result, + check( r#" fn foo() -> i32$0 { let test = "test"; @@ -793,8 +783,7 @@ fn foo() -> Result { "#, ); - check_assist( - wrap_return_type_in_result, + check( r#" fn foo(the_field: u32) -> u32$0 { if the_field < 5 { @@ -833,8 +822,7 @@ fn foo(the_field: u32) -> Result { "#, ); - check_assist( - wrap_return_type_in_result, + check( r#" fn foo(the_field: u32) -> u3$02 { if the_field < 5 { @@ -861,8 +849,7 @@ fn foo(the_field: u32) -> Result { "#, ); - check_assist( - wrap_return_type_in_result, + check( r#" fn foo(the_field: u32) -> u32$0 { if the_field < 5 { @@ -891,8 +878,7 @@ fn foo(the_field: u32) -> Result { "#, ); - check_assist( - wrap_return_type_in_result, + check( r#" fn foo(the_field: u32) -> $0u32 { if the_field < 5 { diff --git a/crates/ide_db/src/helpers.rs b/crates/ide_db/src/helpers.rs index 632fd365901..e8cdcbf3fa5 100644 --- a/crates/ide_db/src/helpers.rs +++ b/crates/ide_db/src/helpers.rs @@ -122,6 +122,10 @@ impl FamousDefs<'_, '_> { self.find_enum("core:option:Option") } + pub fn core_result_Result(&self) -> Option { + self.find_enum("core:result:Result") + } + pub fn core_default_Default(&self) -> Option { self.find_trait("core:default:Default") } @@ -206,6 +210,7 @@ impl SnippetCap { } /// Calls `cb` on each expression inside `expr` that is at "tail position". +/// Does not walk into `break` or `return` expressions. pub fn for_each_tail_expr(expr: &ast::Expr, cb: &mut dyn FnMut(&ast::Expr)) { match expr { ast::Expr::BlockExpr(b) => {