From c9daec2585c464f2e7d37756c341b217ca795bde Mon Sep 17 00:00:00 2001
From: y21 <30553356+y21@users.noreply.github.com>
Date: Mon, 12 Jun 2023 02:43:23 +0200
Subject: [PATCH] [`unnecessary_fold`]: suggest turbofish if necessary

---
 clippy_lints/src/methods/unnecessary_fold.rs | 124 +++++++++++++++++--
 tests/ui/unnecessary_fold.fixed              |  24 ++++
 tests/ui/unnecessary_fold.rs                 |  24 ++++
 tests/ui/unnecessary_fold.stderr             |  56 ++++++++-
 4 files changed, 217 insertions(+), 11 deletions(-)

diff --git a/clippy_lints/src/methods/unnecessary_fold.rs b/clippy_lints/src/methods/unnecessary_fold.rs
index 5a3d12fd790..de29d99cf8d 100644
--- a/clippy_lints/src/methods/unnecessary_fold.rs
+++ b/clippy_lints/src/methods/unnecessary_fold.rs
@@ -7,10 +7,74 @@ use rustc_errors::Applicability;
 use rustc_hir as hir;
 use rustc_hir::PatKind;
 use rustc_lint::LateContext;
+use rustc_middle::ty;
 use rustc_span::{source_map::Span, sym};
 
 use super::UNNECESSARY_FOLD;
 
