From 992e3b4f03e0bd348699337317091cc34268684b Mon Sep 17 00:00:00 2001
From: Waffle Lapkin <waffle.lapkin@gmail.com>
Date: Fri, 24 Jan 2025 04:19:10 +0100
Subject: [PATCH] fix tail call checks wrt `#[track_caller]`

only check the caller + disallow caller having the attr.
---
 .../rustc_mir_build/src/check_tail_calls.rs   | 51 ++++++++++---------
 tests/crashes/134336.rs                       | 11 ----
 .../ui/explicit-tail-calls/become-trait-fn.rs | 19 +++++++
 .../callee_is_track_caller.rs                 | 15 ++++++
 .../caller_is_track_caller.rs                 | 16 ++++++
 .../caller_is_track_caller.stderr             | 14 +++++
 6 files changed, 92 insertions(+), 34 deletions(-)
 delete mode 100644 tests/crashes/134336.rs
 create mode 100644 tests/ui/explicit-tail-calls/become-trait-fn.rs
 create mode 100644 tests/ui/explicit-tail-calls/callee_is_track_caller.rs
 create mode 100644 tests/ui/explicit-tail-calls/caller_is_track_caller.rs
 create mode 100644 tests/ui/explicit-tail-calls/caller_is_track_caller.stderr

diff --git a/compiler/rustc_mir_build/src/check_tail_calls.rs b/compiler/rustc_mir_build/src/check_tail_calls.rs
index 0659e3ea314..73f8bc4e901 100644
--- a/compiler/rustc_mir_build/src/check_tail_calls.rs
+++ b/compiler/rustc_mir_build/src/check_tail_calls.rs
@@ -131,11 +131,24 @@ impl<'tcx> TailCallCkVisitor<'_, 'tcx> {
         }
 
         {
+            // `#[track_caller]` affects the ABI of a function (by adding a location argument),
+            // so a `track_caller` can only tail call other `track_caller` functions.
+            //
+            // The issue is however that we can't know if a function is `track_caller` or not at
+            // this point (THIR can be polymorphic, we may have an unresolved trait function).
+            // We could only allow functions that we *can* resolve and *are* `track_caller`,
+            // but that would turn changing `track_caller`-ness into a breaking change,
+            // which is probably undesirable.
+            //
+            // Also note that we don't check callee's `track_caller`-ness at all, mostly for the
+            // reasons above, but also because we can always tailcall the shim we'd generate for
+            // coercing the function to an `fn()` pointer. (although in that case the tailcall is
+            // basically useless -- the shim calls the actual function, so tailcalling the shim is
+            // equivalent to calling the function)
             let caller_needs_location = self.needs_location(self.caller_ty);
-            let callee_needs_location = self.needs_location(ty);
 
-            if caller_needs_location != callee_needs_location {
-                self.report_track_caller_mismatch(expr.span, caller_needs_location);
+            if caller_needs_location {
+                self.report_track_caller_caller(expr.span);
             }
         }
 
@@ -149,7 +162,9 @@ impl<'tcx> TailCallCkVisitor<'_, 'tcx> {
     }
 
     /// Returns true if function of type `ty` needs location argument
