diff --git a/crates/hir/src/lib.rs b/crates/hir/src/lib.rs index c47988fc4c2..661ef6b7c41 100644 --- a/crates/hir/src/lib.rs +++ b/crates/hir/src/lib.rs @@ -1391,8 +1391,13 @@ impl Const { db.const_data(self.id).name.clone() } - pub fn type_ref(self, db: &dyn HirDatabase) -> TypeRef { - db.const_data(self.id).type_ref.as_ref().clone() + pub fn ty(self, db: &dyn HirDatabase) -> Type { + let data = db.const_data(self.id); + let resolver = self.id.resolver(db.upcast()); + let krate = self.id.lookup(db.upcast()).container.krate(db); + let ctx = hir_ty::TyLoweringContext::new(db, &resolver); + let ty = ctx.lower_ty(&data.type_ref); + Type::new_with_resolver_inner(db, krate.id, &resolver, ty) } } @@ -1421,6 +1426,15 @@ impl Static { pub fn is_mut(self, db: &dyn HirDatabase) -> bool { db.static_data(self.id).mutable } + + pub fn ty(self, db: &dyn HirDatabase) -> Type { + let data = db.static_data(self.id); + let resolver = self.id.resolver(db.upcast()); + let krate = self.id.lookup(db.upcast()).container.krate(); + let ctx = hir_ty::TyLoweringContext::new(db, &resolver); + let ty = ctx.lower_ty(&data.type_ref); + Type::new_with_resolver_inner(db, krate, &resolver, ty) + } } impl HasVisibility for Static { diff --git a/crates/ide_assists/src/handlers/extract_function.rs b/crates/ide_assists/src/handlers/extract_function.rs index e6cac754fe0..2ff1511f528 100644 --- a/crates/ide_assists/src/handlers/extract_function.rs +++ b/crates/ide_assists/src/handlers/extract_function.rs @@ -71,7 +71,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option syntax::NodeOrToken::Token(t) => t.parent()?, }; let body = extraction_target(&node, range)?; - let mods = body.analyze_container()?; + let container_info = body.analyze_container(&ctx.sema)?; let (locals_used, self_param) = body.analyze(&ctx.sema); @@ -80,7 +80,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option let module = ctx.sema.scope(&insert_after).module()?; let ret_ty = body.return_ty(ctx)?; - let control_flow = body.external_control_flow(ctx)?; + let control_flow = body.external_control_flow(ctx, &container_info)?; let ret_values = body.ret_values(ctx, node.parent().as_ref().unwrap_or(&node)); let target_range = body.text_range(); @@ -93,7 +93,6 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option let outliving_locals: Vec<_> = ret_values.collect(); if stdx::never!(!outliving_locals.is_empty() && !ret_ty.is_unit()) { // We should not have variables that outlive body if we have expression block - stdx::never!(); return; } @@ -107,7 +106,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option ret_ty, body, outliving_locals, - mods, + mods: container_info, }; let new_indent = IndentLevel::from_node(&insert_after); @@ -177,7 +176,7 @@ struct Function { ret_ty: RetType, body: FunctionBody, outliving_locals: Vec, - mods: Modifiers, + mods: ContainerInfo, } #[derive(Debug)] @@ -213,16 +212,22 @@ enum Anchor { Method, } +// FIXME: ControlFlow and ContainerInfo both track some function modifiers, feels like these two should +// probably be merged somehow. #[derive(Debug)] struct ControlFlow { kind: Option, is_async: bool, + is_unsafe: bool, } -#[derive(Copy, Clone, Debug)] -struct Modifiers { +/// The thing whose expression we are extracting from. Can be a function, const, static, const arg, ... +#[derive(Clone, Debug)] +struct ContainerInfo { is_const: bool, is_in_tail: bool, + /// The function's return type, const's type etc. + ret_type: Option, } /// Control flow that is exported from extracted function @@ -244,10 +249,6 @@ enum FlowKind { Try { kind: TryKind, }, - TryReturn { - expr: ast::Expr, - kind: TryKind, - }, /// Break with value (`break $expr;`) Break(Option), /// Continue @@ -295,7 +296,7 @@ struct OutlivedLocal { struct LocalUsages(ide_db::search::UsageSearchResult); impl LocalUsages { - fn find(ctx: &AssistContext, var: Local) -> Self { + fn find_local_usages(ctx: &AssistContext, var: Local) -> Self { Self( Definition::Local(var) .usages(&ctx.sema) @@ -395,7 +396,7 @@ impl FlowKind { match self { FlowKind::Return(_) => make::expr_return(expr), FlowKind::Break(_) => make::expr_break(expr), - FlowKind::Try { .. } | FlowKind::TryReturn { .. } => { + FlowKind::Try { .. } => { stdx::never!("cannot have result handler with try"); expr.unwrap_or_else(|| make::expr_return(None)) } @@ -408,9 +409,7 @@ impl FlowKind { fn expr_ty(&self, ctx: &AssistContext) -> Option { match self { - FlowKind::Return(Some(expr)) - | FlowKind::Break(Some(expr)) - | FlowKind::TryReturn { expr, .. } => { + FlowKind::Return(Some(expr)) | FlowKind::Break(Some(expr)) => { ctx.sema.type_of_expr(expr).map(TypeInfo::adjusted) } FlowKind::Try { .. } => { @@ -620,49 +619,50 @@ impl FunctionBody { (res, self_param) } - fn analyze_container(&self) -> Option { - let mut is_const = false; - let container_expr = self.parent()?.ancestors().find_map(|it| { - // double Option as we want to short circuit - let res = match_ast! { - match it { - ast::ClosureExpr(closure) => closure.body(), + fn analyze_container(&self, sema: &Semantics) -> Option { + let mut ancestors = self.parent()?.ancestors(); + let infer_expr_opt = |expr| sema.type_of_expr(&expr?).map(TypeInfo::adjusted); + let (is_const, expr, ty) = loop { + let anc = ancestors.next()?; + break match_ast! { + match anc { + ast::ClosureExpr(closure) => (false, closure.body(), infer_expr_opt(closure.body())), ast::EffectExpr(effect) => { - is_const = effect.const_token().is_some(); - effect.block_expr().map(ast::Expr::BlockExpr) + let (constness, block) = match effect.effect() { + ast::Effect::Const(_) => (true, effect.block_expr()), + ast::Effect::Try(_) => (false, effect.block_expr()), + ast::Effect::Label(label) if label.lifetime().is_some() => (false, effect.block_expr()), + _ => continue, + }; + let expr = block.map(ast::Expr::BlockExpr); + (constness, expr.clone(), infer_expr_opt(expr)) }, ast::Fn(fn_) => { - is_const = fn_.const_token().is_some(); - fn_.body().map(ast::Expr::BlockExpr) + (fn_.const_token().is_some(), fn_.body().map(ast::Expr::BlockExpr), Some(sema.to_def(&fn_)?.ret_type(sema.db))) }, ast::Static(statik) => { - is_const = true; - statik.body() + (true, statik.body(), Some(sema.to_def(&statik)?.ty(sema.db))) }, ast::ConstArg(ca) => { - is_const = true; - ca.expr() + (true, ca.expr(), infer_expr_opt(ca.expr())) }, ast::Const(konst) => { - is_const = true; - konst.body() + (true, konst.body(), Some(sema.to_def(&konst)?.ty(sema.db))) }, ast::ConstParam(cp) => { - is_const = true; - cp.default_val() + (true, cp.default_val(), Some(sema.to_def(&cp)?.ty(sema.db))) }, ast::ConstBlockPat(cbp) => { - is_const = true; - cbp.block_expr().map(ast::Expr::BlockExpr) + let expr = cbp.block_expr().map(ast::Expr::BlockExpr); + (true, expr.clone(), infer_expr_opt(expr)) }, - ast::Variant(__) => None, - ast::Meta(__) => None, - _ => return None, + ast::Variant(__) => return None, + ast::Meta(__) => return None, + _ => continue, } }; - Some(res) - })??; - let container_tail = match container_expr { + }; + let container_tail = match expr? { ast::Expr::BlockExpr(block) => block.tail_expr(), expr => Some(expr), }; @@ -670,7 +670,7 @@ impl FunctionBody { container_tail.zip(self.tail_expr()).map_or(false, |(container_tail, body_tail)| { container_tail.syntax().text_range().contains_range(body_tail.syntax().text_range()) }); - Some(Modifiers { is_in_tail, is_const }) + Some(ContainerInfo { is_in_tail, is_const, ret_type: ty }) } fn return_ty(&self, ctx: &AssistContext) -> Option { @@ -694,30 +694,44 @@ impl FunctionBody { } /// Analyses the function body for external control flow. - fn external_control_flow(&self, ctx: &AssistContext) -> Option { + fn external_control_flow( + &self, + ctx: &AssistContext, + container_info: &ContainerInfo, + ) -> Option { let mut ret_expr = None; let mut try_expr = None; let mut break_expr = None; let mut continue_expr = None; let mut is_async = false; + let mut _is_unsafe = false; + let mut unsafe_depth = 0; let mut loop_depth = 0; self.preorder_expr(&mut |expr| { let expr = match expr { WalkEvent::Enter(e) => e, - WalkEvent::Leave( - ast::Expr::LoopExpr(_) | ast::Expr::ForExpr(_) | ast::Expr::WhileExpr(_), - ) => { - loop_depth -= 1; + WalkEvent::Leave(expr) => { + match expr { + ast::Expr::LoopExpr(_) + | ast::Expr::ForExpr(_) + | ast::Expr::WhileExpr(_) => loop_depth -= 1, + ast::Expr::EffectExpr(effect) if effect.unsafe_token().is_some() => { + unsafe_depth -= 1 + } + _ => (), + } return false; } - WalkEvent::Leave(_) => return false, }; match expr { ast::Expr::LoopExpr(_) | ast::Expr::ForExpr(_) | ast::Expr::WhileExpr(_) => { loop_depth += 1; } + ast::Expr::EffectExpr(effect) if effect.unsafe_token().is_some() => { + unsafe_depth += 1 + } ast::Expr::ReturnExpr(it) => { ret_expr = Some(it); } @@ -731,31 +745,21 @@ impl FunctionBody { continue_expr = Some(it); } ast::Expr::AwaitExpr(_) => is_async = true, + // FIXME: Do unsafe analysis on expression, sem highlighting knows this so we should be able + // to just lift that out of there + // expr if unsafe_depth ==0 && expr.is_unsafe => is_unsafe = true, _ => {} } false }); let kind = match (try_expr, ret_expr, break_expr, continue_expr) { - (Some(e), None, None, None) => { - let func = e.syntax().ancestors().find_map(ast::Fn::cast)?; - let def = ctx.sema.to_def(&func)?; - let ret_ty = def.ret_type(ctx.db()); + (Some(_), _, None, None) => { + let ret_ty = container_info.ret_type.clone()?; let kind = TryKind::of_ty(ret_ty, ctx)?; Some(FlowKind::Try { kind }) } - (Some(_), Some(r), None, None) => match r.expr() { - Some(expr) => { - if let Some(kind) = expr_err_kind(&expr, ctx) { - Some(FlowKind::TryReturn { expr, kind }) - } else { - cov_mark::hit!(external_control_flow_try_and_return_non_err); - return None; - } - } - None => return None, - }, (Some(_), _, _, _) => { cov_mark::hit!(external_control_flow_try_and_bc); return None; @@ -774,7 +778,7 @@ impl FunctionBody { (None, None, None, None) => None, }; - Some(ControlFlow { kind, is_async }) + Some(ControlFlow { kind, is_async, is_unsafe: _is_unsafe }) } /// find variables that should be extracted as params /// @@ -796,7 +800,7 @@ impl FunctionBody { } }) .map(|var| { - let usages = LocalUsages::find(ctx, var); + let usages = LocalUsages::find_local_usages(ctx, var); let ty = var.ty(ctx.db()); let is_copy = ty.is_copy(ctx.db()); Param { @@ -972,7 +976,7 @@ fn local_outlives_body( local: Local, parent: &SyntaxNode, ) -> Option { - let usages = LocalUsages::find(ctx, local); + let usages = LocalUsages::find_local_usages(ctx, local); let mut has_mut_usages = false; let mut any_outlives = false; for usage in usages.iter() { @@ -1007,24 +1011,6 @@ fn either_syntax(value: &Either) -> &SyntaxNode { } } -/// Checks is expr is `Err(_)` or `None` -fn expr_err_kind(expr: &ast::Expr, ctx: &AssistContext) -> Option { - let func_name = match expr { - ast::Expr::CallExpr(call_expr) => call_expr.expr()?, - ast::Expr::PathExpr(_) => expr.clone(), - _ => return None, - }; - let text = func_name.syntax().text(); - - if text == "Err" { - Some(TryKind::Result { ty: ctx.sema.type_of_expr(expr).map(TypeInfo::original)? }) - } else if text == "None" { - Some(TryKind::Option) - } else { - None - } -} - /// find where to put extracted function definition /// /// Function should be put right after returned node @@ -1133,9 +1119,7 @@ impl FlowHandler { FlowKind::Return(_) | FlowKind::Break(_) => { FlowHandler::IfOption { action } } - FlowKind::Try { kind } | FlowKind::TryReturn { kind, .. } => { - FlowHandler::Try { kind: kind.clone() } - } + FlowKind::Try { kind } => FlowHandler::Try { kind: kind.clone() }, } } else { match flow_kind { @@ -1145,9 +1129,7 @@ impl FlowHandler { FlowKind::Return(_) | FlowKind::Break(_) => { FlowHandler::MatchResult { err: action } } - FlowKind::Try { kind } | FlowKind::TryReturn { kind, .. } => { - FlowHandler::Try { kind: kind.clone() } - } + FlowKind::Try { kind } => FlowHandler::Try { kind: kind.clone() }, } } } @@ -1241,22 +1223,25 @@ fn format_function( let body = make_body(ctx, old_indent, new_indent, fun); let const_kw = if fun.mods.is_const { "const " } else { "" }; let async_kw = if fun.control_flow.is_async { "async " } else { "" }; + let unsafe_kw = if fun.control_flow.is_unsafe { "unsafe " } else { "" }; match ctx.config.snippet_cap { Some(_) => format_to!( fn_def, - "\n\n{}{}{}fn $0{}{}", + "\n\n{}{}{}{}fn $0{}{}", new_indent, const_kw, async_kw, + unsafe_kw, fun.name, params ), None => format_to!( fn_def, - "\n\n{}{}{}fn {}{}", + "\n\n{}{}{}{}fn {}{}", new_indent, const_kw, async_kw, + unsafe_kw, fun.name, params ), @@ -1510,7 +1495,7 @@ fn fix_param_usages(ctx: &AssistContext, params: &[Param], syntax: &SyntaxNode) continue; } - let usages = LocalUsages::find(ctx, param.var); + let usages = LocalUsages::find_local_usages(ctx, param.var); let usages = usages .iter() .filter(|reference| syntax.text_range().contains_range(reference.range)) @@ -3752,8 +3737,7 @@ fn foo() -> Option<()> { #[test] fn try_and_return_ok() { - cov_mark::check!(external_control_flow_try_and_return_non_err); - check_assist_not_applicable( + check_assist( extract_function, r#" //- minicore: result @@ -3767,6 +3751,23 @@ fn foo() -> Result<(), i64> { let h = 1 + m; Ok(()) } +"#, + r#" +fn foo() -> Result<(), i64> { + let n = 1; + let m = fun_name()?; + let h = 1 + m; + Ok(()) +} + +fn $0fun_name() -> Result { + let k = foo()?; + if k == 42 { + return Ok(1); + } + let m = k + 1; + Ok(m) +} "#, ); }