From 128f2224afa6bf135a674fc14e5c825f2c76c93f Mon Sep 17 00:00:00 2001
From: Camille GILLOT <gillot.camille@gmail.com>
Date: Sat, 4 Feb 2023 15:16:59 +0000
Subject: [PATCH 1/2] Remove `OnHit` callback from query caches.

This is not useful now that query results are `Copy`.
---
 compiler/rustc_middle/src/ty/query.rs         | 62 ++++++----------
 .../rustc_query_system/src/query/caches.rs    | 73 +++++--------------
 .../rustc_query_system/src/query/config.rs    |  2 +-
 .../rustc_query_system/src/query/plumbing.rs  | 54 +++++++-------
 4 files changed, 66 insertions(+), 125 deletions(-)

diff --git a/compiler/rustc_middle/src/ty/query.rs b/compiler/rustc_middle/src/ty/query.rs
index 1be819ca610..7151b79c5ab 100644
--- a/compiler/rustc_middle/src/ty/query.rs
+++ b/compiler/rustc_middle/src/ty/query.rs
@@ -106,16 +106,6 @@ impl<'tcx> TyCtxt<'tcx> {
     }
 }
 
-/// Helper for `TyCtxtEnsure` to avoid a closure.
-#[inline(always)]
-fn noop<T>(_: &T) {}
-
-/// Helper to ensure that queries only return `Copy` types.
-#[inline(always)]
-fn copy<T: Copy>(x: &T) -> T {
-    *x
-}
-
 macro_rules! query_helper_param_ty {
     (DefId) => { impl IntoQueryParam<DefId> };
     (LocalDefId) => { impl IntoQueryParam<LocalDefId> };
@@ -225,14 +215,10 @@ macro_rules! define_callbacks {
                 let key = key.into_query_param();
                 opt_remap_env_constness!([$($modifiers)*][key]);
 
-                let cached = try_get_cached(self.tcx, &self.tcx.query_caches.$name, &key, noop);
-
-                match cached {
-                    Ok(()) => return,
-                    Err(()) => (),
-                }
-
-                self.tcx.queries.$name(self.tcx, DUMMY_SP, key, QueryMode::Ensure);
+                match try_get_cached(self.tcx, &self.tcx.query_caches.$name, &key) {
+                    Some(_) => return,
+                    None => self.tcx.queries.$name(self.tcx, DUMMY_SP, key, QueryMode::Ensure),
+                };
             })*
         }
 
@@ -254,14 +240,10 @@ macro_rules! define_callbacks {
                 let key = key.into_query_param();
                 opt_remap_env_constness!([$($modifiers)*][key]);
 
-                let cached = try_get_cached(self.tcx, &self.tcx.query_caches.$name, &key, copy);
-
-                match cached {
-                    Ok(value) => return value,
-                    Err(()) => (),
+                match try_get_cached(self.tcx, &self.tcx.query_caches.$name, &key) {
+                    Some(value) => value,
+                    None => self.tcx.queries.$name(self.tcx, self.span, key, QueryMode::Get).unwrap(),
                 }
-
-                self.tcx.queries.$name(self.tcx, self.span, key, QueryMode::Get).unwrap()
             })*
         }
 
@@ -353,27 +335,25 @@ macro_rules! define_feedable {
                 let tcx = self.tcx;
                 let cache = &tcx.query_caches.$name;
 
-                let cached = try_get_cached(tcx, cache, &key, copy);
-
-                match cached {
-                    Ok(old) => {
+                match try_get_cached(tcx, cache, &key) {
+                    Some(old) => {
                         bug!(
                             "Trying to feed an already recorded value for query {} key={key:?}:\nold value: {old:?}\nnew value: {value:?}",
                             stringify!($name),
-                        );
+                        )
+                    }
+                    None => {
+                        let dep_node = dep_graph::DepNode::construct(tcx, dep_graph::DepKind::$name, &key);
+                        let dep_node_index = tcx.dep_graph.with_feed_task(
+                            dep_node,
+                            tcx,
+                            key,
+                            &value,
+                            hash_result!([$($modifiers)*]),
+                        );
+                        cache.complete(key, value, dep_node_index)
                     }
-                    Err(()) => (),
                 }
-
-                let dep_node = dep_graph::DepNode::construct(tcx, dep_graph::DepKind::$name, &key);
-                let dep_node_index = tcx.dep_graph.with_feed_task(
-                    dep_node,
-                    tcx,
-                    key,
-                    &value,
-                    hash_result!([$($modifiers)*]),
-                );
-                cache.complete(key, value, dep_node_index)
             }
         })*
     }
