Use more strictly typed syntax nodes for analysis in extract_function assist

This commit is contained in:
Lukas Wirth 2021-07-29 17:17:45 +02:00
parent 2b461c50d7
commit b537cb186e
4 changed files with 200 additions and 89 deletions

1
Cargo.lock generated
View File

@ -600,6 +600,7 @@ dependencies = [
"expect-test", "expect-test",
"hir", "hir",
"ide_db", "ide_db",
"indexmap",
"itertools", "itertools",
"profile", "profile",
"rustc-hash", "rustc-hash",

View File

@ -13,6 +13,7 @@ cov-mark = "2.0.0-pre.1"
rustc-hash = "1.1.0" rustc-hash = "1.1.0"
itertools = "0.10.0" itertools = "0.10.0"
either = "1.6.1" either = "1.6.1"
indexmap = "1.6.2"
stdx = { path = "../stdx", version = "0.0.0" } stdx = { path = "../stdx", version = "0.0.0" }
syntax = { path = "../syntax", version = "0.0.0" } syntax = { path = "../syntax", version = "0.0.0" }

View File

@ -1,13 +1,14 @@
use std::iter; use std::{hash::BuildHasherDefault, iter};
use ast::make; use ast::make;
use either::Either; use either::Either;
use hir::{HirDisplay, Local}; use hir::{HirDisplay, Local, Semantics};
use ide_db::{ use ide_db::{
defs::{Definition, NameRefClass}, defs::{Definition, NameRefClass},
search::{FileReference, ReferenceAccess, SearchScope}, search::{FileReference, ReferenceAccess, SearchScope},
RootDatabase,
}; };
use itertools::Itertools; use rustc_hash::FxHasher;
use stdx::format_to; use stdx::format_to;
use syntax::{ use syntax::{
ast::{ ast::{
@ -25,6 +26,8 @@ use crate::{
AssistId, AssistId,
}; };
type FxIndexSet<T> = indexmap::IndexSet<T, BuildHasherDefault<FxHasher>>;
// Assist: extract_function // Assist: extract_function
// //
// Extracts selected statements into new function. // Extracts selected statements into new function.
@ -51,7 +54,8 @@ use crate::{
// } // }
// ``` // ```
pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
if ctx.frange.range.is_empty() { let range = ctx.frange.range;
if range.is_empty() {
return None; return None;
} }
@ -65,11 +69,9 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
syntax::NodeOrToken::Node(n) => n, syntax::NodeOrToken::Node(n) => n,
syntax::NodeOrToken::Token(t) => t.parent()?, syntax::NodeOrToken::Token(t) => t.parent()?,
}; };
let body = extraction_target(&node, range)?;
let body = extraction_target(&node, ctx.frange.range)?; let (locals_used, has_await, self_param) = analyze_body(&ctx.sema, &body);
let vars_used_in_body = vars_used_in_body(ctx, &body);
let self_param = self_param_from_usages(ctx, &body, &vars_used_in_body);
let anchor = if self_param.is_some() { Anchor::Method } else { Anchor::Freestanding }; let anchor = if self_param.is_some() { Anchor::Method } else { Anchor::Freestanding };
let insert_after = scope_for_fn_insertion(&body, anchor)?; let insert_after = scope_for_fn_insertion(&body, anchor)?;
@ -95,7 +97,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
"Extract into function", "Extract into function",
target_range, target_range,
move |builder| { move |builder| {
let params = extracted_function_params(ctx, &body, &vars_used_in_body); let params = extracted_function_params(ctx, &body, locals_used.iter().copied());
let fun = Function { let fun = Function {
name: "fun_name".to_string(), name: "fun_name".to_string(),
@ -109,15 +111,10 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
let new_indent = IndentLevel::from_node(&insert_after); let new_indent = IndentLevel::from_node(&insert_after);
let old_indent = fun.body.indent_level(); let old_indent = fun.body.indent_level();
let body_contains_await = body_contains_await(&fun.body);
builder.replace( builder.replace(target_range, format_replacement(ctx, &fun, old_indent, has_await));
target_range,
format_replacement(ctx, &fun, old_indent, body_contains_await),
);
let fn_def = let fn_def = format_function(ctx, module, &fun, old_indent, new_indent, has_await);
format_function(ctx, module, &fun, old_indent, new_indent, body_contains_await);
let insert_offset = insert_after.text_range().end(); let insert_offset = insert_after.text_range().end();
match ctx.config.snippet_cap { match ctx.config.snippet_cap {
Some(cap) => builder.insert_snippet(cap, insert_offset, fn_def), Some(cap) => builder.insert_snippet(cap, insert_offset, fn_def),
@ -500,15 +497,59 @@ impl FunctionBody {
} }
} }
fn descendants(&self) -> impl Iterator<Item = SyntaxNode> + '_ { fn walk_expr(&self, cb: &mut dyn FnMut(ast::Expr)) {
match self { match self {
FunctionBody::Expr(expr) => Either::Right(expr.syntax().descendants()), FunctionBody::Expr(expr) => expr.walk(cb),
FunctionBody::Span { parent, text_range } => Either::Left( FunctionBody::Span { parent, text_range } => {
parent parent
.syntax() .statements()
.descendants() .filter(|stmt| text_range.contains_range(stmt.syntax().text_range()))
.filter(move |it| text_range.contains_range(it.text_range())), .filter_map(|stmt| match stmt {
), ast::Stmt::ExprStmt(expr_stmt) => expr_stmt.expr(),
ast::Stmt::Item(_) => None,
ast::Stmt::LetStmt(stmt) => stmt.initializer(),
})
.for_each(|expr| expr.walk(cb));
if let Some(expr) = parent
.tail_expr()
.filter(|it| text_range.contains_range(it.syntax().text_range()))
{
expr.walk(cb);
}
}
}
}
fn walk_pat(&self, cb: &mut dyn FnMut(ast::Pat)) {
match self {
FunctionBody::Expr(expr) => expr.walk_patterns(cb),
FunctionBody::Span { parent, text_range } => {
parent
.statements()
.filter(|stmt| text_range.contains_range(stmt.syntax().text_range()))
.for_each(|stmt| match stmt {
ast::Stmt::ExprStmt(expr_stmt) => {
if let Some(expr) = expr_stmt.expr() {
expr.walk_patterns(cb)
}
}
ast::Stmt::Item(_) => (),
ast::Stmt::LetStmt(stmt) => {
if let Some(pat) = stmt.pat() {
pat.walk(cb);
}
if let Some(expr) = stmt.initializer() {
expr.walk_patterns(cb);
}
}
});
if let Some(expr) = parent
.tail_expr()
.filter(|it| text_range.contains_range(it.syntax().text_range()))
{
expr.walk_patterns(cb);
}
}
} }
} }
@ -622,58 +663,48 @@ fn extraction_target(node: &SyntaxNode, selection_range: TextRange) -> Option<Fu
node.ancestors().find_map(ast::Expr::cast).and_then(FunctionBody::from_expr) node.ancestors().find_map(ast::Expr::cast).and_then(FunctionBody::from_expr)
} }
/// list local variables that are referenced in `body` /// Analyzes a function body, returning the used local variables that are referenced in it as well as
fn vars_used_in_body(ctx: &AssistContext, body: &FunctionBody) -> Vec<Local> { /// whether it contains an await expression.
// FIXME: currently usages inside macros are not found fn analyze_body(
body.descendants() sema: &Semantics<RootDatabase>,
.filter_map(ast::NameRef::cast)
.filter_map(|name_ref| NameRefClass::classify(&ctx.sema, &name_ref))
.map(|name_kind| match name_kind {
NameRefClass::Definition(def) => def,
NameRefClass::FieldShorthand { local_ref, field_ref: _ } => {
Definition::Local(local_ref)
}
})
.filter_map(|definition| match definition {
Definition::Local(local) => Some(local),
_ => None,
})
.unique()
.collect()
}
fn body_contains_await(body: &FunctionBody) -> bool {
body.descendants().any(|d| matches!(d.kind(), SyntaxKind::AWAIT_EXPR))
}
/// find `self` param, that was not defined inside `body`
///
/// It should skip `self` params from impls inside `body`
fn self_param_from_usages(
ctx: &AssistContext,
body: &FunctionBody, body: &FunctionBody,
vars_used_in_body: &[Local], ) -> (FxIndexSet<Local>, bool, Option<(Local, ast::SelfParam)>) {
) -> Option<(Local, ast::SelfParam)> { // FIXME: currently usages inside macros are not found
let mut iter = vars_used_in_body let mut has_await = false;
.iter() let mut self_param = None;
.filter(|var| var.is_self(ctx.db())) let mut res = FxIndexSet::default();
.map(|var| (var, var.source(ctx.db()))) body.walk_expr(&mut |expr| {
.filter(|(_, src)| is_defined_before(ctx, body, src)) has_await |= matches!(expr, ast::Expr::AwaitExpr(_));
.filter_map(|(&node, src)| match src.value { let name_ref = match expr {
Either::Right(it) => Some((node, it)), ast::Expr::PathExpr(path_expr) => {
Either::Left(_) => { path_expr.path().and_then(|it| it.as_single_name_ref())
stdx::never!(false, "Local::is_self returned true, but source is IdentPat");
None
} }
}); _ => return,
};
let self_param = iter.next(); if let Some(name_ref) = name_ref {
stdx::always!( if let Some(
iter.next().is_none(), NameRefClass::Definition(Definition::Local(local_ref))
"body references two different self params, both defined outside" | NameRefClass::FieldShorthand { local_ref, field_ref: _ },
); ) = NameRefClass::classify(sema, &name_ref)
{
self_param res.insert(local_ref);
if local_ref.is_self(sema.db) {
match local_ref.source(sema.db).value {
Either::Right(it) => {
stdx::always!(
self_param.replace((local_ref, it)).is_none(),
"body references two different self params"
);
}
Either::Left(_) => {
stdx::never!("Local::is_self returned true, but source is IdentPat");
}
}
}
}
}
});
(res, has_await, self_param)
} }
/// find variables that should be extracted as params /// find variables that should be extracted as params
@ -682,16 +713,15 @@ fn self_param_from_usages(
fn extracted_function_params( fn extracted_function_params(
ctx: &AssistContext, ctx: &AssistContext,
body: &FunctionBody, body: &FunctionBody,
vars_used_in_body: &[Local], locals: impl Iterator<Item = Local>,
) -> Vec<Param> { ) -> Vec<Param> {
vars_used_in_body locals
.iter() .filter(|local| !local.is_self(ctx.db()))
.filter(|var| !var.is_self(ctx.db())) .map(|local| (local, local.source(ctx.db())))
.map(|node| (node, node.source(ctx.db()))) .filter(|(_, src)| is_defined_outside_of_body(ctx, body, src))
.filter(|(_, src)| is_defined_before(ctx, body, src)) .filter_map(|(local, src)| {
.filter_map(|(&node, src)| {
if src.value.is_left() { if src.value.is_left() {
Some(node) Some(local)
} else { } else {
stdx::never!(false, "Local::is_self returned false, but source is SelfParam"); stdx::never!(false, "Local::is_self returned false, but source is SelfParam");
None None
@ -838,14 +868,18 @@ fn path_element_of_reference(
} }
/// list local variables defined inside `body` /// list local variables defined inside `body`
fn vars_defined_in_body(body: &FunctionBody, ctx: &AssistContext) -> Vec<Local> { fn locals_defined_in_body(body: &FunctionBody, ctx: &AssistContext) -> FxIndexSet<Local> {
// FIXME: this doesn't work well with macros // FIXME: this doesn't work well with macros
// see https://github.com/rust-analyzer/rust-analyzer/pull/7535#discussion_r570048550 // see https://github.com/rust-analyzer/rust-analyzer/pull/7535#discussion_r570048550
body.descendants() let mut res = FxIndexSet::default();
.filter_map(ast::IdentPat::cast) body.walk_pat(&mut |pat| {
.filter_map(|let_stmt| ctx.sema.to_def(&let_stmt)) if let ast::Pat::IdentPat(pat) = pat {
.unique() if let Some(local) = ctx.sema.to_def(&pat) {
.collect() res.insert(local);
}
}
});
res
} }
/// list local variables defined inside `body` that should be returned from extracted function /// list local variables defined inside `body` that should be returned from extracted function
@ -854,7 +888,7 @@ fn vars_defined_in_body_and_outlive(
body: &FunctionBody, body: &FunctionBody,
parent: &SyntaxNode, parent: &SyntaxNode,
) -> Vec<OutlivedLocal> { ) -> Vec<OutlivedLocal> {
let vars_defined_in_body = vars_defined_in_body(body, ctx); let vars_defined_in_body = locals_defined_in_body(body, ctx);
vars_defined_in_body vars_defined_in_body
.into_iter() .into_iter()
.filter_map(|var| var_outlives_body(ctx, body, var, parent)) .filter_map(|var| var_outlives_body(ctx, body, var, parent))
@ -862,7 +896,7 @@ fn vars_defined_in_body_and_outlive(
} }
/// checks if the relevant local was defined before(outside of) body /// checks if the relevant local was defined before(outside of) body
fn is_defined_before( fn is_defined_outside_of_body(
ctx: &AssistContext, ctx: &AssistContext,
body: &FunctionBody, body: &FunctionBody,
src: &hir::InFile<Either<ast::IdentPat, ast::SelfParam>>, src: &hir::InFile<Either<ast::IdentPat, ast::SelfParam>>,

View File

@ -103,6 +103,81 @@ impl ast::Expr {
} }
} }
} }
/// Preorder walk all the expression's child patterns.
pub fn walk_patterns(&self, cb: &mut dyn FnMut(ast::Pat)) {
let mut preorder = self.syntax().preorder();
while let Some(event) = preorder.next() {
let node = match event {
WalkEvent::Enter(node) => node,
WalkEvent::Leave(_) => continue,
};
match ast::Stmt::cast(node.clone()) {
Some(ast::Stmt::LetStmt(l)) => {
if let Some(pat) = l.pat() {
pat.walk(cb);
}
if let Some(expr) = l.initializer() {
expr.walk_patterns(cb);
}
preorder.skip_subtree();
}
// Don't skip subtree since we want to process the expression child next
Some(ast::Stmt::ExprStmt(_)) => (),
// skip inner items which might have their own patterns
Some(ast::Stmt::Item(_)) => preorder.skip_subtree(),
None => {
// skip const args, those are a different context
if ast::GenericArg::can_cast(node.kind()) {
preorder.skip_subtree();
} else if let Some(expr) = ast::Expr::cast(node.clone()) {
let is_different_context = match &expr {
ast::Expr::EffectExpr(effect) => {
matches!(
effect.effect(),
ast::Effect::Async(_)
| ast::Effect::Try(_)
| ast::Effect::Const(_)
)
}
ast::Expr::ClosureExpr(_) => true,
_ => false,
};
if is_different_context {
preorder.skip_subtree();
}
} else if let Some(pat) = ast::Pat::cast(node) {
preorder.skip_subtree();
pat.walk(cb);
}
}
}
}
}
}
impl ast::Pat {
/// Preorder walk all the pattern's sub patterns.
pub fn walk(&self, cb: &mut dyn FnMut(ast::Pat)) {
let mut preorder = self.syntax().preorder();
while let Some(event) = preorder.next() {
let node = match event {
WalkEvent::Enter(node) => node,
WalkEvent::Leave(_) => continue,
};
match ast::Pat::cast(node.clone()) {
Some(ast::Pat::ConstBlockPat(_)) => preorder.skip_subtree(),
Some(pat) => {
cb(pat);
}
// skip const args
None if ast::GenericArg::can_cast(node.kind()) => {
preorder.skip_subtree();
}
None => (),
}
}
}
} }
#[derive(Debug, PartialEq, Eq, Clone)] #[derive(Debug, PartialEq, Eq, Clone)]