From 981fc6e174ced2b3448e61b0851ad2db0fd5ddb3 Mon Sep 17 00:00:00 2001
From: Michael Goulet <michael@errs.io>
Date: Fri, 22 Dec 2023 20:52:31 +0000
Subject: [PATCH] Don't pass body into check_closure and child functions

---
 compiler/rustc_hir_typeck/src/closure.rs | 96 ++++++++++++------------
 1 file changed, 46 insertions(+), 50 deletions(-)

diff --git a/compiler/rustc_hir_typeck/src/closure.rs b/compiler/rustc_hir_typeck/src/closure.rs
index cd42be28e6f..5d517cd55f3 100644
--- a/compiler/rustc_hir_typeck/src/closure.rs
+++ b/compiler/rustc_hir_typeck/src/closure.rs
@@ -60,25 +60,26 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
             }
             None => (None, None),
         };
-        let body = self.tcx.hir().body(closure.body);
-        self.check_closure(closure, expr_span, expected_kind, body, expected_sig)
+
+        self.check_closure(closure, expr_span, expected_kind, expected_sig)
     }
 
-    #[instrument(skip(self, closure, body), level = "debug", ret)]
+    #[instrument(skip(self, closure), level = "debug", ret)]
     fn check_closure(
         &self,
         closure: &hir::Closure<'tcx>,
         expr_span: Span,
         opt_kind: Option<ty::ClosureKind>,
-        body: &'tcx hir::Body<'tcx>,
         expected_sig: Option<ExpectedSig<'tcx>>,
     ) -> Ty<'tcx> {
+        let body = self.tcx.hir().body(closure.body);
+
         trace!("decl = {:#?}", closure.fn_decl);
         let expr_def_id = closure.def_id;
         debug!(?expr_def_id);
 
         let ClosureSignatures { bound_sig, liberated_sig } =
-            self.sig_of_closure(expr_def_id, closure.fn_decl, body, expected_sig);
+            self.sig_of_closure(expr_def_id, closure.fn_decl, body.coroutine_kind, expected_sig);
 
         debug!(?bound_sig, ?liberated_sig);
 
@@ -351,28 +352,28 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
         &self,
         expr_def_id: LocalDefId,
         decl: &hir::FnDecl<'_>,
-        body: &hir::Body<'_>,
+        coroutine_kind: Option<hir::CoroutineKind>,
         expected_sig: Option<ExpectedSig<'tcx>>,
     ) -> ClosureSignatures<'tcx> {
         if let Some(e) = expected_sig {
-            self.sig_of_closure_with_expectation(expr_def_id, decl, body, e)
+            self.sig_of_closure_with_expectation(expr_def_id, decl, coroutine_kind, e)
         } else {
-            self.sig_of_closure_no_expectation(expr_def_id, decl, body)
+            self.sig_of_closure_no_expectation(expr_def_id, decl, coroutine_kind)
         }
     }
 
     /// If there is no expected signature, then we will convert the
     /// types that the user gave into a signature.
-    #[instrument(skip(self, expr_def_id, decl, body), level = "debug")]
+    #[instrument(skip(self, expr_def_id, decl), level = "debug")]
     fn sig_of_closure_no_expectation(
         &self,
         expr_def_id: LocalDefId,
         decl: &hir::FnDecl<'_>,
-        body: &hir::Body<'_>,
+        coroutine_kind: Option<hir::CoroutineKind>,
     ) -> ClosureSignatures<'tcx> {
-        let bound_sig = self.supplied_sig_of_closure(expr_def_id, decl, body);
+        let bound_sig = self.supplied_sig_of_closure(expr_def_id, decl, coroutine_kind);
 
-        self.closure_sigs(expr_def_id, body, bound_sig)
+        self.closure_sigs(expr_def_id, bound_sig)
     }
 
     /// Invoked to compute the signature of a closure expression. This
@@ -422,24 +423,23 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
     /// - `expected_sig`: the expected signature (if any). Note that
     ///   this is missing a binder: that is, there may be late-bound
     ///   regions with depth 1, which are bound then by the closure.
