|
| 1 | +//===- CyclicReplacerCache.h ------------------------------------*- C++ -*-===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +// |
| 9 | +// This file contains helper classes for caching replacer-like functions that |
| 10 | +// map values between two domains. They are able to handle replacer logic that |
| 11 | +// contains self-recursion. |
| 12 | +// |
| 13 | +//===----------------------------------------------------------------------===// |
| 14 | + |
| 15 | +#ifndef MLIR_SUPPORT_CACHINGREPLACER_H |
| 16 | +#define MLIR_SUPPORT_CACHINGREPLACER_H |
| 17 | + |
| 18 | +#include "mlir/IR/Visitors.h" |
| 19 | +#include "llvm/ADT/DenseSet.h" |
| 20 | +#include "llvm/ADT/MapVector.h" |
| 21 | +#include <set> |
| 22 | + |
| 23 | +namespace mlir { |
| 24 | + |
| 25 | +//===----------------------------------------------------------------------===// |
| 26 | +// CyclicReplacerCache |
| 27 | +//===----------------------------------------------------------------------===// |
| 28 | + |
| 29 | +/// A cache for replacer-like functions that map values between two domains. The |
| 30 | +/// difference compared to just using a map to cache in-out pairs is that this |
| 31 | +/// class is able to handle replacer logic that is self-recursive (and thus may |
| 32 | +/// cause infinite recursion in the naive case). |
| 33 | +/// |
| 34 | +/// This class provides a hook for the user to perform cycle pruning when a |
| 35 | +/// cycle is identified, and is able to perform context-sensitive caching so |
| 36 | +/// that the replacement result for an input that is part of a pruned cycle can |
| 37 | +/// be distinct from the replacement result for the same input when it is not |
| 38 | +/// part of a cycle. |
| 39 | +/// |
| 40 | +/// In addition, this class allows deferring cycle pruning until specific inputs |
| 41 | +/// are repeated. This is useful for cases where not all elements in a cycle can |
| 42 | +/// perform pruning. The user still must guarantee that at least one element in |
| 43 | +/// any given cycle can perform pruning. Even if not, an assertion will |
| 44 | +/// eventually be tripped instead of infinite recursion (the run-time is |
| 45 | +/// linearly bounded by the maximum cycle length of its input). |
| 46 | +template <typename InT, typename OutT> |
| 47 | +class CyclicReplacerCache { |
| 48 | +public: |
| 49 | + /// User-provided replacement function & cycle-breaking functions. |
| 50 | + /// The cycle-breaking function must not make any more recursive invocations |
| 51 | + /// to this cached replacer. |
| 52 | + using CycleBreakerFn = std::function<std::optional<OutT>(const InT &)>; |
| 53 | + |
| 54 | + CyclicReplacerCache() = delete; |
| 55 | + CyclicReplacerCache(CycleBreakerFn cycleBreaker) |
| 56 | + : cycleBreaker(std::move(cycleBreaker)) {} |
| 57 | + |
| 58 | + /// A possibly unresolved cache entry. |
| 59 | + /// If unresolved, the entry must be resolved before it goes out of scope. |
| 60 | + struct CacheEntry { |
| 61 | + public: |
| 62 | + ~CacheEntry() { assert(result && "unresovled cache entry"); } |
| 63 | + |
| 64 | + /// Check whether this node was repeated during recursive replacements. |
| 65 | + /// This only makes sense to be called after all recursive replacements are |
| 66 | + /// completed and the current element has resurfaced to the top of the |
| 67 | + /// replacement stack. |
| 68 | + bool wasRepeated() const { |
| 69 | + // If the top frame includes itself as a dependency, then it must have |
| 70 | + // been repeated. |
| 71 | + ReplacementFrame &currFrame = cache.replacementStack.back(); |
| 72 | + size_t currFrameIndex = cache.replacementStack.size() - 1; |
| 73 | + return currFrame.dependentFrames.count(currFrameIndex); |
| 74 | + } |
| 75 | + |
| 76 | + /// Resolve an unresolved cache entry by providing the result to be stored |
| 77 | + /// in the cache. |
| 78 | + void resolve(OutT result) { |
| 79 | + assert(!this->result && "cache entry already resolved"); |
| 80 | + this->result = result; |
| 81 | + cache.finalizeReplacement(element, result); |
| 82 | + } |
| 83 | + |
| 84 | + /// Get the resolved result if one exists. |
| 85 | + std::optional<OutT> get() { return result; } |
| 86 | + |
| 87 | + private: |
| 88 | + friend class CyclicReplacerCache; |
| 89 | + CacheEntry() = delete; |
| 90 | + CacheEntry(CyclicReplacerCache<InT, OutT> &cache, InT element, |
| 91 | + std::optional<OutT> result = std::nullopt) |
| 92 | + : cache(cache), element(element), result(result) {} |
| 93 | + |
| 94 | + CyclicReplacerCache<InT, OutT> &cache; |
| 95 | + InT element; |
| 96 | + std::optional<OutT> result; |
| 97 | + }; |
| 98 | + |
| 99 | + /// Lookup the cache for a pre-calculated replacement for `element`. |
| 100 | + /// If one exists, a resolved CacheEntry will be returned. Otherwise, an |
| 101 | + /// unresolved CacheEntry will be returned, and the caller must resolve it |
| 102 | + /// with the calculated replacement so it can be registered in the cache for |
| 103 | + /// future use. |
| 104 | + /// Multiple unresolved CacheEntries may be retrieved. However, any unresolved |
| 105 | + /// CacheEntries that are returned must be resolved in reverse order of |
| 106 | + /// retrieval, i.e. the last retrieved CacheEntry must be resolved first, and |
| 107 | + /// the first retrieved CacheEntry must be resolved last. This should be |
| 108 | + /// natural when used as a stack / inside recursion. |
| 109 | + CacheEntry lookupOrInit(const InT &element); |
| 110 | + |
| 111 | +private: |
| 112 | + /// Register the replacement in the cache and update the replacementStack. |
| 113 | + void finalizeReplacement(const InT &element, const OutT &result); |
| 114 | + |
| 115 | + CycleBreakerFn cycleBreaker; |
| 116 | + DenseMap<InT, OutT> standaloneCache; |
| 117 | + |
| 118 | + struct DependentReplacement { |
| 119 | + OutT replacement; |
| 120 | + /// The highest replacement frame index that this cache entry is dependent |
| 121 | + /// on. |
| 122 | + size_t highestDependentFrame; |
| 123 | + }; |
| 124 | + DenseMap<InT, DependentReplacement> dependentCache; |
| 125 | + |
| 126 | + struct ReplacementFrame { |
| 127 | + /// The set of elements that is only legal while under this current frame. |
| 128 | + /// They need to be removed from the cache when this frame is popped off the |
| 129 | + /// replacement stack. |
| 130 | + DenseSet<InT> dependingReplacements; |
| 131 | + /// The set of frame indices that this current frame's replacement is |
| 132 | + /// dependent on, ordered from highest to lowest. |
| 133 | + std::set<size_t, std::greater<size_t>> dependentFrames; |
| 134 | + }; |
| 135 | + /// Every element currently in the progress of being replaced pushes a frame |
| 136 | + /// onto this stack. |
| 137 | + SmallVector<ReplacementFrame> replacementStack; |
| 138 | + /// Maps from each input element to its indices on the replacement stack. |
| 139 | + DenseMap<InT, SmallVector<size_t, 2>> cyclicElementFrame; |
| 140 | + /// If set to true, we are currently asking an element to break a cycle. No |
| 141 | + /// more recursive invocations is allowed while this is true (the replacement |
| 142 | + /// stack can no longer grow). |
| 143 | + bool resolvingCycle = false; |
| 144 | +}; |
| 145 | + |
| 146 | +template <typename InT, typename OutT> |
| 147 | +typename CyclicReplacerCache<InT, OutT>::CacheEntry |
| 148 | +CyclicReplacerCache<InT, OutT>::lookupOrInit(const InT &element) { |
| 149 | + assert(!resolvingCycle && |
| 150 | + "illegal recursive invocation while breaking cycle"); |
| 151 | + |
| 152 | + if (auto it = standaloneCache.find(element); it != standaloneCache.end()) |
| 153 | + return CacheEntry(*this, element, it->second); |
| 154 | + |
| 155 | + if (auto it = dependentCache.find(element); it != dependentCache.end()) { |
| 156 | + // pdate the current top frame (the element that invoked this current |
| 157 | + // replacement) to include any dependencies the cache entry had. |
| 158 | + ReplacementFrame &currFrame = replacementStack.back(); |
| 159 | + currFrame.dependentFrames.insert(it->second.highestDependentFrame); |
| 160 | + return CacheEntry(*this, element, it->second.replacement); |
| 161 | + } |
| 162 | + |
| 163 | + auto [it, inserted] = cyclicElementFrame.try_emplace(element); |
| 164 | + if (!inserted) { |
| 165 | + // This is a repeat of a known element. Try to break cycle here. |
| 166 | + resolvingCycle = true; |
| 167 | + std::optional<OutT> result = cycleBreaker(element); |
| 168 | + resolvingCycle = false; |
| 169 | + if (result) { |
| 170 | + // Cycle was broken. |
| 171 | + size_t dependentFrame = it->second.back(); |
| 172 | + dependentCache[element] = {*result, dependentFrame}; |
| 173 | + ReplacementFrame &currFrame = replacementStack.back(); |
| 174 | + // If this is a repeat, there is no replacement frame to pop. Mark the top |
| 175 | + // frame as being dependent on this element. |
| 176 | + currFrame.dependentFrames.insert(dependentFrame); |
| 177 | + |
| 178 | + return CacheEntry(*this, element, *result); |
| 179 | + } |
| 180 | + |
| 181 | + // Cycle could not be broken. |
| 182 | + // A legal setup must ensure at least one element of each cycle can break |
| 183 | + // cycles. Under this setup, each element can be seen at most twice before |
| 184 | + // the cycle is broken. If we see an element more than twice, we know this |
| 185 | + // is an illegal setup. |
| 186 | + assert(it->second.size() <= 2 && "illegal 3rd repeat of input"); |
| 187 | + } |
| 188 | + |
| 189 | + // Otherwise, either this is the first time we see this element, or this |
| 190 | + // element could not break this cycle. |
| 191 | + it->second.push_back(replacementStack.size()); |
| 192 | + replacementStack.emplace_back(); |
| 193 | + |
| 194 | + return CacheEntry(*this, element); |
| 195 | +} |
| 196 | + |
| 197 | +template <typename InT, typename OutT> |
| 198 | +void CyclicReplacerCache<InT, OutT>::finalizeReplacement(const InT &element, |
| 199 | + const OutT &result) { |
| 200 | + ReplacementFrame &currFrame = replacementStack.back(); |
| 201 | + // With the conclusion of this replacement frame, the current element is no |
| 202 | + // longer a dependent element. |
| 203 | + currFrame.dependentFrames.erase(replacementStack.size() - 1); |
| 204 | + |
| 205 | + auto prevLayerIter = ++replacementStack.rbegin(); |
| 206 | + if (prevLayerIter == replacementStack.rend()) { |
| 207 | + // If this is the last frame, there should be zero dependents. |
| 208 | + assert(currFrame.dependentFrames.empty() && |
| 209 | + "internal error: top-level dependent replacement"); |
| 210 | + // Cache standalone result. |
| 211 | + standaloneCache[element] = result; |
| 212 | + } else if (currFrame.dependentFrames.empty()) { |
| 213 | + // Cache standalone result. |
| 214 | + standaloneCache[element] = result; |
| 215 | + } else { |
| 216 | + // Cache dependent result. |
| 217 | + size_t highestDependentFrame = *currFrame.dependentFrames.begin(); |
| 218 | + dependentCache[element] = {result, highestDependentFrame}; |
| 219 | + |
| 220 | + // Otherwise, the previous frame inherits the same dependent frames. |
| 221 | + prevLayerIter->dependentFrames.insert(currFrame.dependentFrames.begin(), |
| 222 | + currFrame.dependentFrames.end()); |
| 223 | + |
| 224 | + // Mark this current replacement as a depending replacement on the closest |
| 225 | + // dependent frame. |
| 226 | + replacementStack[highestDependentFrame].dependingReplacements.insert( |
| 227 | + element); |
| 228 | + } |
| 229 | + |
| 230 | + // All depending replacements in the cache must be purged. |
| 231 | + for (InT key : currFrame.dependingReplacements) |
| 232 | + dependentCache.erase(key); |
| 233 | + |
| 234 | + replacementStack.pop_back(); |
| 235 | + auto it = cyclicElementFrame.find(element); |
| 236 | + it->second.pop_back(); |
| 237 | + if (it->second.empty()) |
| 238 | + cyclicElementFrame.erase(it); |
| 239 | +} |
| 240 | + |
| 241 | +//===----------------------------------------------------------------------===// |
| 242 | +// CachedCyclicReplacer |
| 243 | +//===----------------------------------------------------------------------===// |
| 244 | + |
| 245 | +/// A helper class for cases where the input/output types of the replacer |
| 246 | +/// function is identical to the types stored in the cache. This class wraps |
| 247 | +/// the user-provided replacer function, and can be used in place of the user |
| 248 | +/// function. |
| 249 | +template <typename InT, typename OutT> |
| 250 | +class CachedCyclicReplacer { |
| 251 | +public: |
| 252 | + using ReplacerFn = std::function<OutT(const InT &)>; |
| 253 | + using CycleBreakerFn = |
| 254 | + typename CyclicReplacerCache<InT, OutT>::CycleBreakerFn; |
| 255 | + |
| 256 | + CachedCyclicReplacer() = delete; |
| 257 | + CachedCyclicReplacer(ReplacerFn replacer, CycleBreakerFn cycleBreaker) |
| 258 | + : replacer(std::move(replacer)), cache(std::move(cycleBreaker)) {} |
| 259 | + |
| 260 | + OutT operator()(const InT &element) { |
| 261 | + auto cacheEntry = cache.lookupOrInit(element); |
| 262 | + if (std::optional<OutT> result = cacheEntry.get()) |
| 263 | + return *result; |
| 264 | + |
| 265 | + OutT result = replacer(element); |
| 266 | + cacheEntry.resolve(result); |
| 267 | + return result; |
| 268 | + } |
| 269 | + |
| 270 | +private: |
| 271 | + ReplacerFn replacer; |
| 272 | + CyclicReplacerCache<InT, OutT> cache; |
| 273 | +}; |
| 274 | + |
| 275 | +} // namespace mlir |
| 276 | + |
| 277 | +#endif // MLIR_SUPPORT_CACHINGREPLACER_H |
0 commit comments