Skip to content

[MLIR][Mem2Reg] Improve performance by avoiding recomputations #91444

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Transforms/Mem2Reg.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ struct Mem2RegStatistics {
LogicalResult
tryToPromoteMemorySlots(ArrayRef<PromotableAllocationOpInterface> allocators,
OpBuilder &builder, const DataLayout &dataLayout,
DominanceInfo &dominance,
Mem2RegStatistics statistics = {});

} // namespace mlir
Expand Down
65 changes: 46 additions & 19 deletions mlir/lib/Transforms/Mem2Reg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/GenericIteratedDominanceFrontier.h"

namespace mlir {
Expand Down Expand Up @@ -158,6 +157,8 @@ class MemorySlotPromotionAnalyzer {
const DataLayout &dataLayout;
};

using BlockIndexCache = DenseMap<Region *, DenseMap<Block *, size_t>>;

/// The MemorySlotPromoter handles the state of promoting a memory slot. It
/// wraps a slot and its associated allocator. This will perform the mutation of
/// IR.
Expand All @@ -166,7 +167,8 @@ class MemorySlotPromoter {
MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator,
OpBuilder &builder, DominanceInfo &dominance,
const DataLayout &dataLayout, MemorySlotPromotionInfo info,
const Mem2RegStatistics &statistics);
const Mem2RegStatistics &statistics,
BlockIndexCache &blockIndexCache);

/// Actually promotes the slot by mutating IR. Promoting a slot DOES
/// invalidate the MemorySlotPromotionInfo of other slots. Preparation of
Expand Down Expand Up @@ -207,16 +209,21 @@ class MemorySlotPromoter {
const DataLayout &dataLayout;
MemorySlotPromotionInfo info;
const Mem2RegStatistics &statistics;

/// Shared cache of block indices of specific regions.
BlockIndexCache &blockIndexCache;
};

} // namespace

MemorySlotPromoter::MemorySlotPromoter(
MemorySlot slot, PromotableAllocationOpInterface allocator,
OpBuilder &builder, DominanceInfo &dominance, const DataLayout &dataLayout,
MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics)
MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics,
BlockIndexCache &blockIndexCache)
: slot(slot), allocator(allocator), builder(builder), dominance(dominance),
dataLayout(dataLayout), info(std::move(info)), statistics(statistics) {
dataLayout(dataLayout), info(std::move(info)), statistics(statistics),
blockIndexCache(blockIndexCache) {
#ifndef NDEBUG
auto isResultOrNewBlockArgument = [&]() {
if (BlockArgument arg = dyn_cast<BlockArgument>(slot.ptr))
Expand Down Expand Up @@ -500,15 +507,29 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
}
}

/// Gets or creates a block index mapping for `region`.
static const DenseMap<Block *, size_t> &
getOrCreateBlockIndices(BlockIndexCache &blockIndexCache, Region *region) {
auto [it, inserted] = blockIndexCache.try_emplace(region);
if (!inserted)
return it->second;

DenseMap<Block *, size_t> &blockIndices = it->second;
SetVector<Block *> topologicalOrder = getTopologicallySortedBlocks(*region);
for (auto [index, block] : llvm::enumerate(topologicalOrder))
blockIndices[block] = index;
return blockIndices;
}

/// Sorts `ops` according to dominance. Relies on the topological order of basic
/// blocks to get a deterministic ordering.
static void dominanceSort(SmallVector<Operation *> &ops, Region &region) {
/// blocks to get a deterministic ordering. Uses `blockIndexCache` to avoid the
/// potentially expensive recomputation of a block index map.
static void dominanceSort(SmallVector<Operation *> &ops, Region &region,
BlockIndexCache &blockIndexCache) {
// Produce a topological block order and construct a map to lookup the indices
// of blocks.
DenseMap<Block *, size_t> topoBlockIndices;
SetVector<Block *> topologicalOrder = getTopologicallySortedBlocks(region);
for (auto [index, block] : llvm::enumerate(topologicalOrder))
topoBlockIndices[block] = index;
const DenseMap<Block *, size_t> &topoBlockIndices =
getOrCreateBlockIndices(blockIndexCache, &region);

// Combining the topological order of the basic blocks together with block
// internal operation order guarantees a deterministic, dominance respecting
Expand All @@ -527,7 +548,8 @@ void MemorySlotPromoter::removeBlockingUses() {
llvm::make_first_range(info.userToBlockingUses));

// Sort according to dominance.
dominanceSort(usersToRemoveUses, *slot.ptr.getParentBlock()->getParent());
dominanceSort(usersToRemoveUses, *slot.ptr.getParentBlock()->getParent(),
blockIndexCache);

llvm::SmallVector<Operation *> toErase;
// List of all replaced values in the slot.
Expand Down Expand Up @@ -605,20 +627,25 @@ void MemorySlotPromoter::promoteSlot() {

LogicalResult mlir::tryToPromoteMemorySlots(
ArrayRef<PromotableAllocationOpInterface> allocators, OpBuilder &builder,
const DataLayout &dataLayout, Mem2RegStatistics statistics) {
const DataLayout &dataLayout, DominanceInfo &dominance,
Mem2RegStatistics statistics) {
bool promotedAny = false;

// A cache that stores deterministic block indices which are used to determine
// a valid operation modification order. The block index maps are computed
// lazily and cached to avoid expensive recomputation.
BlockIndexCache blockIndexCache;

for (PromotableAllocationOpInterface allocator : allocators) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On a slightly unrelated note: Would it make sense to move in the repeated promotion try loop to this place? I guess that keeping some of the caches around for longer would be beneficial. Additionally, we could avoid re-walking the entire region to get all the allocators again.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds like a reasonable follow up step to me. The advantage is that the block indices could be cached longer and the possible disadvantage is that tryToPromoteMemorySlots always promotes all allocas of the region (which today we do anyways)? Are there other trade-offs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would only attempt promotion on the allocas sent in originally. Just retrying on the set of sent in ones - minus the promoted ones, seems to be the most efficient and expected behavior.

for (MemorySlot slot : allocator.getPromotableSlots()) {
if (slot.ptr.use_empty())
continue;

DominanceInfo dominance;
MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
if (info) {
MemorySlotPromoter(slot, allocator, builder, dominance, dataLayout,
std::move(*info), statistics)
std::move(*info), statistics, blockIndexCache)
.promoteSlot();
promotedAny = true;
}
Expand All @@ -640,6 +667,10 @@ struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {

bool changed = false;

auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp);
auto &dominance = getAnalysis<DominanceInfo>();

for (Region &region : scopeOp->getRegions()) {
if (region.getBlocks().empty())
continue;
Expand All @@ -655,16 +686,12 @@ struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
allocators.emplace_back(allocator);
});

auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp);

// Attempt promoting until no promotion succeeds.
if (failed(tryToPromoteMemorySlots(allocators, builder, dataLayout,
statistics)))
dominance, statistics)))
break;

changed = true;
getAnalysisManager().invalidate({});
}
}
if (!changed)
Expand Down