+/// No turbofish needed in any case.
+fn no_turbofish(_: &LateContext<'_>, _: &hir::Expr<'_>) -> bool {
+    false
+}
+
+/// Turbofish (`::<T>`) may be needed, but can be omitted if we are certain
+/// that the type can be inferred from usage.
+fn turbofish_if_not_inferred(cx: &LateContext<'_>, expr: &hir::Expr<'_>) -> bool {
+    let parent = cx.tcx.hir().get_parent(expr.hir_id);
+
+    // some common cases where turbofish isn't needed:
+    // - assigned to a local variable with a type annotation
+    if let hir::Node::Local(local) = parent
+        && local.ty.is_some()
+    {
+        return false;
+    }
+
+    // - part of a function call argument, can be inferred from the function signature (provided that
+    //   the parameter is not a generic type parameter)
+    if let hir::Node::Expr(parent_expr) = parent
+        && let hir::ExprKind::Call(recv, args) = parent_expr.kind
+        && let hir::ExprKind::Path(ref qpath) = recv.kind
+        && let Some(fn_def_id) = cx.qpath_res(qpath, recv.hir_id).opt_def_id()
+        && let fn_sig = cx.tcx.fn_sig(fn_def_id).skip_binder().skip_binder()
+        && let Some(arg_pos) = args.iter().position(|arg| arg.hir_id == expr.hir_id)
+        && let Some(ty) = fn_sig.inputs().get(arg_pos)
+        && !matches!(ty.kind(), ty::Param(_))
+    {
+        return false;
+    }
+
+    // if it's neither of those, stay on the safe side and suggest turbofish,
+    // even if it could work!
+    true
+}
+
+#[derive(Copy, Clone)]
+struct Replacement {
+    method_name: &'static str,
+    has_args: bool,
+    requires_turbofish: fn(&LateContext<'_>, &hir::Expr<'_>) -> bool,
+}
+impl Replacement {
+    /// `any(f)`, `all(f)`
+    pub fn non_generic(method_name: &'static str) -> Self {
+        Self {
+            method_name,
+            has_args: true,
+            requires_turbofish: no_turbofish,
+        }
+    }
+
+    /// `sum::<T>()`, `product::<T>()`
+    pub fn generic(method_name: &'static str) -> Self {
+        Self {
+            method_name,
+            has_args: false,
+            requires_turbofish: turbofish_if_not_inferred,
+        }
+    }
+}
+
 pub(super) fn check(
     cx: &LateContext<'_>,
     expr: &hir::Expr<'_>,
@@ -24,8 +88,7 @@ pub(super) fn check(
         acc: &hir::Expr<'_>,
         fold_span: Span,
         op: hir::BinOpKind,
-        replacement_method_name: &str,
-        replacement_has_args: bool,
+        replacement: Replacement,
     ) {
         if_chain! {
             // Extract the body of the closure passed to fold
@@ -43,18 +106,27 @@ pub(super) fn check(
             if let PatKind::Binding(_, second_arg_id, second_arg_ident, _) = strip_pat_refs(param_b.pat).kind;
 
             if path_to_local_id(left_expr, first_arg_id);
-            if replacement_has_args || path_to_local_id(right_expr, second_arg_id);
+            if replacement.has_args || path_to_local_id(right_expr, second_arg_id);
 
             then {
                 let mut applicability = Applicability::MachineApplicable;
-                let sugg = if replacement_has_args {
+
+                let turbofish = if (replacement.requires_turbofish)(cx, expr) {
+                    format!("::<{}>", cx.typeck_results().expr_ty_adjusted(right_expr).peel_refs())
+                } else {
+                    String::new()
+                };
+
+                let sugg = if replacement.has_args {
                     format!(
-                        "{replacement_method_name}(|{second_arg_ident}| {r})",
+                        "{method}{turbofish}(|{second_arg_ident}| {r})",
+                        method = replacement.method_name,
                         r = snippet_with_applicability(cx, right_expr.span, "EXPR", &mut applicability),
                     )
                 } else {
                     format!(
-                        "{replacement_method_name}()",
+                        "{method}{turbofish}()",
+                        method = replacement.method_name,
                     )
                 };
 
@@ -80,11 +152,43 @@ pub(super) fn check(
     // Check if the first argument to .fold is a suitable literal
     if let hir::ExprKind::Lit(lit) = init.kind {
         match lit.node {
-            ast::LitKind::Bool(false) => check_fold_with_op(cx, expr, acc, fold_span, hir::BinOpKind::Or, "any", true),
-            ast::LitKind::Bool(true) => check_fold_with_op(cx, expr, acc, fold_span, hir::BinOpKind::And, "all", true),
-            ast::LitKind::Int(0, _) => check_fold_with_op(cx, expr, acc, fold_span, hir::BinOpKind::Add, "sum", false),
+            ast::LitKind::Bool(false) => {
+                check_fold_with_op(
+                    cx,
+                    expr,
+                    acc,
+                    fold_span,
+                    hir::BinOpKind::Or,
+                    Replacement::non_generic("any"),
+                );
+            },
+            ast::LitKind::Bool(true) => {
+                check_fold_with_op(
+                    cx,
+                    expr,
+                    acc,
+                    fold_span,
+                    hir::BinOpKind::And,
+                    Replacement::non_generic("all"),
+                );
+            },
+            ast::LitKind::Int(0, _) => check_fold_with_op(
+                cx,
+                expr,
+                acc,
+                fold_span,
+                hir::BinOpKind::Add,
+                Replacement::generic("sum"),
+            ),
             ast::LitKind::Int(1, _) => {
-                check_fold_with_op(cx, expr, acc, fold_span, hir::BinOpKind::Mul, "product", false);
+                check_fold_with_op(
+                    cx,
+                    expr,
+                    acc,
+                    fold_span,
+                    hir::BinOpKind::Mul,
+                    Replacement::generic("product"),
+                );
             },
             _ => (),
         }
diff --git a/tests/ui/unnecessary_fold.fixed b/tests/ui/unnecessary_fold.fixed
index 2bed14973ca..bd1d4a152ae 100644
--- a/tests/ui/unnecessary_fold.fixed
+++ b/tests/ui/unnecessary_fold.fixed
@@ -49,4 +49,28 @@ fn unnecessary_fold_over_multiple_lines() {
         .any(|x| x > 2);
 }
 
+fn issue10000() {
+    use std::collections::HashMap;
+    use std::hash::BuildHasher;
+
+    fn anything<T>(_: T) {}
+    fn num(_: i32) {}
+    fn smoketest_map<S: BuildHasher>(mut map: HashMap<i32, i32, S>) {
+        map.insert(0, 0);
+        assert_eq!(map.values().sum::<i32>(), 0);
+
+        // more cases:
+        let _ = map.values().sum::<i32>();
+        let _ = map.values().product::<i32>();
+        let _: i32 = map.values().sum();
+        let _: i32 = map.values().product();
+        anything(map.values().sum::<i32>());
+        anything(map.values().product::<i32>());
+        num(map.values().sum());
+        num(map.values().product());
+    }
+
+    smoketest_map(HashMap::new());
+}
+
 fn main() {}
diff --git a/tests/ui/unnecessary_fold.rs b/tests/ui/unnecessary_fold.rs
index a3cec8ea3d5..d27cc460c44 100644
--- a/tests/ui/unnecessary_fold.rs
+++ b/tests/ui/unnecessary_fold.rs
@@ -49,4 +49,28 @@ fn unnecessary_fold_over_multiple_lines() {
         .fold(false, |acc, x| acc || x > 2);
 }
 
+fn issue10000() {
+    use std::collections::HashMap;
+    use std::hash::BuildHasher;
+
+    fn anything<T>(_: T) {}
+    fn num(_: i32) {}
+    fn smoketest_map<S: BuildHasher>(mut map: HashMap<i32, i32, S>) {
+        map.insert(0, 0);
+        assert_eq!(map.values().fold(0, |x, y| x + y), 0);
+
+        // more cases:
+        let _ = map.values().fold(0, |x, y| x + y);
+        let _ = map.values().fold(1, |x, y| x * y);
+        let _: i32 = map.values().fold(0, |x, y| x + y);
+        let _: i32 = map.values().fold(1, |x, y| x * y);
+        anything(map.values().fold(0, |x, y| x + y));
+        anything(map.values().fold(1, |x, y| x * y));
+        num(map.values().fold(0, |x, y| x + y));
+        num(map.values().fold(1, |x, y| x * y));
+    }
+
+    smoketest_map(HashMap::new());
+}
+
 fn main() {}
diff --git a/tests/ui/unnecessary_fold.stderr b/tests/ui/unnecessary_fold.stderr
index 22c44588ab7..98979f7477f 100644
--- a/tests/ui/unnecessary_fold.stderr
+++ b/tests/ui/unnecessary_fold.stderr
@@ -36,5 +36,59 @@ error: this `.fold` can be written more succinctly using another method
 LL |         .fold(false, |acc, x| acc || x > 2);
    |          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ help: try: `any(|x| x > 2)`
 
-error: aborting due to 6 previous errors
+error: this `.fold` can be written more succinctly using another method
+  --> $DIR/unnecessary_fold.rs:60:33
+   |
+LL |         assert_eq!(map.values().fold(0, |x, y| x + y), 0);
+   |                                 ^^^^^^^^^^^^^^^^^^^^^ help: try: `sum::<i32>()`
+
+error: this `.fold` can be written more succinctly using another method
+  --> $DIR/unnecessary_fold.rs:63:30
+   |
+LL |         let _ = map.values().fold(0, |x, y| x + y);
+   |                              ^^^^^^^^^^^^^^^^^^^^^ help: try: `sum::<i32>()`
+
+error: this `.fold` can be written more succinctly using another method
+  --> $DIR/unnecessary_fold.rs:64:30
+   |
+LL |         let _ = map.values().fold(1, |x, y| x * y);
+   |                              ^^^^^^^^^^^^^^^^^^^^^ help: try: `product::<i32>()`
+
+error: this `.fold` can be written more succinctly using another method
+  --> $DIR/unnecessary_fold.rs:65:35
+   |
+LL |         let _: i32 = map.values().fold(0, |x, y| x + y);
+   |                                   ^^^^^^^^^^^^^^^^^^^^^ help: try: `sum()`
+
+error: this `.fold` can be written more succinctly using another method
+  --> $DIR/unnecessary_fold.rs:66:35
+   |
+LL |         let _: i32 = map.values().fold(1, |x, y| x * y);
+   |                                   ^^^^^^^^^^^^^^^^^^^^^ help: try: `product()`
+
+error: this `.fold` can be written more succinctly using another method
+  --> $DIR/unnecessary_fold.rs:67:31
+   |
+LL |         anything(map.values().fold(0, |x, y| x + y));
+   |                               ^^^^^^^^^^^^^^^^^^^^^ help: try: `sum::<i32>()`
+
+error: this `.fold` can be written more succinctly using another method
+  --> $DIR/unnecessary_fold.rs:68:31
+   |
+LL |         anything(map.values().fold(1, |x, y| x * y));
+   |                               ^^^^^^^^^^^^^^^^^^^^^ help: try: `product::<i32>()`
+
+error: this `.fold` can be written more succinctly using another method
+  --> $DIR/unnecessary_fold.rs:69:26
+   |
+LL |         num(map.values().fold(0, |x, y| x + y));
+   |                          ^^^^^^^^^^^^^^^^^^^^^ help: try: `sum()`
+
+error: this `.fold` can be written more succinctly using another method
+  --> $DIR/unnecessary_fold.rs:70:26
+   |
+LL |         num(map.values().fold(1, |x, y| x * y));
+   |                          ^^^^^^^^^^^^^^^^^^^^^ help: try: `product()`
+
+error: aborting due to 15 previous errors