Skip to content

Commit 970ff76

Browse files
committed
Add an abstraction for custom query caches
1 parent 86874d4 commit 970ff76

File tree

7 files changed

+388
-123
lines changed

7 files changed

+388
-123
lines changed

src/librustc/ty/query/caches.rs

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
use crate::dep_graph::DepNodeIndex;
2+
use crate::ty::query::config::QueryAccessors;
3+
use crate::ty::query::plumbing::{QueryLookup, QueryState, QueryStateShard};
4+
use crate::ty::TyCtxt;
5+
6+
use rustc_data_structures::fx::FxHashMap;
7+
use rustc_data_structures::sharded::Sharded;
8+
use rustc_hir::def_id::{DefId, DefIndex, LOCAL_CRATE};
9+
use rustc_index::vec::IndexVec;
10+
use std::cell::RefCell;
11+
use std::default::Default;
12+
use std::hash::Hash;
13+
14+
pub(crate) trait CacheSelector<K, V> {
15+
type Cache: QueryCache<K, V>;
16+
}
17+
18+
pub(crate) trait QueryCache<K, V>: Default {
19+
type Sharded: Default;
20+
21+
/// Checks if the query is already computed and in the cache.
22+
/// It returns the shard index and a lock guard to the shard,
23+
/// which will be used if the query is not in the cache and we need
24+
/// to compute it.
25+
fn lookup<'tcx, R, GetCache, OnHit, OnMiss, Q>(
26+
&self,
27+
state: &'tcx QueryState<'tcx, Q>,
28+
get_cache: GetCache,
29+
key: K,
30+
// `on_hit` can be called while holding a lock to the query state shard.
31+
on_hit: OnHit,
32+
on_miss: OnMiss,
33+
) -> R
34+
where
35+
Q: QueryAccessors<'tcx>,
36+
GetCache: for<'a> Fn(&'a mut QueryStateShard<'tcx, Q>) -> &'a mut Self::Sharded,
37+
OnHit: FnOnce(&V, DepNodeIndex) -> R,
38+
OnMiss: FnOnce(K, QueryLookup<'tcx, Q>) -> R;
39+
40+
fn complete(
41+
&self,
42+
tcx: TyCtxt<'tcx>,
43+
lock_sharded_storage: &mut Self::Sharded,
44+
key: K,
45+
value: V,
46+
index: DepNodeIndex,
47+
);
48+
49+
fn iter<R, L>(
50+
&self,
51+
shards: &Sharded<L>,
52+
get_shard: impl Fn(&mut L) -> &mut Self::Sharded,
53+
f: impl for<'a> FnOnce(Box<dyn Iterator<Item = (&'a K, &'a V, DepNodeIndex)> + 'a>) -> R,
54+
) -> R;
55+
}
56+
57+
pub struct DefaultCacheSelector;
58+
59+
impl<K: Eq + Hash, V: Clone> CacheSelector<K, V> for DefaultCacheSelector {
60+
type Cache = DefaultCache;
61+
}
62+
63+
#[derive(Default)]
64+
pub struct DefaultCache;
65+
66+
impl<K: Eq + Hash, V: Clone> QueryCache<K, V> for DefaultCache {
67+
type Sharded = FxHashMap<K, (V, DepNodeIndex)>;
68+
69+
#[inline(always)]
70+
fn lookup<'tcx, R, GetCache, OnHit, OnMiss, Q>(
71+
&self,
72+
state: &'tcx QueryState<'tcx, Q>,
73+
get_cache: GetCache,
74+
key: K,
75+
on_hit: OnHit,
76+
on_miss: OnMiss,
77+
) -> R
78+
where
79+
Q: QueryAccessors<'tcx>,
80+
GetCache: for<'a> Fn(&'a mut QueryStateShard<'tcx, Q>) -> &'a mut Self::Sharded,
81+
OnHit: FnOnce(&V, DepNodeIndex) -> R,
82+
OnMiss: FnOnce(K, QueryLookup<'tcx, Q>) -> R,
83+
{
84+
let mut lookup = state.get_lookup(&key);
85+
let lock = &mut *lookup.lock;
86+
87+
let result = get_cache(lock).raw_entry().from_key_hashed_nocheck(lookup.key_hash, &key);
88+
89+
if let Some((_, value)) = result { on_hit(&value.0, value.1) } else { on_miss(key, lookup) }
90+
}
91+
92+
#[inline]
93+
fn complete(
94+
&self,
95+
_: TyCtxt<'tcx>,
96+
lock_sharded_storage: &mut Self::Sharded,
97+
key: K,
98+
value: V,
99+
index: DepNodeIndex,
100+
) {
101+
lock_sharded_storage.insert(key, (value, index));
102+
}
103+
104+
fn iter<R, L>(
105+
&self,
106+
shards: &Sharded<L>,
107+
get_shard: impl Fn(&mut L) -> &mut Self::Sharded,
108+
f: impl for<'a> FnOnce(Box<dyn Iterator<Item = (&'a K, &'a V, DepNodeIndex)> + 'a>) -> R,
109+
) -> R {
110+
let mut shards = shards.lock_shards();
111+
let mut shards: Vec<_> = shards.iter_mut().map(|shard| get_shard(shard)).collect();
112+
let results = shards.iter_mut().flat_map(|shard| shard.iter()).map(|(k, v)| (k, &v.0, v.1));
113+
f(Box::new(results))
114+
}
115+
}
116+
117+
pub struct LocalDenseDefIdCache<V> {
118+
local: RefCell<IndexVec<DefIndex, Option<(V, DepNodeIndex)>>>,
119+
other: DefaultCache,
120+
}
121+
122+
impl<V> Default for LocalDenseDefIdCache<V> {
123+
fn default() -> Self {
124+
LocalDenseDefIdCache { local: RefCell::new(IndexVec::new()), other: Default::default() }
125+
}
126+
}
127+
128+
impl<V: Clone> QueryCache<DefId, V> for LocalDenseDefIdCache<V> {
129+
type Sharded = <DefaultCache as QueryCache<DefId, V>>::Sharded;
130+
131+
#[inline(always)]
132+
fn lookup<'tcx, R, GetCache, OnHit, OnMiss, Q>(
133+
&self,
134+
state: &'tcx QueryState<'tcx, Q>,
135+
get_cache: GetCache,
136+
key: DefId,
137+
on_hit: OnHit,
138+
on_miss: OnMiss,
139+
) -> R
140+
where
141+
Q: QueryAccessors<'tcx>,
142+
GetCache: for<'a> Fn(&'a mut QueryStateShard<'tcx, Q>) -> &'a mut Self::Sharded,
143+
OnHit: FnOnce(&V, DepNodeIndex) -> R,
144+
OnMiss: FnOnce(DefId, QueryLookup<'tcx, Q>) -> R,
145+
{
146+
if key.krate == LOCAL_CRATE {
147+
let local = self.local.borrow();
148+
if let Some(result) = local.get(key.index).and_then(|v| v.as_ref()) {
149+
on_hit(&result.0, result.1)
150+
} else {
151+
drop(local);
152+
let lookup = state.get_lookup(&key);
153+
on_miss(key, lookup)
154+
}
155+
} else {
156+
self.other.lookup(state, get_cache, key, on_hit, on_miss)
157+
}
158+
}
159+
160+
#[inline]
161+
fn complete(
162+
&self,
163+
tcx: TyCtxt<'tcx>,
164+
lock_sharded_storage: &mut Self::Sharded,
165+
key: DefId,
166+
value: V,
167+
index: DepNodeIndex,
168+
) {
169+
if key.krate == LOCAL_CRATE {
170+
let mut local = self.local.borrow_mut();
171+
if local.raw.capacity() == 0 {
172+
*local = IndexVec::from_elem_n(None, tcx.hir().definitions().def_index_count());
173+
}
174+
local[key.index] = Some((value, index));
175+
} else {
176+
self.other.complete(tcx, lock_sharded_storage, key, value, index);
177+
}
178+
}
179+
180+
fn iter<R, L>(
181+
&self,
182+
shards: &Sharded<L>,
183+
get_shard: impl Fn(&mut L) -> &mut Self::Sharded,
184+
f: impl for<'a> FnOnce(Box<dyn Iterator<Item = (&'a DefId, &'a V, DepNodeIndex)> + 'a>) -> R,
185+
) -> R {
186+
let local = self.local.borrow();
187+
let local: Vec<(DefId, &V, DepNodeIndex)> = local
188+
.iter_enumerated()
189+
.filter_map(|(i, e)| e.as_ref().map(|e| (DefId::local(i), &e.0, e.1)))
190+
.collect();
191+
self.other.iter(shards, get_shard, |results| {
192+
f(Box::new(results.chain(local.iter().map(|(id, v, i)| (id, *v, *i)))))
193+
})
194+
}
195+
}

