Skip to content

Commit ec50f58

Browse files
authored
[MLIR][Support] A cache for cyclical replacers/maps (#98202)
This is a support data structure that acts as a cache for replacer-like functions that map values between two domains. The difference compared to just using a map to cache in-out pairs is that this class is able to handle replacer logic that is self-recursive (and thus may cause infinite recursion in the naive case). This class provides a hook for the user to perform cycle pruning when a cycle is identified, and is able to perform context-sensitive caching so that the replacement result for an input that is part of a pruned cycle can be distinct from the replacement result for the same input when it is not part of a cycle. In addition, this class allows deferring cycle pruning until specific inputs are repeated. This is useful for cases where not all elements in a cycle can perform pruning. The user still must guarantee that at least one element in any given cycle can perform pruning. Even if not, an assertion will eventually be tripped instead of infinite recursion (the run-time is linearly bounded by the maximum cycle length of its input).
1 parent d22a419 commit ec50f58

File tree

3 files changed

+759
-0
lines changed

3 files changed

+759
-0
lines changed
Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
//===- CyclicReplacerCache.h ------------------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file contains helper classes for caching replacer-like functions that
10+
// map values between two domains. They are able to handle replacer logic that
11+
// contains self-recursion.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#ifndef MLIR_SUPPORT_CYCLICREPLACERCACHE_H
16+
#define MLIR_SUPPORT_CYCLICREPLACERCACHE_H
17+
18+
#include "mlir/IR/Visitors.h"
19+
#include "llvm/ADT/DenseSet.h"
20+
#include "llvm/ADT/MapVector.h"
21+
#include <set>
22+
23+
namespace mlir {
24+
25+
//===----------------------------------------------------------------------===//
26+
// CyclicReplacerCache
27+
//===----------------------------------------------------------------------===//
28+
29+
/// A cache for replacer-like functions that map values between two domains. The
30+
/// difference compared to just using a map to cache in-out pairs is that this
31+
/// class is able to handle replacer logic that is self-recursive (and thus may
32+
/// cause infinite recursion in the naive case).
33+
///
34+
/// This class provides a hook for the user to perform cycle pruning when a
35+
/// cycle is identified, and is able to perform context-sensitive caching so
36+
/// that the replacement result for an input that is part of a pruned cycle can
37+
/// be distinct from the replacement result for the same input when it is not
38+
/// part of a cycle.
39+
///
40+
/// In addition, this class allows deferring cycle pruning until specific inputs
41+
/// are repeated. This is useful for cases where not all elements in a cycle can
42+
/// perform pruning. The user still must guarantee that at least one element in
43+
/// any given cycle can perform pruning. Even if not, an assertion will
44+
/// eventually be tripped instead of infinite recursion (the run-time is
45+
/// linearly bounded by the maximum cycle length of its input).
46+
///
47+
/// WARNING: This class works best with InT & OutT that are trivial scalar
48+
/// types. The input/output elements will be frequently copied and hashed.
49+
template <typename InT, typename OutT>
50+
class CyclicReplacerCache {
51+
public:
52+
/// User-provided replacement function & cycle-breaking functions.
53+
/// The cycle-breaking function must not make any more recursive invocations
54+
/// to this cached replacer.
55+
using CycleBreakerFn = std::function<std::optional<OutT>(InT)>;
56+
57+
CyclicReplacerCache() = delete;
58+
CyclicReplacerCache(CycleBreakerFn cycleBreaker)
59+
: cycleBreaker(std::move(cycleBreaker)) {}
60+
61+
/// A possibly unresolved cache entry.
62+
/// If unresolved, the entry must be resolved before it goes out of scope.
63+
struct CacheEntry {
64+
public:
65+
~CacheEntry() { assert(result && "unresovled cache entry"); }
66+
67+
/// Check whether this node was repeated during recursive replacements.
68+
/// This only makes sense to be called after all recursive replacements are
69+
/// completed and the current element has resurfaced to the top of the
70+
/// replacement stack.
71+
bool wasRepeated() const {
72+
// If the top frame includes itself as a dependency, then it must have
73+
// been repeated.
74+
ReplacementFrame &currFrame = cache.replacementStack.back();
75+
size_t currFrameIndex = cache.replacementStack.size() - 1;
76+
return currFrame.dependentFrames.count(currFrameIndex);
77+
}
78+
79+
/// Resolve an unresolved cache entry by providing the result to be stored
80+
/// in the cache.
81+
void resolve(OutT result) {
82+
assert(!this->result && "cache entry already resolved");
83+
cache.finalizeReplacement(element, result);
84+
this->result = std::move(result);
85+
}
86+
87+
/// Get the resolved result if one exists.
88+
const std::optional<OutT> &get() const { return result; }
89+
90+
private:
91+
friend class CyclicReplacerCache;
92+
CacheEntry() = delete;
93+
CacheEntry(CyclicReplacerCache<InT, OutT> &cache, InT element,
94+
std::optional<OutT> result = std::nullopt)
95+
: cache(cache), element(std::move(element)), result(result) {}
96+
97+
CyclicReplacerCache<InT, OutT> &cache;
98+
InT element;
99+
std::optional<OutT> result;
100+
};
101+
102+
/// Lookup the cache for a pre-calculated replacement for `element`.
103+
/// If one exists, a resolved CacheEntry will be returned. Otherwise, an
104+
/// unresolved CacheEntry will be returned, and the caller must resolve it
105+
/// with the calculated replacement so it can be registered in the cache for
106+
/// future use.
107+
/// Multiple unresolved CacheEntries may be retrieved. However, any unresolved
108+
/// CacheEntries that are returned must be resolved in reverse order of
109+
/// retrieval, i.e. the last retrieved CacheEntry must be resolved first, and
110+
/// the first retrieved CacheEntry must be resolved last. This should be
111+
/// natural when used as a stack / inside recursion.
112+
CacheEntry lookupOrInit(InT element);
113+
114+
private:
115+
/// Register the replacement in the cache and update the replacementStack.
116+
void finalizeReplacement(InT element, OutT result);
117+
118+
CycleBreakerFn cycleBreaker;
119+
DenseMap<InT, OutT> standaloneCache;
120+
121+
struct DependentReplacement {
122+
OutT replacement;
123+
/// The highest replacement frame index that this cache entry is dependent
124+
/// on.
125+
size_t highestDependentFrame;
126+
};
127+
DenseMap<InT, DependentReplacement> dependentCache;
128+
129+
struct ReplacementFrame {
130+
/// The set of elements that is only legal while under this current frame.
131+
/// They need to be removed from the cache when this frame is popped off the
132+
/// replacement stack.
133+
DenseSet<InT> dependingReplacements;
134+
/// The set of frame indices that this current frame's replacement is
135+
/// dependent on, ordered from highest to lowest.
136+
std::set<size_t, std::greater<size_t>> dependentFrames;
137+
};
138+
/// Every element currently in the progress of being replaced pushes a frame
139+
/// onto this stack.
140+
SmallVector<ReplacementFrame> replacementStack;
141+
/// Maps from each input element to its indices on the replacement stack.
142+
DenseMap<InT, SmallVector<size_t, 2>> cyclicElementFrame;
143+
/// If set to true, we are currently asking an element to break a cycle. No
144+
/// more recursive invocations is allowed while this is true (the replacement
145+
/// stack can no longer grow).
146+
bool resolvingCycle = false;
147+
};
148+
149+
template <typename InT, typename OutT>
150+
typename CyclicReplacerCache<InT, OutT>::CacheEntry
151+
CyclicReplacerCache<InT, OutT>::lookupOrInit(InT element) {
152+
assert(!resolvingCycle &&
153+
"illegal recursive invocation while breaking cycle");
154+
155+
if (auto it = standaloneCache.find(element); it != standaloneCache.end())
156+
return CacheEntry(*this, element, it->second);
157+
158+
if (auto it = dependentCache.find(element); it != dependentCache.end()) {
159+
// Update the current top frame (the element that invoked this current
160+
// replacement) to include any dependencies the cache entry had.
161+
ReplacementFrame &currFrame = replacementStack.back();
162+
currFrame.dependentFrames.insert(it->second.highestDependentFrame);
163+
return CacheEntry(*this, element, it->second.replacement);
164+
}
165+
166+
auto [it, inserted] = cyclicElementFrame.try_emplace(element);
167+
if (!inserted) {
168+
// This is a repeat of a known element. Try to break cycle here.
169+
resolvingCycle = true;
170+
std::optional<OutT> result = cycleBreaker(element);
171+
resolvingCycle = false;
172+
if (result) {
173+
// Cycle was broken.
174+
size_t dependentFrame = it->second.back();
175+
dependentCache[element] = {*result, dependentFrame};
176+
ReplacementFrame &currFrame = replacementStack.back();
177+
// If this is a repeat, there is no replacement frame to pop. Mark the top
178+
// frame as being dependent on this element.
179+
currFrame.dependentFrames.insert(dependentFrame);
180+
181+
return CacheEntry(*this, element, *result);
182+
}
183+
184+
// Cycle could not be broken.
185+
// A legal setup must ensure at least one element of each cycle can break
186+
// cycles. Under this setup, each element can be seen at most twice before
187+
// the cycle is broken. If we see an element more than twice, we know this
188+
// is an illegal setup.
189+
assert(it->second.size() <= 2 && "illegal 3rd repeat of input");
190+
}
191+
192+
// Otherwise, either this is the first time we see this element, or this
193+
// element could not break this cycle.
194+
it->second.push_back(replacementStack.size());
195+
replacementStack.emplace_back();
196+
197+
return CacheEntry(*this, element);
198+
}
199+
200+
template <typename InT, typename OutT>
201+
void CyclicReplacerCache<InT, OutT>::finalizeReplacement(InT element,
202+
OutT result) {
203+
ReplacementFrame &currFrame = replacementStack.back();
204+
// With the conclusion of this replacement frame, the current element is no
205+
// longer a dependent element.
206+
currFrame.dependentFrames.erase(replacementStack.size() - 1);
207+
208+
auto prevLayerIter = ++replacementStack.rbegin();
209+
if (prevLayerIter == replacementStack.rend()) {
210+
// If this is the last frame, there should be zero dependents.
211+
assert(currFrame.dependentFrames.empty() &&
212+
"internal error: top-level dependent replacement");
213+
// Cache standalone result.
214+
standaloneCache[element] = result;
215+
} else if (currFrame.dependentFrames.empty()) {
216+
// Cache standalone result.
217+
standaloneCache[element] = result;
218+
} else {
219+
// Cache dependent result.
220+
size_t highestDependentFrame = *currFrame.dependentFrames.begin();
221+
dependentCache[element] = {result, highestDependentFrame};
222+
223+
// Otherwise, the previous frame inherits the same dependent frames.
224+
prevLayerIter->dependentFrames.insert(currFrame.dependentFrames.begin(),
225+
currFrame.dependentFrames.end());
226+
227+
// Mark this current replacement as a depending replacement on the closest
228+
// dependent frame.
229+
replacementStack[highestDependentFrame].dependingReplacements.insert(
230+
element);
231+
}
232+
233+
// All depending replacements in the cache must be purged.
234+
for (InT key : currFrame.dependingReplacements)
235+
dependentCache.erase(key);
236+
237+
replacementStack.pop_back();
238+
auto it = cyclicElementFrame.find(element);
239+
it->second.pop_back();
240+
if (it->second.empty())
241+
cyclicElementFrame.erase(it);
242+
}
243+
244+
//===----------------------------------------------------------------------===//
245+
// CachedCyclicReplacer
246+
//===----------------------------------------------------------------------===//
247+
248+
/// A helper class for cases where the input/output types of the replacer
249+
/// function is identical to the types stored in the cache. This class wraps
250+
/// the user-provided replacer function, and can be used in place of the user
251+
/// function.
252+
template <typename InT, typename OutT>
253+
class CachedCyclicReplacer {
254+
public:
255+
using ReplacerFn = std::function<OutT(InT)>;
256+
using CycleBreakerFn =
257+
typename CyclicReplacerCache<InT, OutT>::CycleBreakerFn;
258+
259+
CachedCyclicReplacer() = delete;
260+
CachedCyclicReplacer(ReplacerFn replacer, CycleBreakerFn cycleBreaker)
261+
: replacer(std::move(replacer)), cache(std::move(cycleBreaker)) {}
262+
263+
OutT operator()(InT element) {
264+
auto cacheEntry = cache.lookupOrInit(element);
265+
if (std::optional<OutT> result = cacheEntry.get())
266+
return *result;
267+
268+
OutT result = replacer(element);
269+
cacheEntry.resolve(result);
270+
return result;
271+
}
272+
273+
private:
274+
ReplacerFn replacer;
275+
CyclicReplacerCache<InT, OutT> cache;
276+
};
277+
278+
} // namespace mlir
279+
280+
#endif // MLIR_SUPPORT_CYCLICREPLACERCACHE_H

mlir/unittests/Support/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_mlir_unittest(MLIRSupportTests
2+
CyclicReplacerCacheTest.cpp
23
IndentedOstreamTest.cpp
34
StorageUniquerTest.cpp
45
)

0 commit comments

Comments
 (0)