18
18
#include " mlir/Transforms/Passes.h"
19
19
#include " mlir/Transforms/RegionUtils.h"
20
20
#include " llvm/ADT/STLExtras.h"
21
- #include " llvm/Support/Casting.h"
22
21
#include " llvm/Support/GenericIteratedDominanceFrontier.h"
23
22
24
23
namespace mlir {
@@ -158,6 +157,8 @@ class MemorySlotPromotionAnalyzer {
158
157
const DataLayout &dataLayout;
159
158
};
160
159
160
+ using BlockIndexCache = DenseMap<Region *, DenseMap<Block *, size_t >>;
161
+
161
162
// / The MemorySlotPromoter handles the state of promoting a memory slot. It
162
163
// / wraps a slot and its associated allocator. This will perform the mutation of
163
164
// / IR.
@@ -166,7 +167,8 @@ class MemorySlotPromoter {
166
167
MemorySlotPromoter (MemorySlot slot, PromotableAllocationOpInterface allocator,
167
168
OpBuilder &builder, DominanceInfo &dominance,
168
169
const DataLayout &dataLayout, MemorySlotPromotionInfo info,
169
- const Mem2RegStatistics &statistics);
170
+ const Mem2RegStatistics &statistics,
171
+ BlockIndexCache &blockIndexCache);
170
172
171
173
// / Actually promotes the slot by mutating IR. Promoting a slot DOES
172
174
// / invalidate the MemorySlotPromotionInfo of other slots. Preparation of
@@ -207,16 +209,21 @@ class MemorySlotPromoter {
207
209
const DataLayout &dataLayout;
208
210
MemorySlotPromotionInfo info;
209
211
const Mem2RegStatistics &statistics;
212
+
213
+ // / Shared cache of block indices of specific regions.
214
+ BlockIndexCache &blockIndexCache;
210
215
};
211
216
212
217
} // namespace
213
218
214
219
MemorySlotPromoter::MemorySlotPromoter (
215
220
MemorySlot slot, PromotableAllocationOpInterface allocator,
216
221
OpBuilder &builder, DominanceInfo &dominance, const DataLayout &dataLayout,
217
- MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics)
222
+ MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics,
223
+ BlockIndexCache &blockIndexCache)
218
224
: slot(slot), allocator(allocator), builder(builder), dominance(dominance),
219
- dataLayout(dataLayout), info(std::move(info)), statistics(statistics) {
225
+ dataLayout(dataLayout), info(std::move(info)), statistics(statistics),
226
+ blockIndexCache(blockIndexCache) {
220
227
#ifndef NDEBUG
221
228
auto isResultOrNewBlockArgument = [&]() {
222
229
if (BlockArgument arg = dyn_cast<BlockArgument>(slot.ptr ))
@@ -500,15 +507,29 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
500
507
}
501
508
}
502
509
510
+ // / Gets or creates a block index mapping for `region`.
511
+ static const DenseMap<Block *, size_t > &
512
+ getOrCreateBlockIndices (BlockIndexCache &blockIndexCache, Region *region) {
513
+ auto [it, created] = blockIndexCache.try_emplace (region);
514
+ if (!created)
515
+ return it->second ;
516
+
517
+ DenseMap<Block *, size_t > &blockIndices = it->second ;
518
+ SetVector<Block *> topologicalOrder = getTopologicallySortedBlocks (*region);
519
+ for (auto [index, block] : llvm::enumerate (topologicalOrder))
520
+ blockIndices[block] = index;
521
+ return blockIndices;
522
+ }
523
+
503
524
// / Sorts `ops` according to dominance. Relies on the topological order of basic
504
- // / blocks to get a deterministic ordering.
505
- static void dominanceSort (SmallVector<Operation *> &ops, Region ®ion) {
525
+ // / blocks to get a deterministic ordering. Uses `blockIndexCache` to avoid the
526
+ // / potentially expensive recomputation of a block index map.
527
+ static void dominanceSort (SmallVector<Operation *> &ops, Region ®ion,
528
+ BlockIndexCache &blockIndexCache) {
506
529
// Produce a topological block order and construct a map to lookup the indices
507
530
// of blocks.
508
- DenseMap<Block *, size_t > topoBlockIndices;
509
- SetVector<Block *> topologicalOrder = getTopologicallySortedBlocks (region);
510
- for (auto [index, block] : llvm::enumerate (topologicalOrder))
511
- topoBlockIndices[block] = index;
531
+ const DenseMap<Block *, size_t > &topoBlockIndices =
532
+ getOrCreateBlockIndices (blockIndexCache, ®ion);
512
533
513
534
// Combining the topological order of the basic blocks together with block
514
535
// internal operation order guarantees a deterministic, dominance respecting
@@ -527,7 +548,8 @@ void MemorySlotPromoter::removeBlockingUses() {
527
548
llvm::make_first_range (info.userToBlockingUses ));
528
549
529
550
// Sort according to dominance.
530
- dominanceSort (usersToRemoveUses, *slot.ptr .getParentBlock ()->getParent ());
551
+ dominanceSort (usersToRemoveUses, *slot.ptr .getParentBlock ()->getParent (),
552
+ blockIndexCache);
531
553
532
554
llvm::SmallVector<Operation *> toErase;
533
555
// List of all replaced values in the slot.
@@ -605,20 +627,24 @@ void MemorySlotPromoter::promoteSlot() {
605
627
606
628
LogicalResult mlir::tryToPromoteMemorySlots (
607
629
ArrayRef<PromotableAllocationOpInterface> allocators, OpBuilder &builder,
608
- const DataLayout &dataLayout, Mem2RegStatistics statistics) {
630
+ const DataLayout &dataLayout, DominanceInfo &dominance,
631
+ Mem2RegStatistics statistics) {
609
632
bool promotedAny = false ;
610
633
634
+ // Cache for block index maps. This is required to avoid expensive
635
+ // recomputations.
636
+ BlockIndexCache blockIndexCache;
637
+
611
638
for (PromotableAllocationOpInterface allocator : allocators) {
612
639
for (MemorySlot slot : allocator.getPromotableSlots ()) {
613
640
if (slot.ptr .use_empty ())
614
641
continue ;
615
642
616
- DominanceInfo dominance;
617
643
MemorySlotPromotionAnalyzer analyzer (slot, dominance, dataLayout);
618
644
std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo ();
619
645
if (info) {
620
646
MemorySlotPromoter (slot, allocator, builder, dominance, dataLayout,
621
- std::move (*info), statistics)
647
+ std::move (*info), statistics, blockIndexCache )
622
648
.promoteSlot ();
623
649
promotedAny = true ;
624
650
}
@@ -640,6 +666,10 @@ struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
640
666
641
667
bool changed = false ;
642
668
669
+ auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
670
+ const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove (scopeOp);
671
+ auto &dominance = getAnalysis<DominanceInfo>();
672
+
643
673
for (Region ®ion : scopeOp->getRegions ()) {
644
674
if (region.getBlocks ().empty ())
645
675
continue ;
@@ -655,16 +685,12 @@ struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
655
685
allocators.emplace_back (allocator);
656
686
});
657
687
658
- auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
659
- const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove (scopeOp);
660
-
661
688
// Attempt promoting until no promotion succeeds.
662
689
if (failed (tryToPromoteMemorySlots (allocators, builder, dataLayout,
663
- statistics)))
690
+ dominance, statistics)))
664
691
break ;
665
692
666
693
changed = true ;
667
- getAnalysisManager ().invalidate ({});
668
694
}
669
695
}
670
696
if (!changed)
0 commit comments