From b6d574642d03676cdaae5060be21b11640200086 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Thu, 5 Aug 2021 02:54:06 +0200 Subject: [PATCH] extract_type_alias extracts generics correctly --- .../src/handlers/extract_type_alias.rs | 172 ++++++++++++++++-- crates/syntax/src/ast/node_ext.rs | 20 +- 2 files changed, 175 insertions(+), 17 deletions(-) diff --git a/crates/ide_assists/src/handlers/extract_type_alias.rs b/crates/ide_assists/src/handlers/extract_type_alias.rs index eac8857c678..4913ac1e08e 100644 --- a/crates/ide_assists/src/handlers/extract_type_alias.rs +++ b/crates/ide_assists/src/handlers/extract_type_alias.rs @@ -1,5 +1,7 @@ +use either::Either; +use itertools::Itertools; use syntax::{ - ast::{self, edit::IndentLevel, AstNode}, + ast::{self, edit::IndentLevel, AstNode, GenericParamsOwner, NameOwner}, match_ast, }; @@ -27,41 +29,158 @@ pub(crate) fn extract_type_alias(acc: &mut Assists, ctx: &AssistContext) -> Opti return None; } - let node = ctx.find_node_at_range::()?; - let item = ctx.find_node_at_offset::()?; - let insert = match_ast! { - match (item.syntax().parent()?) { - ast::AssocItemList(it) => it.syntax().parent()?, - _ => item.syntax().clone(), + let ty = ctx.find_node_at_range::()?; + let item = ty.syntax().ancestors().find_map(ast::Item::cast)?; + let assoc_owner = item.syntax().ancestors().nth(2).and_then(|it| { + match_ast! { + match it { + ast::Trait(tr) => Some(Either::Left(tr)), + ast::Impl(impl_) => Some(Either::Right(impl_)), + _ => None, + } } - }; - let indent = IndentLevel::from_node(&insert); - let insert = insert.text_range().start(); - let target = node.syntax().text_range(); + }); + let node = assoc_owner.as_ref().map_or_else( + || item.syntax(), + |impl_| impl_.as_ref().either(AstNode::syntax, AstNode::syntax), + ); + let insert_pos = node.text_range().start(); + let target = ty.syntax().text_range(); acc.add( AssistId("extract_type_alias", AssistKind::RefactorExtract), "Extract type as type alias", target, |builder| { - builder.edit_file(ctx.frange.file_id); - builder.replace(target, "Type"); + let mut known_generics = match item.generic_param_list() { + Some(it) => it.generic_params().collect(), + None => Vec::new(), + }; + if let Some(it) = assoc_owner.as_ref().and_then(|it| match it { + Either::Left(it) => it.generic_param_list(), + Either::Right(it) => it.generic_param_list(), + }) { + known_generics.extend(it.generic_params()); + } + let generics = collect_used_generics(&ty, &known_generics); + + let replacement = if !generics.is_empty() { + format!( + "Type<{}>", + generics.iter().format_with(", ", |generic, f| { + match generic { + ast::GenericParam::ConstParam(cp) => f(&cp.name().unwrap()), + ast::GenericParam::LifetimeParam(lp) => f(&lp.lifetime().unwrap()), + ast::GenericParam::TypeParam(tp) => f(&tp.name().unwrap()), + } + }) + ) + } else { + String::from("Type") + }; + builder.replace(target, replacement); + + let indent = IndentLevel::from_node(node); + let generics = if !generics.is_empty() { + format!("<{}>", generics.iter().format(", ")) + } else { + String::new() + }; match ctx.config.snippet_cap { Some(cap) => { builder.insert_snippet( cap, - insert, - format!("type $0Type = {};\n\n{}", node, indent), + insert_pos, + format!("type $0Type{} = {};\n\n{}", generics, ty, indent), ); } None => { - builder.insert(insert, format!("type Type = {};\n\n{}", node, indent)); + builder.insert( + insert_pos, + format!("type Type{} = {};\n\n{}", generics, ty, indent), + ); } } }, ) } +fn collect_used_generics<'gp>( + ty: &ast::Type, + known_generics: &'gp [ast::GenericParam], +) -> Vec<&'gp ast::GenericParam> { + // can't use a closure -> closure here cause lifetime inference fails for that + fn find_lifetime(text: &str) -> impl Fn(&&ast::GenericParam) -> bool + '_ { + move |gp: &&ast::GenericParam| match gp { + ast::GenericParam::LifetimeParam(lp) => { + lp.lifetime().map_or(false, |lt| lt.text() == text) + } + _ => false, + } + } + + let mut generics = Vec::new(); + ty.walk(&mut |ty| match ty { + ast::Type::PathType(ty) => { + if let Some(path) = ty.path() { + if let Some(name_ref) = path.as_single_name_ref() { + if let Some(param) = known_generics.iter().find(|gp| { + match gp { + ast::GenericParam::ConstParam(cp) => cp.name(), + ast::GenericParam::TypeParam(tp) => tp.name(), + _ => None, + } + .map_or(false, |n| n.text() == name_ref.text()) + }) { + generics.push(param); + } + } + generics.extend( + path.segments() + .filter_map(|seg| seg.generic_arg_list()) + .flat_map(|it| it.generic_args()) + .filter_map(|it| match it { + ast::GenericArg::LifetimeArg(lt) => { + let lt = lt.lifetime()?; + known_generics.iter().find(find_lifetime(<.text())) + } + _ => None, + }), + ); + } + } + ast::Type::ImplTraitType(impl_ty) => { + if let Some(it) = impl_ty.type_bound_list() { + generics.extend( + it.bounds() + .filter_map(|it| it.lifetime()) + .filter_map(|lt| known_generics.iter().find(find_lifetime(<.text()))), + ); + } + } + ast::Type::DynTraitType(dyn_ty) => { + if let Some(it) = dyn_ty.type_bound_list() { + generics.extend( + it.bounds() + .filter_map(|it| it.lifetime()) + .filter_map(|lt| known_generics.iter().find(find_lifetime(<.text()))), + ); + } + } + ast::Type::RefType(ref_) => generics.extend( + ref_.lifetime().and_then(|lt| known_generics.iter().find(find_lifetime(<.text()))), + ), + _ => (), + }); + // stable resort to lifetime, type, const + generics.sort_by_key(|gp| match gp { + ast::GenericParam::ConstParam(_) => 2, + ast::GenericParam::LifetimeParam(_) => 0, + ast::GenericParam::TypeParam(_) => 1, + }); + generics +} + #[cfg(test)] mod tests { use crate::tests::{check_assist, check_assist_not_applicable}; @@ -216,4 +335,25 @@ mod m { "#, ); } + + #[test] + fn generics() { + check_assist( + extract_type_alias, + r#" +struct Struct; +impl<'outer, Outer, const OUTER: usize> () { + fn func<'inner, Inner, const INNER: usize>(_: $0&(Struct, Struct, Outer, &'inner (), Inner, &'outer ())$0) {} +} +"#, + r#" +struct Struct; +type $0Type<'inner, 'outer, Outer, Inner, const INNER: usize, const OUTER: usize> = &(Struct, Struct, Outer, &'inner (), Inner, &'outer ()); + +impl<'outer, Outer, const OUTER: usize> () { + fn func<'inner, Inner, const INNER: usize>(_: Type<'inner, 'outer, Outer, Inner, INNER, OUTER>) {} +} +"#, + ); + } } diff --git a/crates/syntax/src/ast/node_ext.rs b/crates/syntax/src/ast/node_ext.rs index 0a540d9cfb3..3030c881209 100644 --- a/crates/syntax/src/ast/node_ext.rs +++ b/crates/syntax/src/ast/node_ext.rs @@ -8,7 +8,10 @@ use parser::SyntaxKind; use rowan::{GreenNodeData, GreenTokenData, WalkEvent}; use crate::{ - ast::{self, support, AstChildren, AstNode, AstToken, AttrsOwner, NameOwner, SyntaxNode}, + ast::{ + self, support, AstChildren, AstNode, AstToken, AttrsOwner, GenericParamsOwner, NameOwner, + SyntaxNode, + }, NodeOrToken, SmolStr, SyntaxElement, SyntaxToken, TokenText, T, }; @@ -593,6 +596,21 @@ impl ast::Variant { } } +impl ast::Item { + pub fn generic_param_list(&self) -> Option { + match self { + ast::Item::Enum(it) => it.generic_param_list(), + ast::Item::Fn(it) => it.generic_param_list(), + ast::Item::Impl(it) => it.generic_param_list(), + ast::Item::Struct(it) => it.generic_param_list(), + ast::Item::Trait(it) => it.generic_param_list(), + ast::Item::TypeAlias(it) => it.generic_param_list(), + ast::Item::Union(it) => it.generic_param_list(), + _ => None, + } + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub enum FieldKind { Name(ast::NameRef),