-    #[instrument(skip(self, expr_def_id, decl, body), level = "debug")]
+    #[instrument(skip(self, expr_def_id, decl), level = "debug")]
     fn sig_of_closure_with_expectation(
         &self,
         expr_def_id: LocalDefId,
         decl: &hir::FnDecl<'_>,
-        body: &hir::Body<'_>,
+        coroutine_kind: Option<hir::CoroutineKind>,
         expected_sig: ExpectedSig<'tcx>,
     ) -> ClosureSignatures<'tcx> {
         // Watch out for some surprises and just ignore the
         // expectation if things don't see to match up with what we
         // expect.
         if expected_sig.sig.c_variadic() != decl.c_variadic {
-            return self.sig_of_closure_no_expectation(expr_def_id, decl, body);
+            return self.sig_of_closure_no_expectation(expr_def_id, decl, coroutine_kind);
         } else if expected_sig.sig.skip_binder().inputs_and_output.len() != decl.inputs.len() + 1 {
             return self.sig_of_closure_with_mismatched_number_of_arguments(
                 expr_def_id,
                 decl,
-                body,
                 expected_sig,
             );
         }
@@ -463,16 +463,21 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
         // anonymize away, so as not to confuse the user.
         let bound_sig = self.tcx.anonymize_bound_vars(bound_sig);
 
-        let closure_sigs = self.closure_sigs(expr_def_id, body, bound_sig);
+        let closure_sigs = self.closure_sigs(expr_def_id, bound_sig);
 
         // Up till this point, we have ignored the annotations that the user
         // gave. This function will check that they unify successfully.
         // Along the way, it also writes out entries for types that the user
         // wrote into our typeck results, which are then later used by the privacy
         // check.
-        match self.merge_supplied_sig_with_expectation(expr_def_id, decl, body, closure_sigs) {
+        match self.merge_supplied_sig_with_expectation(
+            expr_def_id,
+            decl,
+            coroutine_kind,
+            closure_sigs,
+        ) {
             Ok(infer_ok) => self.register_infer_ok_obligations(infer_ok),
-            Err(_) => self.sig_of_closure_no_expectation(expr_def_id, decl, body),
+            Err(_) => self.sig_of_closure_no_expectation(expr_def_id, decl, coroutine_kind),
         }
     }
 
@@ -480,7 +485,6 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
         &self,
         expr_def_id: LocalDefId,
         decl: &hir::FnDecl<'_>,
-        body: &hir::Body<'_>,
         expected_sig: ExpectedSig<'tcx>,
     ) -> ClosureSignatures<'tcx> {
         let expr_map_node = self.tcx.hir_node_by_def_id(expr_def_id);
@@ -511,25 +515,25 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
 
         let error_sig = self.error_sig_of_closure(decl, guar);
 
-        self.closure_sigs(expr_def_id, body, error_sig)
+        self.closure_sigs(expr_def_id, error_sig)
     }
 
     /// Enforce the user's types against the expectation. See
     /// `sig_of_closure_with_expectation` for details on the overall
     /// strategy.
