Complete local fn and closure params from surrounding locals scope

This commit is contained in:
Lukas Wirth 2022-01-31 11:56:42 +01:00
parent ddf7b70a0f
commit 6194092086
5 changed files with 203 additions and 92 deletions

View File

@ -389,8 +389,8 @@ impl<'db, DB: HirDatabase> Semantics<'db, DB> {
self.imp.scope(node)
}
pub fn scope_at_offset(&self, token: &SyntaxToken, offset: TextSize) -> SemanticsScope<'db> {
self.imp.scope_at_offset(&token.parent().unwrap(), offset)
pub fn scope_at_offset(&self, node: &SyntaxNode, offset: TextSize) -> SemanticsScope<'db> {
self.imp.scope_at_offset(&node, offset)
}
pub fn scope_for_def(&self, def: Trait) -> SemanticsScope<'db> {

View File

@ -1,9 +1,11 @@
//! See [`complete_fn_param`].
use hir::HirDisplay;
use rustc_hash::FxHashMap;
use syntax::{
algo,
ast::{self, HasModuleItem},
match_ast, AstNode, SyntaxKind,
match_ast, AstNode, Direction, SyntaxKind,
};
use crate::{
@ -15,14 +17,48 @@ use crate::{
/// functions in a file have a `spam: &mut Spam` parameter, a completion with
/// `spam: &mut Spam` insert text/label and `spam` lookup string will be
/// suggested.
///
/// Also complete parameters for closure or local functions from the surrounding defined locals.
pub(crate) fn complete_fn_param(acc: &mut Completions, ctx: &CompletionContext) -> Option<()> {
let param_of_fn =
matches!(ctx.pattern_ctx, Some(PatternContext { is_param: Some(ParamKind::Function), .. }));
let (param_list, _, param_kind) = match &ctx.pattern_ctx {
Some(PatternContext { param_ctx: Some(kind), .. }) => kind,
_ => return None,
};
if !param_of_fn {
return None;
let comma_wrapper = comma_wrapper(ctx);
let mut add_new_item_to_acc = |label: &str, lookup: String| {
let mk_item = |label: &str| {
CompletionItem::new(CompletionItemKind::Binding, ctx.source_range(), label)
};
let mut item = match &comma_wrapper {
Some(fmt) => mk_item(&fmt(&label)),
None => mk_item(label),
};
item.lookup_by(lookup);
item.add_to(acc)
};
match param_kind {
ParamKind::Function(function) => {
fill_fn_params(ctx, function, &param_list, add_new_item_to_acc);
}
ParamKind::Closure(closure) => {
let stmt_list = closure.syntax().ancestors().find_map(ast::StmtList::cast)?;
params_from_stmt_list_scope(ctx, stmt_list, |name, ty| {
add_new_item_to_acc(&format!("{name}: {ty}"), name.to_string());
});
}
}
Some(())
}
fn fill_fn_params(
ctx: &CompletionContext,
function: &ast::Fn,
param_list: &ast::ParamList,
mut add_new_item_to_acc: impl FnMut(&str, String),
) {
let mut file_params = FxHashMap::default();
let mut extract_params = |f: ast::Fn| {
@ -56,23 +92,46 @@ pub(crate) fn complete_fn_param(acc: &mut Completions, ctx: &CompletionContext)
};
}
let function = ctx.token.ancestors().find_map(ast::Fn::cast)?;
let param_list = function.param_list()?;
remove_duplicated(&mut file_params, param_list.params());
let self_completion_items = ["self", "&self", "mut self", "&mut self"];
if should_add_self_completions(ctx, param_list) {
self_completion_items.into_iter().for_each(|self_item| {
add_new_item_to_acc(ctx, acc, self_item.to_string(), self_item.to_string())
if let Some(stmt_list) = function.syntax().parent().and_then(ast::StmtList::cast) {
params_from_stmt_list_scope(ctx, stmt_list, |name, ty| {
file_params.entry(format!("{name}: {ty}")).or_insert(name.to_string());
});
}
file_params.into_iter().try_for_each(|(whole_param, binding)| {
Some(add_new_item_to_acc(ctx, acc, surround_with_commas(ctx, whole_param), binding))
})?;
remove_duplicated(&mut file_params, param_list.params());
let self_completion_items = ["self", "&self", "mut self", "&mut self"];
if should_add_self_completions(ctx, param_list) {
self_completion_items
.into_iter()
.for_each(|self_item| add_new_item_to_acc(self_item, self_item.to_string()));
}
Some(())
file_params
.into_iter()
.for_each(|(whole_param, binding)| add_new_item_to_acc(&whole_param, binding));
}
fn params_from_stmt_list_scope(
ctx: &CompletionContext,
stmt_list: ast::StmtList,
mut cb: impl FnMut(hir::Name, String),
) {
let syntax_node = match stmt_list.syntax().last_child() {
Some(it) => it,
None => return,
};
let scope = ctx.sema.scope_at_offset(stmt_list.syntax(), syntax_node.text_range().end());
let module = match scope.module() {
Some(it) => it,
None => return,
};
scope.process_all_names(&mut |name, def| {
if let hir::ScopeDef::Local(local) = def {
if let Ok(ty) = local.ty(ctx.db).display_source_code(ctx.db, module.into()) {
cb(name, ty);
}
}
});
}
fn remove_duplicated(
@ -96,52 +155,32 @@ fn remove_duplicated(
})
}
fn should_add_self_completions(ctx: &CompletionContext, param_list: ast::ParamList) -> bool {
fn should_add_self_completions(ctx: &CompletionContext, param_list: &ast::ParamList) -> bool {
let inside_impl = ctx.impl_def.is_some();
let no_params = param_list.params().next().is_none() && param_list.self_param().is_none();
inside_impl && no_params
}
fn surround_with_commas(ctx: &CompletionContext, param: String) -> String {
match fallible_surround_with_commas(ctx, &param) {
Some(surrounded) => surrounded,
// fallback to the original parameter
None => param,
}
}
fn fallible_surround_with_commas(ctx: &CompletionContext, param: &str) -> Option<String> {
let next_token = {
fn comma_wrapper(ctx: &CompletionContext) -> Option<impl Fn(&str) -> String> {
let next_token_kind = {
let t = ctx.token.next_token()?;
match t.kind() {
SyntaxKind::WHITESPACE => t.next_token()?,
_ => t,
}
let t = algo::skip_whitespace_token(t, Direction::Next)?;
t.kind()
};
let prev_token_kind = {
let t = ctx.previous_token.clone()?;
let t = algo::skip_whitespace_token(t, Direction::Prev)?;
t.kind()
};
let trailing_comma_missing = matches!(next_token.kind(), SyntaxKind::IDENT);
let trailing = if trailing_comma_missing { "," } else { "" };
let has_trailing_comma =
matches!(next_token_kind, SyntaxKind::COMMA | SyntaxKind::R_PAREN | SyntaxKind::PIPE);
let trailing = if has_trailing_comma { "" } else { "," };
let previous_token = if matches!(ctx.token.kind(), SyntaxKind::IDENT | SyntaxKind::WHITESPACE) {
ctx.previous_token.as_ref()?
} else {
&ctx.token
};
let has_leading_comma =
matches!(prev_token_kind, SyntaxKind::COMMA | SyntaxKind::L_PAREN | SyntaxKind::PIPE);
let leading = if has_leading_comma { "" } else { ", " };
let needs_leading = !matches!(previous_token.kind(), SyntaxKind::L_PAREN | SyntaxKind::COMMA);
let leading = if needs_leading { ", " } else { "" };
Some(format!("{}{}{}", leading, param, trailing))
}
fn add_new_item_to_acc(
ctx: &CompletionContext,
acc: &mut Completions,
label: String,
lookup: String,
) {
let mut item = CompletionItem::new(CompletionItemKind::Binding, ctx.source_range(), label);
item.lookup_by(lookup);
item.add_to(acc)
Some(move |param: &_| format!("{}{}{}", leading, param, trailing))
}

View File

@ -27,6 +27,8 @@ use crate::{
CompletionConfig,
};
const COMPLETION_MARKER: &str = "intellijRulezz";
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub(crate) enum PatternRefutability {
Refutable,
@ -68,7 +70,7 @@ pub(crate) struct PathCompletionContext {
#[derive(Debug)]
pub(super) struct PatternContext {
pub(super) refutability: PatternRefutability,
pub(super) is_param: Option<ParamKind>,
pub(super) param_ctx: Option<(ast::ParamList, ast::Param, ParamKind)>,
pub(super) has_type_ascription: bool,
}
@ -80,10 +82,10 @@ pub(super) enum LifetimeContext {
LabelDef,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[derive(Clone, Debug, PartialEq, Eq)]
pub(crate) enum ParamKind {
Function,
Closure,
Function(ast::Fn),
Closure(ast::ClosureExpr),
}
/// `CompletionContext` is created early during completion to figure out, where
@ -382,7 +384,7 @@ impl<'a> CompletionContext<'a> {
// actual completion.
let file_with_fake_ident = {
let parse = db.parse(file_id);
let edit = Indel::insert(offset, "intellijRulezz".to_string());
let edit = Indel::insert(offset, COMPLETION_MARKER.to_string());
parse.reparse(&edit).tree()
};
let fake_ident_token =
@ -390,7 +392,7 @@ impl<'a> CompletionContext<'a> {
let original_token = original_file.syntax().token_at_offset(offset).left_biased()?;
let token = sema.descend_into_macros_single(original_token.clone());
let scope = sema.scope_at_offset(&token, offset);
let scope = sema.scope_at_offset(&token.parent()?, offset);
let krate = scope.krate();
let mut locals = vec![];
scope.process_all_names(&mut |name, scope| {
@ -723,7 +725,7 @@ impl<'a> CompletionContext<'a> {
}
}
ast::NameLike::Name(name) => {
self.pattern_ctx = Self::classify_name(&self.sema, name);
self.pattern_ctx = Self::classify_name(&self.sema, original_file, name);
}
}
}
@ -750,7 +752,11 @@ impl<'a> CompletionContext<'a> {
})
}
fn classify_name(_sema: &Semantics<RootDatabase>, name: ast::Name) -> Option<PatternContext> {
fn classify_name(
_sema: &Semantics<RootDatabase>,
original_file: &SyntaxNode,
name: ast::Name,
) -> Option<PatternContext> {
let bind_pat = name.syntax().parent().and_then(ast::IdentPat::cast)?;
let is_name_in_field_pat = bind_pat
.syntax()
@ -763,7 +769,7 @@ impl<'a> CompletionContext<'a> {
if !bind_pat.is_simple_ident() {
return None;
}
Some(pattern_context_for(bind_pat.into()))
Some(pattern_context_for(original_file, bind_pat.into()))
}
fn classify_name_ref(
@ -799,15 +805,15 @@ impl<'a> CompletionContext<'a> {
},
ast::TupleStructPat(it) => {
path_ctx.has_call_parens = true;
pat_ctx = Some(pattern_context_for(it.into()));
pat_ctx = Some(pattern_context_for(original_file, it.into()));
Some(PathKind::Pat)
},
ast::RecordPat(it) => {
pat_ctx = Some(pattern_context_for(it.into()));
pat_ctx = Some(pattern_context_for(original_file, it.into()));
Some(PathKind::Pat)
},
ast::PathPat(it) => {
pat_ctx = Some(pattern_context_for(it.into()));
pat_ctx = Some(pattern_context_for(original_file, it.into()));
Some(PathKind::Pat)
},
ast::MacroCall(it) => it.excl_token().and(Some(PathKind::Mac)),
@ -824,12 +830,7 @@ impl<'a> CompletionContext<'a> {
path_ctx.use_tree_parent = use_tree_parent;
path_ctx.qualifier = path
.segment()
.and_then(|it| {
find_node_with_range::<ast::PathSegment>(
original_file,
it.syntax().text_range(),
)
})
.and_then(|it| find_node_in_file(original_file, &it))
.map(|it| it.parent_path());
return Some((path_ctx, pat_ctx));
}
@ -864,7 +865,7 @@ impl<'a> CompletionContext<'a> {
}
}
fn pattern_context_for(pat: ast::Pat) -> PatternContext {
fn pattern_context_for(original_file: &SyntaxNode, pat: ast::Pat) -> PatternContext {
let mut is_param = None;
let (refutability, has_type_ascription) =
pat
@ -877,18 +878,21 @@ fn pattern_context_for(pat: ast::Pat) -> PatternContext {
match node {
ast::LetStmt(let_) => return (PatternRefutability::Irrefutable, let_.ty().is_some()),
ast::Param(param) => {
let is_closure_param = param
.syntax()
.ancestors()
.nth(2)
.and_then(ast::ClosureExpr::cast)
.is_some();
is_param = Some(if is_closure_param {
ParamKind::Closure
} else {
ParamKind::Function
});
return (PatternRefutability::Irrefutable, param.ty().is_some())
let has_type_ascription = param.ty().is_some();
is_param = (|| {
let fake_param_list = param.syntax().parent().and_then(ast::ParamList::cast)?;
let param_list = find_node_in_file_compensated(original_file, &fake_param_list)?;
let param_list_owner = param_list.syntax().parent()?;
let kind = match_ast! {
match param_list_owner {
ast::ClosureExpr(closure) => ParamKind::Closure(closure),
ast::Fn(fn_) => ParamKind::Function(fn_),
_ => return None,
}
};
Some((param_list, param, kind))
})();
return (PatternRefutability::Irrefutable, has_type_ascription)
},
ast::MatchArm(_) => PatternRefutability::Refutable,
ast::Condition(_) => PatternRefutability::Refutable,
@ -898,11 +902,29 @@ fn pattern_context_for(pat: ast::Pat) -> PatternContext {
};
(refutability, false)
});
PatternContext { refutability, is_param, has_type_ascription }
PatternContext { refutability, param_ctx: is_param, has_type_ascription }
}
fn find_node_with_range<N: AstNode>(syntax: &SyntaxNode, range: TextRange) -> Option<N> {
syntax.covering_element(range).ancestors().find_map(N::cast)
fn find_node_in_file<N: AstNode>(syntax: &SyntaxNode, node: &N) -> Option<N> {
let syntax_range = syntax.text_range();
let range = node.syntax().text_range();
let intersection = range.intersect(syntax_range)?;
syntax.covering_element(intersection).ancestors().find_map(N::cast)
}
/// Compensates for the offset introduced by the fake ident
/// This is wrong if `node` comes before the insertion point! Use `find_node_in_file` instead.
fn find_node_in_file_compensated<N: AstNode>(syntax: &SyntaxNode, node: &N) -> Option<N> {
let syntax_range = syntax.text_range();
let range = node.syntax().text_range();
let end = range.end().checked_sub(TextSize::try_from(COMPLETION_MARKER.len()).ok()?)?;
if end < range.start() {
return None;
}
let range = TextRange::new(range.start(), end);
// our inserted ident could cause `range` to be go outside of the original syntax, so cap it
let intersection = range.intersect(syntax_range)?;
syntax.covering_element(intersection).ancestors().find_map(N::cast)
}
fn path_or_use_tree_qualifier(path: &ast::Path) -> Option<(ast::Path, bool)> {

View File

@ -87,7 +87,7 @@ fn render_pat(
if matches!(
ctx.completion.pattern_ctx,
Some(PatternContext {
is_param: Some(ParamKind::Function),
param_ctx: Some((.., ParamKind::Function(_))),
has_type_ascription: false,
..
})

View File

@ -156,3 +156,53 @@ impl A {
"#]],
)
}
// doesn't complete qux due to there being no expression after
// see source_analyzer::adjust comment
#[test]
fn local_fn_shows_locals_for_params() {
check(
r#"
fn outer() {
let foo = 3;
{
let bar = 3;
fn inner($0) {}
let baz = 3;
let qux = 3;
}
let fez = 3;
}
"#,
expect![[r#"
bn foo: i32
bn baz: i32
bn bar: i32
kw mut
"#]],
)
}
#[test]
fn closure_shows_locals_for_params() {
check(
r#"
fn outer() {
let foo = 3;
{
let bar = 3;
|$0| {};
let baz = 3;
let qux = 3;
}
let fez = 3;
}
"#,
expect![[r#"
bn baz: i32
bn bar: i32
bn foo: i32
kw mut
"#]],
)
}