src/librustc/ty/query/config.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
use crate::dep_graph::SerializedDepNodeIndex;
22
use crate::dep_graph::{DepKind, DepNode};
3+
use crate::ty::query::caches::QueryCache;
34
use crate::ty::query::plumbing::CycleError;
45
use crate::ty::query::queries;
5-
use crate::ty::query::{Query, QueryCache};
6+
use crate::ty::query::{Query, QueryState};
67
use crate::ty::TyCtxt;
78
use rustc_data_structures::profiling::ProfileCategory;
89
use rustc_hir::def_id::{CrateNum, DefId};
910

1011
use crate::ich::StableHashingContext;
1112
use rustc_data_structures::fingerprint::Fingerprint;
12-
use rustc_data_structures::sharded::Sharded;
1313
use std::borrow::Cow;
1414
use std::fmt::Debug;
1515
use std::hash::Hash;
@@ -30,10 +30,12 @@ pub(crate) trait QueryAccessors<'tcx>: QueryConfig<'tcx> {
3030
const ANON: bool;
3131
const EVAL_ALWAYS: bool;
3232

33+
type Cache: QueryCache<Self::Key, Self::Value>;
34+
3335
fn query(key: Self::Key) -> Query<'tcx>;
3436

3537
// Don't use this method to access query results, instead use the methods on TyCtxt
36-
fn query_cache<'a>(tcx: TyCtxt<'tcx>) -> &'a Sharded<QueryCache<'tcx, Self>>;
38+
fn query_state<'a>(tcx: TyCtxt<'tcx>) -> &'a QueryState<'tcx, Self>;
3739

3840
fn to_dep_node(tcx: TyCtxt<'tcx>, key: &Self::Key) -> DepNode;
3941

@@ -61,7 +63,10 @@ pub(crate) trait QueryDescription<'tcx>: QueryAccessors<'tcx> {
6163
}
6264
}
6365

64-
impl<'tcx, M: QueryAccessors<'tcx, Key = DefId>> QueryDescription<'tcx> for M {
66+
impl<'tcx, M: QueryAccessors<'tcx, Key = DefId>> QueryDescription<'tcx> for M
67+
where
68+
<M as QueryAccessors<'tcx>>::Cache: QueryCache<DefId, <M as QueryConfig<'tcx>>::Value>,
69+
{
6570
default fn describe(tcx: TyCtxt<'_>, def_id: DefId) -> Cow<'static, str> {
6671
if !tcx.sess.verbose() {
6772
format!("processing `{}`", tcx.def_path_str(def_id)).into()

0 commit comments

Comments
 (0)