Skip to content

Commit bcc3eda

Browse files
committed
add replacer cache and gtests
1 parent 0162386 commit bcc3eda

File tree

3 files changed

+750
-0
lines changed

3 files changed

+750
-0
lines changed
Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
//===- CyclicReplacerCache.h ------------------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file contains helper classes for caching replacer-like functions that
10+
// map values between two domains. They are able to handle replacer logic that
11+
// contains self-recursion.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#ifndef MLIR_SUPPORT_CACHINGREPLACER_H
16+
#define MLIR_SUPPORT_CACHINGREPLACER_H
17+
18+
#include "mlir/IR/Visitors.h"
19+
#include "llvm/ADT/DenseSet.h"
20+
#include "llvm/ADT/MapVector.h"
21+
#include <set>
22+
23+
namespace mlir {
24+
25+
//===----------------------------------------------------------------------===//
26+
// CyclicReplacerCache
27+
//===----------------------------------------------------------------------===//
28+
29+
/// A cache for replacer-like functions that map values between two domains. The
30+
/// difference compared to just using a map to cache in-out pairs is that this
31+
/// class is able to handle replacer logic that is self-recursive (and thus may
32+
/// cause infinite recursion in the naive case).
33+
///
34+
/// This class provides a hook for the user to perform cycle pruning when a
35+
/// cycle is identified, and is able to perform context-sensitive caching so
36+
/// that the replacement result for an input that is part of a pruned cycle can
37+
/// be distinct from the replacement result for the same input when it is not
38+
/// part of a cycle.
39+
///
40+
/// In addition, this class allows deferring cycle pruning until specific inputs
41+
/// are repeated. This is useful for cases where not all elements in a cycle can
42+
/// perform pruning. The user still must guarantee that at least one element in
43+
/// any given cycle can perform pruning. Even if not, an assertion will
44+
/// eventually be tripped instead of infinite recursion (the run-time is
45+
/// linearly bounded by the maximum cycle length of its input).
46+
template <typename InT, typename OutT>
47+
class CyclicReplacerCache {
48+
public:
49+
/// User-provided replacement function & cycle-breaking functions.
50+
/// The cycle-breaking function must not make any more recursive invocations
51+
/// to this cached replacer.
52+
using CycleBreakerFn = std::function<std::optional<OutT>(const InT &)>;
53+
54+
CyclicReplacerCache() = delete;
55+
CyclicReplacerCache(CycleBreakerFn cycleBreaker)
56+
: cycleBreaker(std::move(cycleBreaker)) {}
57+
58+
/// A possibly unresolved cache entry.
59+
/// If unresolved, the entry must be resolved before it goes out of scope.
60+
struct CacheEntry {
61+
public:
62+
~CacheEntry() { assert(result && "unresovled cache entry"); }
63+
64+
/// Check whether this node was repeated during recursive replacements.
65+
/// This only makes sense to be called after all recursive replacements are
66+
/// completed and the current element has resurfaced to the top of the
67+
/// replacement stack.
68+
bool wasRepeated() const {
69+
// If the top frame includes itself as a dependency, then it must have
70+
// been repeated.
71+
ReplacementFrame &currFrame = cache.replacementStack.back();
72+
size_t currFrameIndex = cache.replacementStack.size() - 1;
73+
return currFrame.dependentFrames.count(currFrameIndex);
74+
}
75+
76+
/// Resolve an unresolved cache entry by providing the result to be stored
77+
/// in the cache.
78+
void resolve(OutT result) {
79+
assert(!this->result && "cache entry already resolved");
80+
this->result = result;
81+
cache.finalizeReplacement(element, result);
82+
}
83+
84+
/// Get the resolved result if one exists.
85+
std::optional<OutT> get() { return result; }
86+
87+
private:
88+
friend class CyclicReplacerCache;
89+
CacheEntry() = delete;
90+
CacheEntry(CyclicReplacerCache<InT, OutT> &cache, InT element,
91+
std::optional<OutT> result = std::nullopt)
92+
: cache(cache), element(element), result(result) {}
93+
94+
CyclicReplacerCache<InT, OutT> &cache;
95+
InT element;
96+
std::optional<OutT> result;
97+
};
98+
99+
/// Lookup the cache for a pre-calculated replacement for `element`.
100+
/// If one exists, a resolved CacheEntry will be returned. Otherwise, an
101+
/// unresolved CacheEntry will be returned, and the caller must resolve it
102+
/// with the calculated replacement so it can be registered in the cache for
103+
/// future use.
104+
/// Multiple unresolved CacheEntries may be retrieved. However, any unresolved
105+
/// CacheEntries that are returned must be resolved in reverse order of
106+
/// retrieval, i.e. the last retrieved CacheEntry must be resolved first, and
107+
/// the first retrieved CacheEntry must be resolved last. This should be
108+
/// natural when used as a stack / inside recursion.
109+
CacheEntry lookupOrInit(const InT &element);
110+
111+
private:
112+
/// Register the replacement in the cache and update the replacementStack.
113+
void finalizeReplacement(const InT &element, const OutT &result);
114+
115+
CycleBreakerFn cycleBreaker;
116+
DenseMap<InT, OutT> standaloneCache;
117+
118+
struct DependentReplacement {
119+
OutT replacement;
120+
/// The highest replacement frame index that this cache entry is dependent
121+
/// on.
122+
size_t highestDependentFrame;
123+
};
124+
DenseMap<InT, DependentReplacement> dependentCache;
125+
126+
struct ReplacementFrame {
127+
/// The set of elements that is only legal while under this current frame.
128+
/// They need to be removed from the cache when this frame is popped off the
129+
/// replacement stack.
130+
DenseSet<InT> dependingReplacements;
131+
/// The set of frame indices that this current frame's replacement is
132+
/// dependent on, ordered from highest to lowest.
133+
std::set<size_t, std::greater<size_t>> dependentFrames;
134+
};
135+
/// Every element currently in the progress of being replaced pushes a frame
136+
/// onto this stack.
137+
SmallVector<ReplacementFrame> replacementStack;
138+
/// Maps from each input element to its indices on the replacement stack.
139+
DenseMap<InT, SmallVector<size_t, 2>> cyclicElementFrame;
140+
/// If set to true, we are currently asking an element to break a cycle. No
141+
/// more recursive invocations is allowed while this is true (the replacement
142+
/// stack can no longer grow).
143+
bool resolvingCycle = false;
144+
};
145+
146+
template <typename InT, typename OutT>
147+
typename CyclicReplacerCache<InT, OutT>::CacheEntry
148+
CyclicReplacerCache<InT, OutT>::lookupOrInit(const InT &element) {
149+
assert(!resolvingCycle &&
150+
"illegal recursive invocation while breaking cycle");
151+
152+
if (auto it = standaloneCache.find(element); it != standaloneCache.end())
153+
return CacheEntry(*this, element, it->second);
154+
155+
if (auto it = dependentCache.find(element); it != dependentCache.end()) {
156+
// pdate the current top frame (the element that invoked this current
157+
// replacement) to include any dependencies the cache entry had.
158+
ReplacementFrame &currFrame = replacementStack.back();
159+
currFrame.dependentFrames.insert(it->second.highestDependentFrame);
160+
return CacheEntry(*this, element, it->second.replacement);
161+
}
162+
163+
auto [it, inserted] = cyclicElementFrame.try_emplace(element);
164+
if (!inserted) {
165+
// This is a repeat of a known element. Try to break cycle here.
166+
resolvingCycle = true;
167+
std::optional<OutT> result = cycleBreaker(element);
168+
resolvingCycle = false;
169+
if (result) {
170+
// Cycle was broken.
171+
size_t dependentFrame = it->second.back();
172+
dependentCache[element] = {*result, dependentFrame};
173+
ReplacementFrame &currFrame = replacementStack.back();
174+
// If this is a repeat, there is no replacement frame to pop. Mark the top
175+
// frame as being dependent on this element.
176+
currFrame.dependentFrames.insert(dependentFrame);
177+
178+
return CacheEntry(*this, element, *result);
179+
}
180+
181+
// Cycle could not be broken.
182+
// A legal setup must ensure at least one element of each cycle can break
183+
// cycles. Under this setup, each element can be seen at most twice before
184+
// the cycle is broken. If we see an element more than twice, we know this
185+
// is an illegal setup.
186+
assert(it->second.size() <= 2 && "illegal 3rd repeat of input");
187+
}
188+
189+
// Otherwise, either this is the first time we see this element, or this
190+
// element could not break this cycle.
191+
it->second.push_back(replacementStack.size());
192+
replacementStack.emplace_back();
193+
194+
return CacheEntry(*this, element);
195+
}
196+
197+
template <typename InT, typename OutT>
198+
void CyclicReplacerCache<InT, OutT>::finalizeReplacement(const InT &element,
199+
const OutT &result) {
200+
ReplacementFrame &currFrame = replacementStack.back();
201+
// With the conclusion of this replacement frame, the current element is no
202+
// longer a dependent element.
203+
currFrame.dependentFrames.erase(replacementStack.size() - 1);
204+
205+
auto prevLayerIter = ++replacementStack.rbegin();
206+
if (prevLayerIter == replacementStack.rend()) {
207+
// If this is the last frame, there should be zero dependents.
208+
assert(currFrame.dependentFrames.empty() &&
209+
"internal error: top-level dependent replacement");
210+
// Cache standalone result.
211+
standaloneCache[element] = result;
212+
} else if (currFrame.dependentFrames.empty()) {
213+
// Cache standalone result.
214+
standaloneCache[element] = result;
215+
} else {
216+
// Cache dependent result.
217+
size_t highestDependentFrame = *currFrame.dependentFrames.begin();
218+
dependentCache[element] = {result, highestDependentFrame};
219+
220+
// Otherwise, the previous frame inherits the same dependent frames.
221+
prevLayerIter->dependentFrames.insert(currFrame.dependentFrames.begin(),
222+
currFrame.dependentFrames.end());
223+
224+
// Mark this current replacement as a depending replacement on the closest
225+
// dependent frame.
226+
replacementStack[highestDependentFrame].dependingReplacements.insert(
227+
element);
228+
}
229+
230+
// All depending replacements in the cache must be purged.
231+
for (InT key : currFrame.dependingReplacements)
232+
dependentCache.erase(key);
233+
234+
replacementStack.pop_back();
235+
auto it = cyclicElementFrame.find(element);
236+
it->second.pop_back();
237+
if (it->second.empty())
238+
cyclicElementFrame.erase(it);
239+
}
240+
241+
//===----------------------------------------------------------------------===//
242+
// CachedCyclicReplacer
243+
//===----------------------------------------------------------------------===//
244+
245+
/// A helper class for cases where the input/output types of the replacer
246+
/// function is identical to the types stored in the cache. This class wraps
247+
/// the user-provided replacer function, and can be used in place of the user
248+
/// function.
249+
template <typename InT, typename OutT>
250+
class CachedCyclicReplacer {
251+
public:
252+
using ReplacerFn = std::function<OutT(const InT &)>;
253+
using CycleBreakerFn =
254+
typename CyclicReplacerCache<InT, OutT>::CycleBreakerFn;
255+
256+
CachedCyclicReplacer() = delete;
257+
CachedCyclicReplacer(ReplacerFn replacer, CycleBreakerFn cycleBreaker)
258+
: replacer(std::move(replacer)), cache(std::move(cycleBreaker)) {}
259+
260+
OutT operator()(const InT &element) {
261+
auto cacheEntry = cache.lookupOrInit(element);
262+
if (std::optional<OutT> result = cacheEntry.get())
263+
return *result;
264+
265+
OutT result = replacer(element);
266+
cacheEntry.resolve(result);
267+
return result;
268+
}
269+
270+
private:
271+
ReplacerFn replacer;
272+
CyclicReplacerCache<InT, OutT> cache;
273+
};
274+
275+
} // namespace mlir
276+
277+
#endif // MLIR_SUPPORT_CACHINGREPLACER_H

mlir/unittests/Support/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_mlir_unittest(MLIRSupportTests
2+
CyclicReplacerCacheTest.cpp
23
IndentedOstreamTest.cpp
34
StorageUniquerTest.cpp
45
)

0 commit comments

Comments
 (0)