From 1e1761e9ae847f452b5f279708efa1836380bd46 Mon Sep 17 00:00:00 2001
From: DropDemBits <r3usrlnd@gmail.com>
Date: Tue, 7 Nov 2023 21:22:53 -0500
Subject: [PATCH] Migrate `extract_variable` to mutable ast

---
 .../src/handlers/extract_variable.rs          | 229 ++++++++++--------
 1 file changed, 133 insertions(+), 96 deletions(-)

diff --git a/crates/ide-assists/src/handlers/extract_variable.rs b/crates/ide-assists/src/handlers/extract_variable.rs
index e7c884dcb70..874b81d3b63 100644
--- a/crates/ide-assists/src/handlers/extract_variable.rs
+++ b/crates/ide-assists/src/handlers/extract_variable.rs
@@ -1,12 +1,8 @@
 use hir::TypeInfo;
-use stdx::format_to;
 use syntax::{
-    ast::{self, AstNode},
-    NodeOrToken,
-    SyntaxKind::{
-        BLOCK_EXPR, BREAK_EXPR, CLOSURE_EXPR, COMMENT, LOOP_EXPR, MATCH_ARM, MATCH_GUARD,
-        PATH_EXPR, RETURN_EXPR,
-    },
+    ast::{self, edit::IndentLevel, edit_in_place::Indent, make, AstNode, HasName},
+    ted, NodeOrToken,
+    SyntaxKind::{BLOCK_EXPR, BREAK_EXPR, COMMENT, LOOP_EXPR, MATCH_GUARD, PATH_EXPR, RETURN_EXPR},
     SyntaxNode,
 };
 