-    #[instrument(level = "debug", skip(self, expr_def_id, decl, body, expected_sigs))]
+    #[instrument(level = "debug", skip(self, expr_def_id, decl, expected_sigs))]
     fn merge_supplied_sig_with_expectation(
         &self,
         expr_def_id: LocalDefId,
         decl: &hir::FnDecl<'_>,
-        body: &hir::Body<'_>,
+        coroutine_kind: Option<hir::CoroutineKind>,
         mut expected_sigs: ClosureSignatures<'tcx>,
     ) -> InferResult<'tcx, ClosureSignatures<'tcx>> {
         // Get the signature S that the user gave.
         //
         // (See comment on `sig_of_closure_with_expectation` for the
         // meaning of these letters.)
-        let supplied_sig = self.supplied_sig_of_closure(expr_def_id, decl, body);
+        let supplied_sig = self.supplied_sig_of_closure(expr_def_id, decl, coroutine_kind);
 
         debug!(?supplied_sig);
 
@@ -611,17 +615,17 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
     /// types that the user gave into a signature.
     ///
     /// Also, record this closure signature for later.
-    #[instrument(skip(self, decl, body), level = "debug", ret)]
+    #[instrument(skip(self, decl), level = "debug", ret)]
     fn supplied_sig_of_closure(
         &self,
         expr_def_id: LocalDefId,
         decl: &hir::FnDecl<'_>,
-        body: &hir::Body<'_>,
+        coroutine_kind: Option<hir::CoroutineKind>,
     ) -> ty::PolyFnSig<'tcx> {
         let astconv: &dyn AstConv<'_> = self;
 
         trace!("decl = {:#?}", decl);
-        debug!(?body.coroutine_kind);
+        debug!(?coroutine_kind);
 
         let hir_id = self.tcx.local_def_id_to_hir_id(expr_def_id);
         let bound_vars = self.tcx.late_bound_vars(hir_id);
@@ -630,7 +634,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
         let supplied_arguments = decl.inputs.iter().map(|a| astconv.ast_ty_to_ty(a));
         let supplied_return = match decl.output {
             hir::FnRetTy::Return(ref output) => astconv.ast_ty_to_ty(output),
-            hir::FnRetTy::DefaultReturn(_) => match body.coroutine_kind {
+            hir::FnRetTy::DefaultReturn(_) => match coroutine_kind {
                 // In the case of the async block that we create for a function body,
                 // we expect the return type of the block to match that of the enclosing
                 // function.
@@ -639,19 +643,16 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                     hir::CoroutineSource::Fn,
                 )) => {
                     debug!("closure is async fn body");
-                    let def_id = self.tcx.hir().body_owner_def_id(body.id());
-                    self.deduce_future_output_from_obligations(expr_def_id, def_id).unwrap_or_else(
-                        || {
-                            // AFAIK, deducing the future output
-                            // always succeeds *except* in error cases
-                            // like #65159. I'd like to return Error
-                            // here, but I can't because I can't
-                            // easily (and locally) prove that we
-                            // *have* reported an
-                            // error. --nikomatsakis
-                            astconv.ty_infer(None, decl.output.span())
-                        },
-                    )
+                    self.deduce_future_output_from_obligations(expr_def_id).unwrap_or_else(|| {
+                        // AFAIK, deducing the future output
+                        // always succeeds *except* in error cases
+                        // like #65159. I'd like to return Error
+                        // here, but I can't because I can't
+                        // easily (and locally) prove that we
+                        // *have* reported an
+                        // error. --nikomatsakis
+                        astconv.ty_infer(None, decl.output.span())
+                    })
                 }
                 // All `gen {}` and `async gen {}` must return unit.
                 Some(
@@ -688,16 +689,12 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
     /// Future<Output = T>`, so we do this by searching through the
     /// obligations to extract the `T`.
     #[instrument(skip(self), level = "debug", ret)]
-    fn deduce_future_output_from_obligations(
-        &self,
-        expr_def_id: LocalDefId,
-        body_def_id: LocalDefId,
-    ) -> Option<Ty<'tcx>> {
+    fn deduce_future_output_from_obligations(&self, body_def_id: LocalDefId) -> Option<Ty<'tcx>> {
         let ret_coercion = self.ret_coercion.as_ref().unwrap_or_else(|| {
-            span_bug!(self.tcx.def_span(expr_def_id), "async fn coroutine outside of a fn")
+            span_bug!(self.tcx.def_span(body_def_id), "async fn coroutine outside of a fn")
         });
 
-        let closure_span = self.tcx.def_span(expr_def_id);
+        let closure_span = self.tcx.def_span(body_def_id);
         let ret_ty = ret_coercion.borrow().expected_ty();
         let ret_ty = self.try_structurally_resolve_type(closure_span, ret_ty);
 
@@ -842,12 +839,11 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
     fn closure_sigs(
         &self,
         expr_def_id: LocalDefId,
-        body: &hir::Body<'_>,
         bound_sig: ty::PolyFnSig<'tcx>,
     ) -> ClosureSignatures<'tcx> {
         let liberated_sig =
             self.tcx().liberate_late_bound_regions(expr_def_id.to_def_id(), bound_sig);
-        let liberated_sig = self.normalize(body.value.span, liberated_sig);
+        let liberated_sig = self.normalize(self.tcx.def_span(expr_def_id), liberated_sig);
         ClosureSignatures { bound_sig, liberated_sig }
     }
 }