diff --git a/compiler/rustc_query_system/src/query/caches.rs b/compiler/rustc_query_system/src/query/caches.rs
index 77d0d0314fc..21c89cbc4f1 100644
--- a/compiler/rustc_query_system/src/query/caches.rs
+++ b/compiler/rustc_query_system/src/query/caches.rs
@@ -16,13 +16,13 @@ use std::marker::PhantomData;
 pub trait CacheSelector<'tcx, V> {
     type Cache
     where
-        V: Clone;
+        V: Copy;
     type ArenaCache;
 }
 
 pub trait QueryStorage {
     type Value: Debug;
-    type Stored: Clone;
+    type Stored: Copy;
 
     /// Store a value without putting it in the cache.
     /// This is meant to be used with cycle errors.
@@ -36,14 +36,7 @@ pub trait QueryCache: QueryStorage + Sized {
     /// It returns the shard index and a lock guard to the shard,
     /// which will be used if the query is not in the cache and we need
     /// to compute it.
-    fn lookup<R, OnHit>(
-        &self,
-        key: &Self::Key,
-        // `on_hit` can be called while holding a lock to the query state shard.
-        on_hit: OnHit,
-    ) -> Result<R, ()>
-    where
-        OnHit: FnOnce(&Self::Stored, DepNodeIndex) -> R;
+    fn lookup(&self, key: &Self::Key) -> Option<(Self::Stored, DepNodeIndex)>;
 
     fn complete(&self, key: Self::Key, value: Self::Value, index: DepNodeIndex) -> Self::Stored;
 
@@ -55,7 +48,7 @@ pub struct DefaultCacheSelector<K>(PhantomData<K>);
 impl<'tcx, K: Eq + Hash, V: 'tcx> CacheSelector<'tcx, V> for DefaultCacheSelector<K> {
     type Cache = DefaultCache<K, V>
     where
-        V: Clone;
+        V: Copy;
     type ArenaCache = ArenaCache<'tcx, K, V>;
 }
 
@@ -72,7 +65,7 @@ impl<K, V> Default for DefaultCache<K, V> {
     }
 }
 
