From 9ce44be2ab7e9a99eece1c2e254f15ad1c6d73c5 Mon Sep 17 00:00:00 2001
From: Paul Daniel Faria <Nashenas88@users.noreply.github.com>
Date: Wed, 27 May 2020 08:51:08 -0400
Subject: [PATCH] Address review comments, have MissingUnsafe diagnostic point
 to each unsafe use, update tests

---
 crates/ra_hir_ty/src/diagnostics.rs | 15 ++++++---------
 crates/ra_hir_ty/src/expr.rs        | 23 ++++++++++-------------
 crates/ra_hir_ty/src/tests.rs       |  6 +++---
 3 files changed, 19 insertions(+), 25 deletions(-)

diff --git a/crates/ra_hir_ty/src/diagnostics.rs b/crates/ra_hir_ty/src/diagnostics.rs
index 0b7bcdc9def..a59efb34768 100644
--- a/crates/ra_hir_ty/src/diagnostics.rs
+++ b/crates/ra_hir_ty/src/diagnostics.rs
@@ -3,10 +3,7 @@
 use std::any::Any;
 
 use hir_expand::{db::AstDatabase, name::Name, HirFileId, InFile};
-use ra_syntax::{
-    ast::{self, NameOwner},
-    AstNode, AstPtr, SyntaxNodePtr,
-};
+use ra_syntax::{ast, AstNode, AstPtr, SyntaxNodePtr};
 use stdx::format_to;
 
 pub use hir_def::{diagnostics::UnresolvedModule, expr::MatchArm, path::Path};
@@ -176,15 +173,15 @@ impl AstDiagnostic for BreakOutsideOfLoop {
 #[derive(Debug)]
 pub struct MissingUnsafe {
     pub file: HirFileId,
-    pub fn_def: AstPtr<ast::FnDef>,
+    pub expr: AstPtr<ast::Expr>,
 }
 
 impl Diagnostic for MissingUnsafe {
     fn message(&self) -> String {
-        format!("Missing unsafe keyword on fn")
+        format!("This operation is unsafe and requires an unsafe function or block")
     }
     fn source(&self) -> InFile<SyntaxNodePtr> {
-        InFile { file_id: self.file, value: self.fn_def.clone().into() }
+        InFile { file_id: self.file, value: self.expr.clone().into() }
     }
     fn as_any(&self) -> &(dyn Any + Send + 'static) {
         self
@@ -192,11 +189,11 @@ impl Diagnostic for MissingUnsafe {
 }
 
 impl AstDiagnostic for MissingUnsafe {
-    type AST = ast::Name;
+    type AST = ast::Expr;
 
     fn ast(&self, db: &impl AstDatabase) -> Self::AST {
         let root = db.parse_or_expand(self.source().file_id).unwrap();
         let node = self.source().value.to_node(&root);
-        ast::FnDef::cast(node).unwrap().name().unwrap()
+        ast::Expr::cast(node).unwrap()
     }
 }
diff --git a/crates/ra_hir_ty/src/expr.rs b/crates/ra_hir_ty/src/expr.rs
index ce73251b882..5f332aadbba 100644
--- a/crates/ra_hir_ty/src/expr.rs
+++ b/crates/ra_hir_ty/src/expr.rs
@@ -2,9 +2,7 @@
 
 use std::sync::Arc;
 
-use hir_def::{
-    path::path, resolver::HasResolver, src::HasSource, AdtId, DefWithBodyId, FunctionId, Lookup,
-};
+use hir_def::{path::path, resolver::HasResolver, AdtId, DefWithBodyId, FunctionId};
 use hir_expand::diagnostics::DiagnosticSink;
 use ra_syntax::{ast, AstPtr};
 use rustc_hash::FxHashSet;
@@ -346,7 +344,7 @@ pub fn unsafe_expressions(
                 }
             }
             Expr::Call { callee, .. } => {
-                let ty = &infer.type_of_expr[*callee];
+                let ty = &infer[*callee];
                 if let &Ty::Apply(ApplicationTy {
                     ctor: TypeCtor::FnDef(CallableDef::FunctionId(func)),
                     ..
@@ -361,7 +359,7 @@ pub fn unsafe_expressions(
                 if infer
                     .method_resolution(id)
                     .map(|func| db.function_data(func).is_unsafe)
-                    .unwrap_or_else(|| false)
+                    .unwrap_or(false)
                 {
                     unsafe_exprs.push(UnsafeExpr::new(id));
                 }
@@ -409,7 +407,7 @@ impl<'a, 'b> UnsafeValidator<'a, 'b> {
         let func_data = db.function_data(self.func);
         if func_data.is_unsafe
             || unsafe_expressions
-                .into_iter()
+                .iter()
                 .filter(|unsafe_expr| !unsafe_expr.inside_unsafe_block)
                 .count()
                 == 0
@@ -417,12 +415,11 @@ impl<'a, 'b> UnsafeValidator<'a, 'b> {
             return;
         }
 
-        let loc = self.func.lookup(db.upcast());
-        let in_file = loc.source(db.upcast());
-
-        let file = in_file.file_id;
-        let fn_def = AstPtr::new(&in_file.value);
-
-        self.sink.push(MissingUnsafe { file, fn_def })
+        let (_, body_source) = db.body_with_source_map(def);
+        for unsafe_expr in unsafe_expressions {
+            if let Ok(in_file) = body_source.as_ref().expr_syntax(unsafe_expr.expr) {
+                self.sink.push(MissingUnsafe { file: in_file.file_id, expr: in_file.value })
+            }
+        }
     }
 }
diff --git a/crates/ra_hir_ty/src/tests.rs b/crates/ra_hir_ty/src/tests.rs
index 4bc2e8b276f..26b3aeb50e0 100644
--- a/crates/ra_hir_ty/src/tests.rs
+++ b/crates/ra_hir_ty/src/tests.rs
@@ -552,7 +552,7 @@ fn missing_unsafe() {
     .diagnostics()
     .0;
 
-    assert_snapshot!(diagnostics, @r#""fn missing_unsafe() {\n    let x = &5 as *const usize;\n    let y = *x;\n}": Missing unsafe keyword on fn"#);
+    assert_snapshot!(diagnostics, @r#""*x": This operation is unsafe and requires an unsafe function or block"#);
 }
 
 #[test]
@@ -573,7 +573,7 @@ fn missing_unsafe() {
     .diagnostics()
     .0;
 
-    assert_snapshot!(diagnostics, @r#""fn missing_unsafe() {\n    unsafe_fn();\n}": Missing unsafe keyword on fn"#);
+    assert_snapshot!(diagnostics, @r#""unsafe_fn()": This operation is unsafe and requires an unsafe function or block"#);
 }
 
 #[test]
@@ -599,7 +599,7 @@ fn missing_unsafe() {
     .diagnostics()
     .0;
 
-    assert_snapshot!(diagnostics, @r#""fn missing_unsafe() {\n    HasUnsafe.unsafe_fn();\n}": Missing unsafe keyword on fn"#);
+    assert_snapshot!(diagnostics, @r#""HasUnsafe.unsafe_fn()": This operation is unsafe and requires an unsafe function or block"#);
 }
 
 #[test]