@@ -66,98 +62,140 @@ pub(crate) fn extract_variable(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
         .as_ref()
         .map_or(false, |it| matches!(it, ast::Expr::FieldExpr(_) | ast::Expr::MethodCallExpr(_)));
 
-    let reference_modifier = match ty.filter(|_| needs_adjust) {
-        Some(receiver_type) if receiver_type.is_mutable_reference() => "&mut ",
-        Some(receiver_type) if receiver_type.is_reference() => "&",
-        _ => "",
-    };
-
-    let var_modifier = match parent {
-        Some(ast::Expr::RefExpr(expr)) if expr.mut_token().is_some() => "mut ",
-        _ => "",
-    };
-
     let anchor = Anchor::from(&to_extract)?;
-    let indent = anchor.syntax().prev_sibling_or_token()?.as_token()?.clone();
     let target = to_extract.syntax().text_range();
     acc.add(
         AssistId("extract_variable", AssistKind::RefactorExtract),
         "Extract into variable",
         target,
         move |edit| {
-            let field_shorthand =
-                match to_extract.syntax().parent().and_then(ast::RecordExprField::cast) {
-                    Some(field) => field.name_ref(),
-                    None => None,
-                };
+            let field_shorthand = to_extract
+                .syntax()
+                .parent()
+                .and_then(ast::RecordExprField::cast)
+                .filter(|field| field.name_ref().is_some());
 
-            let mut buf = String::new();
+            let (var_name, expr_replace) = match field_shorthand {
+                Some(field) => (field.to_string(), field.syntax().clone()),
+                None => (
+                    suggest_name::for_variable(&to_extract, &ctx.sema),
+                    to_extract.syntax().clone(),
+                ),
+            };
 
-            let var_name = match &field_shorthand {
-                Some(it) => it.to_string(),
-                None => suggest_name::for_variable(&to_extract, &ctx.sema),
+            let ident_pat = match parent {
+                Some(ast::Expr::RefExpr(expr)) if expr.mut_token().is_some() => {
+                    make::ident_pat(false, true, make::name(&var_name))
+                }
+                _ => make::ident_pat(false, false, make::name(&var_name)),
             };
-            let expr_range = match &field_shorthand {
-                Some(it) => it.syntax().text_range().cover(to_extract.syntax().text_range()),
-                None => to_extract.syntax().text_range(),
+
+            let to_extract = match ty.as_ref().filter(|_| needs_adjust) {
+                Some(receiver_type) if receiver_type.is_mutable_reference() => {
+                    make::expr_ref(to_extract, true)
+                }
+                Some(receiver_type) if receiver_type.is_reference() => {
+                    make::expr_ref(to_extract, false)
+                }
+                _ => to_extract,
             };
 
+            let expr_replace = edit.make_syntax_mut(expr_replace);
+            let let_stmt =
+                make::let_stmt(ident_pat.into(), None, Some(to_extract)).clone_for_update();
+            let name_expr = make::expr_path(make::ext::ident_path(&var_name)).clone_for_update();
+
             match anchor {
-                Anchor::Before(_) | Anchor::Replace(_) => {
-                    format_to!(buf, "let {var_modifier}{var_name} = {reference_modifier}")
-                }
-                Anchor::WrapInBlock(_) => {
-                    format_to!(buf, "{{ let {var_name} = {reference_modifier}")
-                }
-            };
-            format_to!(buf, "{to_extract}");
+                Anchor::Before(place) => {
+                    let prev_ws = place.prev_sibling_or_token().and_then(|it| it.into_token());
+                    let indent_to = IndentLevel::from_node(&place);
+                    let insert_place = edit.make_syntax_mut(place);
 
-            if let Anchor::Replace(stmt) = anchor {
-                cov_mark::hit!(test_extract_var_expr_stmt);
-                if stmt.semicolon_token().is_none() {
-                    buf.push(';');
-                }
-                match ctx.config.snippet_cap {
-                    Some(cap) => {
-                        let snip = buf.replace(
-                            &format!("let {var_modifier}{var_name}"),
-                            &format!("let {var_modifier}$0{var_name}"),
-                        );
-                        edit.replace_snippet(cap, expr_range, snip)
-                    }
-                    None => edit.replace(expr_range, buf),
-                }
-                return;
-            }
+                    // Adjust ws to insert depending on if this is all inline or on separate lines
+                    let trailing_ws = if prev_ws.is_some_and(|it| it.text().starts_with("\n")) {
+                        format!("\n{indent_to}")
+                    } else {
+                        format!(" ")
+                    };
 
-            buf.push(';');
-
-            // We want to maintain the indent level,
-            // but we do not want to duplicate possible
-            // extra newlines in the indent block
-            let text = indent.text();
-            if text.starts_with('\n') {
-                buf.push('\n');
-                buf.push_str(text.trim_start_matches('\n'));
-            } else {
-                buf.push_str(text);
-            }
-
-            edit.replace(expr_range, var_name.clone());
-            let offset = anchor.syntax().text_range().start();
-            match ctx.config.snippet_cap {
-                Some(cap) => {
-                    let snip = buf.replace(
-                        &format!("let {var_modifier}{var_name}"),
-                        &format!("let {var_modifier}$0{var_name}"),
+                    ted::insert_all_raw(
+                        ted::Position::before(insert_place),
+                        vec![
+                            let_stmt.syntax().clone().into(),
+                            make::tokens::whitespace(&trailing_ws).into(),
+                        ],
                     );
-                    edit.insert_snippet(cap, offset, snip)
-                }
-                None => edit.insert(offset, buf),
-            }
 
-            if let Anchor::WrapInBlock(_) = anchor {
-                edit.insert(anchor.syntax().text_range().end(), " }");
+                    ted::replace(expr_replace, name_expr.syntax());
+
+                    if let Some(cap) = ctx.config.snippet_cap {
+                        if let Some(ast::Pat::IdentPat(ident_pat)) = let_stmt.pat() {
+                            if let Some(name) = ident_pat.name() {
+                                edit.add_tabstop_before(cap, name);
+                            }
+                        }
+                    }
+                }
+                Anchor::Replace(stmt) => {
+                    cov_mark::hit!(test_extract_var_expr_stmt);
+
+                    let stmt_replace = edit.make_mut(stmt);
+                    ted::replace(stmt_replace.syntax(), let_stmt.syntax());
+
+                    if let Some(cap) = ctx.config.snippet_cap {
+                        if let Some(ast::Pat::IdentPat(ident_pat)) = let_stmt.pat() {
+                            if let Some(name) = ident_pat.name() {
+                                edit.add_tabstop_before(cap, name);
+                            }
+                        }
+                    }
+                }
+                Anchor::WrapInBlock(to_wrap) => {
+                    let indent_to = to_wrap.indent_level();
+
+                    let block = if to_wrap.syntax() == &expr_replace {
+                        // Since `expr_replace` is the same that needs to be wrapped in a block,
+                        // we can just directly replace it with a block
+                        let block =
+                            make::block_expr([let_stmt.into()], Some(name_expr)).clone_for_update();
+                        ted::replace(expr_replace, block.syntax());
+
+                        block
+                    } else {
+                        // `expr_replace` is a descendant of `to_wrap`, so both steps need to be
+                        // handled seperately, otherwise we wrap the wrong expression
+                        let to_wrap = edit.make_mut(to_wrap);
+
+                        // Replace the target expr first so that we don't need to find where
+                        // `expr_replace` is in the wrapped `to_wrap`
+                        ted::replace(expr_replace, name_expr.syntax());
+
+                        // Wrap `to_wrap` in a block
+                        let block = make::block_expr([let_stmt.into()], Some(to_wrap.clone()))
+                            .clone_for_update();
+                        ted::replace(to_wrap.syntax(), block.syntax());
+
+                        block
+                    };
+
+                    if let Some(cap) = ctx.config.snippet_cap {
+                        // Adding a tabstop to `name` requires finding the let stmt again, since
+                        // the existing `let_stmt` is not actually added to the tree
+                        let pat = block.statements().find_map(|stmt| {
+                            let ast::Stmt::LetStmt(let_stmt) = stmt else { return None };
+                            let_stmt.pat()
+                        });
+
+                        if let Some(ast::Pat::IdentPat(ident_pat)) = pat {
+                            if let Some(name) = ident_pat.name() {
+                                edit.add_tabstop_before(cap, name);
+                            }
+                        }
+                    }
+
+                    // fixup indentation of block
+                    block.indent(indent_to);
+                }
             }
         },
     )
@@ -181,7 +219,7 @@ fn valid_target_expr(node: SyntaxNode) -> Option<ast::Expr> {
 enum Anchor {
     Before(SyntaxNode),
     Replace(ast::ExprStmt),
-    WrapInBlock(SyntaxNode),
+    WrapInBlock(ast::Expr),
 }
 
 impl Anchor {
@@ -204,16 +242,16 @@ impl Anchor {
                 }
 
                 if let Some(parent) = node.parent() {
-                    if parent.kind() == CLOSURE_EXPR {
+                    if let Some(parent) = ast::ClosureExpr::cast(parent.clone()) {
                         cov_mark::hit!(test_extract_var_in_closure_no_block);
-                        return Some(Anchor::WrapInBlock(node));
+                        return parent.body().map(Anchor::WrapInBlock);
                     }
-                    if parent.kind() == MATCH_ARM {
+                    if let Some(parent) = ast::MatchArm::cast(parent) {
                         if node.kind() == MATCH_GUARD {
                             cov_mark::hit!(test_extract_var_in_match_guard);
                         } else {
                             cov_mark::hit!(test_extract_var_in_match_arm_no_block);
-                            return Some(Anchor::WrapInBlock(node));
+                            return parent.expr().map(Anchor::WrapInBlock);
                         }
                     }
                 }
@@ -229,13 +267,6 @@ impl Anchor {
                 None
             })
     }
-
-    fn syntax(&self) -> &SyntaxNode {
-        match self {
-            Anchor::Before(it) | Anchor::WrapInBlock(it) => it,
-            Anchor::Replace(stmt) => stmt.syntax(),
-        }
-    }
 }
 
 #[cfg(test)]
@@ -502,7 +533,10 @@ fn main() {
 fn main() {
     let x = true;
     let tuple = match x {
-        true => { let $0var_name = 2 + 2; (var_name, true) }
+        true => {
+            let $0var_name = 2 + 2;
+            (var_name, true)
+        }
         _ => (0, false)
     };
 }
@@ -579,7 +613,10 @@ fn main() {
 "#,
             r#"
 fn main() {
-    let lambda = |x: u32| { let $0var_name = x * 2; var_name };
+    let lambda = |x: u32| {
+        let $0var_name = x * 2;
+        var_name
+    };
 }
 "#,
         );