-impl<K: Eq + Hash, V: Clone + Debug> QueryStorage for DefaultCache<K, V> {
+impl<K: Eq + Hash, V: Copy + Debug> QueryStorage for DefaultCache<K, V> {
     type Value = V;
     type Stored = V;
 
@@ -86,15 +79,12 @@ impl<K: Eq + Hash, V: Clone + Debug> QueryStorage for DefaultCache<K, V> {
 impl<K, V> QueryCache for DefaultCache<K, V>
 where
     K: Eq + Hash + Clone + Debug,
-    V: Clone + Debug,
+    V: Copy + Debug,
 {
     type Key = K;
 
     #[inline(always)]
-    fn lookup<R, OnHit>(&self, key: &K, on_hit: OnHit) -> Result<R, ()>
-    where
-        OnHit: FnOnce(&V, DepNodeIndex) -> R,
-    {
+    fn lookup(&self, key: &K) -> Option<(V, DepNodeIndex)> {
         let key_hash = sharded::make_hash(key);
         #[cfg(parallel_compiler)]
         let lock = self.cache.get_shard_by_hash(key_hash).lock();
@@ -102,12 +92,7 @@ where
         let lock = self.cache.lock();
         let result = lock.raw_entry().from_key_hashed_nocheck(key_hash, key);
 
-        if let Some((_, value)) = result {
-            let hit_result = on_hit(&value.0, value.1);
-            Ok(hit_result)
-        } else {
-            Err(())
-        }
+        if let Some((_, value)) = result { Some(*value) } else { None }
     }
 
     #[inline]
@@ -176,10 +161,7 @@ where
     type Key = K;
 
     #[inline(always)]
-    fn lookup<R, OnHit>(&self, key: &K, on_hit: OnHit) -> Result<R, ()>
-    where
-        OnHit: FnOnce(&&'tcx V, DepNodeIndex) -> R,
-    {
+    fn lookup(&self, key: &K) -> Option<(&'tcx V, DepNodeIndex)> {
         let key_hash = sharded::make_hash(key);
         #[cfg(parallel_compiler)]
         let lock = self.cache.get_shard_by_hash(key_hash).lock();
@@ -187,12 +169,7 @@ where
         let lock = self.cache.lock();
         let result = lock.raw_entry().from_key_hashed_nocheck(key_hash, key);
 
-        if let Some((_, value)) = result {
-            let hit_result = on_hit(&&value.0, value.1);
-            Ok(hit_result)
-        } else {
-            Err(())
-        }
+        if let Some((_, value)) = result { Some((&value.0, value.1)) } else { None }
     }
 
     #[inline]
@@ -234,7 +211,7 @@ pub struct VecCacheSelector<K>(PhantomData<K>);
 impl<'tcx, K: Idx, V: 'tcx> CacheSelector<'tcx, V> for VecCacheSelector<K> {
     type Cache = VecCache<K, V>
     where
-        V: Clone;
+        V: Copy;
     type ArenaCache = VecArenaCache<'tcx, K, V>;
 }
 
@@ -251,7 +228,7 @@ impl<K: Idx, V> Default for VecCache<K, V> {
     }
 }
 
-impl<K: Eq + Idx, V: Clone + Debug> QueryStorage for VecCache<K, V> {
+impl<K: Eq + Idx, V: Copy + Debug> QueryStorage for VecCache<K, V> {
     type Value = V;
     type Stored = V;
 
@@ -265,25 +242,17 @@ impl<K: Eq + Idx, V: Clone + Debug> QueryStorage for VecCache<K, V> {
 impl<K, V> QueryCache for VecCache<K, V>
 where
     K: Eq + Idx + Clone + Debug,
-    V: Clone + Debug,
+    V: Copy + Debug,
 {
     type Key = K;
 
     #[inline(always)]
-    fn lookup<R, OnHit>(&self, key: &K, on_hit: OnHit) -> Result<R, ()>
-    where
-        OnHit: FnOnce(&V, DepNodeIndex) -> R,
-    {
+    fn lookup(&self, key: &K) -> Option<(V, DepNodeIndex)> {
         #[cfg(parallel_compiler)]
         let lock = self.cache.get_shard_by_hash(key.index() as u64).lock();
         #[cfg(not(parallel_compiler))]
         let lock = self.cache.lock();
-        if let Some(Some(value)) = lock.get(*key) {
-            let hit_result = on_hit(&value.0, value.1);
-            Ok(hit_result)
-        } else {
-            Err(())
-        }
+        if let Some(Some(value)) = lock.get(*key) { Some(*value) } else { None }
     }
 
     #[inline]
@@ -357,20 +326,12 @@ where
     type Key = K;
 
     #[inline(always)]
-    fn lookup<R, OnHit>(&self, key: &K, on_hit: OnHit) -> Result<R, ()>
-    where
-        OnHit: FnOnce(&&'tcx V, DepNodeIndex) -> R,
-    {
+    fn lookup(&self, key: &K) -> Option<(&'tcx V, DepNodeIndex)> {
         #[cfg(parallel_compiler)]
         let lock = self.cache.get_shard_by_hash(key.index() as u64).lock();
         #[cfg(not(parallel_compiler))]
         let lock = self.cache.lock();
-        if let Some(Some(value)) = lock.get(*key) {
-            let hit_result = on_hit(&&value.0, value.1);
-            Ok(hit_result)
-        } else {
-            Err(())
-        }
+        if let Some(Some(value)) = lock.get(*key) { Some((&value.0, value.1)) } else { None }
     }
 
     #[inline]
diff --git a/compiler/rustc_query_system/src/query/config.rs b/compiler/rustc_query_system/src/query/config.rs
index 8c0330e438d..a28e45a5c08 100644
--- a/compiler/rustc_query_system/src/query/config.rs
+++ b/compiler/rustc_query_system/src/query/config.rs
@@ -21,7 +21,7 @@ pub trait QueryConfig<Qcx: QueryContext> {
 
     type Key: DepNodeParams<Qcx::DepContext> + Eq + Hash + Clone + Debug;
     type Value: Debug;
-    type Stored: Debug + Clone + std::borrow::Borrow<Self::Value>;
+    type Stored: Debug + Copy + std::borrow::Borrow<Self::Value>;
 
     type Cache: QueryCache<Key = Self::Key, Stored = Self::Stored, Value = Self::Value>;
 
diff --git a/compiler/rustc_query_system/src/query/plumbing.rs b/compiler/rustc_query_system/src/query/plumbing.rs
index b3b939eae88..bf380f6e2d3 100644
--- a/compiler/rustc_query_system/src/query/plumbing.rs
+++ b/compiler/rustc_query_system/src/query/plumbing.rs
@@ -130,7 +130,7 @@ fn mk_cycle<Qcx, V, R, D: DepKind>(
 where
     Qcx: QueryContext + crate::query::HasDepContext<DepKind = D>,
     V: std::fmt::Debug + Value<Qcx::DepContext, Qcx::DepKind>,
-    R: Clone,
+    R: Copy,
 {
     let error = report_cycle(qcx.dep_context().sess(), &cycle_error);
     let value = handle_cycle_error(*qcx.dep_context(), &cycle_error, error, handler);
@@ -339,25 +339,21 @@ where
 /// which will be used if the query is not in the cache and we need
 /// to compute it.
 #[inline]
-pub fn try_get_cached<Tcx, C, R, OnHit>(
-    tcx: Tcx,
-    cache: &C,
-    key: &C::Key,
-    // `on_hit` can be called while holding a lock to the query cache
-    on_hit: OnHit,
-) -> Result<R, ()>
+pub fn try_get_cached<Tcx, C>(tcx: Tcx, cache: &C, key: &C::Key) -> Option<C::Stored>
 where
     C: QueryCache,
     Tcx: DepContext,
-    OnHit: FnOnce(&C::Stored) -> R,
 {
-    cache.lookup(&key, |value, index| {
-        if std::intrinsics::unlikely(tcx.profiler().enabled()) {
-            tcx.profiler().query_cache_hit(index.into());
+    match cache.lookup(&key) {
+        Some((value, index)) => {
+            if std::intrinsics::unlikely(tcx.profiler().enabled()) {
+                tcx.profiler().query_cache_hit(index.into());
+            }
+            tcx.dep_graph().read_index(index);
+            Some(value)
         }
-        tcx.dep_graph().read_index(index);
-        on_hit(value)
-    })
+        None => None,
+    }
 }
 
 fn try_execute_query<Q, Qcx>(
@@ -379,17 +375,25 @@ where
             if Q::FEEDABLE {
                 // We may have put a value inside the cache from inside the execution.
                 // Verify that it has the same hash as what we have now, to ensure consistency.
-                let _ = cache.lookup(&key, |cached_result, _| {
+                if let Some((cached_result, _)) = cache.lookup(&key) {
                     let hasher = Q::HASH_RESULT.expect("feedable forbids no_hash");
 
-                    let old_hash = qcx.dep_context().with_stable_hashing_context(|mut hcx| hasher(&mut hcx, cached_result.borrow()));
-                    let new_hash = qcx.dep_context().with_stable_hashing_context(|mut hcx| hasher(&mut hcx, &result));
+                    let old_hash = qcx.dep_context().with_stable_hashing_context(|mut hcx| {
+                        hasher(&mut hcx, cached_result.borrow())
+                    });
+                    let new_hash = qcx
+                        .dep_context()
+                        .with_stable_hashing_context(|mut hcx| hasher(&mut hcx, &result));
                     debug_assert_eq!(
-                        old_hash, new_hash,
+                        old_hash,
+                        new_hash,
                         "Computed query value for {:?}({:?}) is inconsistent with fed value,\ncomputed={:#?}\nfed={:#?}",
-                        Q::DEP_KIND, key, result, cached_result,
+                        Q::DEP_KIND,
+                        key,
+                        result,
+                        cached_result,
                     );
-                });
+                }
             }
             let result = job.complete(cache, result, dep_node_index);
             (result, Some(dep_node_index))
@@ -771,15 +775,11 @@ where
     // We may be concurrently trying both execute and force a query.
     // Ensure that only one of them runs the query.
     let cache = Q::query_cache(qcx);
-    let cached = cache.lookup(&key, |_, index| {
+    if let Some((_, index)) = cache.lookup(&key) {
         if std::intrinsics::unlikely(qcx.dep_context().profiler().enabled()) {
             qcx.dep_context().profiler().query_cache_hit(index.into());
         }
-    });
-
-    match cached {
-        Ok(()) => return,
-        Err(()) => {}
+        return;
     }
 
     let state = Q::query_state(qcx);

From 635ff8e2a8ee1acc3bc5144606d9d12f9ac62d98 Mon Sep 17 00:00:00 2001
From: Camille GILLOT <gillot.camille@gmail.com>
Date: Sat, 4 Feb 2023 15:56:21 +0000
Subject: [PATCH 2/2] Support parallel compiler.

---
 compiler/rustc_query_system/src/query/plumbing.rs | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/compiler/rustc_query_system/src/query/plumbing.rs b/compiler/rustc_query_system/src/query/plumbing.rs
index bf380f6e2d3..ffc413d15f5 100644
--- a/compiler/rustc_query_system/src/query/plumbing.rs
+++ b/compiler/rustc_query_system/src/query/plumbing.rs
@@ -404,9 +404,9 @@ where
         }
         #[cfg(parallel_compiler)]
         TryGetJob::JobCompleted(query_blocked_prof_timer) => {
-            let (v, index) = cache
-                .lookup(&key, |value, index| (value.clone(), index))
-                .unwrap_or_else(|_| panic!("value must be in cache after waiting"));
+            let Some((v, index)) = cache.lookup(&key) else {
+                panic!("value must be in cache after waiting")
+            };
 
             if std::intrinsics::unlikely(qcx.dep_context().profiler().enabled()) {
                 qcx.dep_context().profiler().query_cache_hit(index.into());