-    /// (i.e. if a function is marked as `#[track_caller]`)
+    /// (i.e. if a function is marked as `#[track_caller]`).
+    ///
+    /// Panics if the function's instance can't be immediately resolved.
     fn needs_location(&self, ty: Ty<'tcx>) -> bool {
         if let &ty::FnDef(did, substs) = ty.kind() {
             let instance =
@@ -292,25 +307,15 @@ impl<'tcx> TailCallCkVisitor<'_, 'tcx> {
         self.found_errors = Err(err);
     }
 
-    fn report_track_caller_mismatch(&mut self, sp: Span, caller_needs_location: bool) {
-        let err = match caller_needs_location {
-            true => self
-                .tcx
-                .dcx()
-                .struct_span_err(
-                    sp,
-                    "a function marked with `#[track_caller]` cannot tail-call one that is not",
-                )
-                .emit(),
-            false => self
-                .tcx
-                .dcx()
-                .struct_span_err(
-                    sp,
-                    "a function mot marked with `#[track_caller]` cannot tail-call one that is",
-                )
-                .emit(),
-        };
+    fn report_track_caller_caller(&mut self, sp: Span) {
+        let err = self
+            .tcx
+            .dcx()
+            .struct_span_err(
+                sp,
+                "a function marked with `#[track_caller]` cannot perform a tail-call",
+            )
+            .emit();
 
         self.found_errors = Err(err);
     }
diff --git a/tests/crashes/134336.rs b/tests/crashes/134336.rs
deleted file mode 100644
index 14b88e14f04..00000000000
--- a/tests/crashes/134336.rs
+++ /dev/null
@@ -1,11 +0,0 @@
-//@ known-bug: #134336
-#![expect(incomplete_features)]
-#![feature(explicit_tail_calls)]
-
-trait Tr {
-    fn f();
-}
-
-fn g<T: Tr>() {
-    become T::f();
-}
diff --git a/tests/ui/explicit-tail-calls/become-trait-fn.rs b/tests/ui/explicit-tail-calls/become-trait-fn.rs
new file mode 100644
index 00000000000..03255f15dd0
--- /dev/null
+++ b/tests/ui/explicit-tail-calls/become-trait-fn.rs
@@ -0,0 +1,19 @@
+// regression test for <https://github.com/rust-lang/rust/issues/134336>
+// this previously caused an ICE, because we would compare `#[track_caller]` of
+// the callee and the caller (in tailcalls specifically), leading to a problem
+// since `T::f`'s instance can't be resolved (we do not know if the function is
+// or isn't marked with `#[track_caller]`!)
+//
+//@ check-pass
+#![expect(incomplete_features)]
+#![feature(explicit_tail_calls)]
+
+trait Tr {
+    fn f();
+}
+
+fn g<T: Tr>() {
+    become T::f();
+}
+
+fn main() {}
diff --git a/tests/ui/explicit-tail-calls/callee_is_track_caller.rs b/tests/ui/explicit-tail-calls/callee_is_track_caller.rs
new file mode 100644
index 00000000000..bcb93fda8c8
--- /dev/null
+++ b/tests/ui/explicit-tail-calls/callee_is_track_caller.rs
@@ -0,0 +1,15 @@
+//@ check-pass
+// FIXME(explicit_tail_calls): make this run-pass, once tail calls are properly implemented
+#![expect(incomplete_features)]
+#![feature(explicit_tail_calls)]
+
+fn a(x: u32) -> u32 {
+    become b(x);
+}
+
+#[track_caller]
+fn b(x: u32) -> u32 { x + 42 }
+
+fn main() {
+    assert_eq!(a(12), 54);
+}
diff --git a/tests/ui/explicit-tail-calls/caller_is_track_caller.rs b/tests/ui/explicit-tail-calls/caller_is_track_caller.rs
new file mode 100644
index 00000000000..4e5f3f12f83
--- /dev/null
+++ b/tests/ui/explicit-tail-calls/caller_is_track_caller.rs
@@ -0,0 +1,16 @@
+#![expect(incomplete_features)]
+#![feature(explicit_tail_calls)]
+
+#[track_caller]
+fn a() {
+    become b(); //~ error: a function marked with `#[track_caller]` cannot perform a tail-call
+}
+
+fn b() {}
+
+#[track_caller]
+fn c() {
+    become a(); //~ error: a function marked with `#[track_caller]` cannot perform a tail-call
+}
+
+fn main() {}
diff --git a/tests/ui/explicit-tail-calls/caller_is_track_caller.stderr b/tests/ui/explicit-tail-calls/caller_is_track_caller.stderr
new file mode 100644
index 00000000000..79b9b45986c
--- /dev/null
+++ b/tests/ui/explicit-tail-calls/caller_is_track_caller.stderr
@@ -0,0 +1,14 @@
+error: a function marked with `#[track_caller]` cannot perform a tail-call
+  --> $DIR/caller_is_track_caller.rs:6:5
+   |
+LL |     become b();
+   |     ^^^^^^^^^^
+
+error: a function marked with `#[track_caller]` cannot perform a tail-call
+  --> $DIR/caller_is_track_caller.rs:13:5
+   |
+LL |     become a();
+   |     ^^^^^^^^^^
+
+error: aborting due to 2 previous errors
+