Skip to content

Commit eeafc9d

Browse files
authored
[MLIR][Mem2Reg] Fix multi slot handling & move retry handling (#91464)
This commit fixes Mem2Regs mutli-slot allocator handling and extends the test dialect to test this. Additionally, this modifies Mem2Reg's API to always attempt a full promotion on all the passed in "allocators". This ensures that the pass does not require unnecessary walks over the regions and improves caching benefits.
1 parent 2163ae7 commit eeafc9d

File tree

9 files changed

+173
-39
lines changed

9 files changed

+173
-39
lines changed

mlir/include/mlir/Interfaces/MemorySlotInterfaces.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,12 @@ def PromotableAllocationOpInterface
6868
Hook triggered once the promotion of a slot is complete. This can
6969
also clean up the created default value if necessary.
7070
This will only be called for slots declared by this operation.
71+
72+
Must return a new promotable allocation op if this operation produced
73+
multiple promotable slots, nullopt otherwise.
7174
}],
72-
"void", "handlePromotionComplete",
75+
"::std::optional<::mlir::PromotableAllocationOpInterface>",
76+
"handlePromotionComplete",
7377
(ins
7478
"const ::mlir::MemorySlot &":$slot,
7579
"::mlir::Value":$defaultValue,

mlir/include/mlir/Transforms/Mem2Reg.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#ifndef MLIR_TRANSFORMS_MEM2REG_H
1010
#define MLIR_TRANSFORMS_MEM2REG_H
1111

12-
#include "mlir/IR/PatternMatch.h"
1312
#include "mlir/Interfaces/MemorySlotInterfaces.h"
1413
#include "llvm/ADT/Statistic.h"
1514

@@ -23,8 +22,9 @@ struct Mem2RegStatistics {
2322
llvm::Statistic *newBlockArgumentAmount = nullptr;
2423
};
2524

26-
/// Attempts to promote the memory slots of the provided allocators. Succeeds if
27-
/// at least one memory slot was promoted.
25+
/// Attempts to promote the memory slots of the provided allocators. Iteratively
26+
/// retries the promotion of all slots as promoting one slot might enable
27+
/// subsequent promotions. Succeeds if at least one memory slot was promoted.
2828
LogicalResult
2929
tryToPromoteMemorySlots(ArrayRef<PromotableAllocationOpInterface> allocators,
3030
OpBuilder &builder, const DataLayout &dataLayout,

mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,14 @@ void LLVM::AllocaOp::handleBlockArgument(const MemorySlot &slot,
5050
declareOp.getLocationExpr());
5151
}
5252

53-
void LLVM::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
54-
Value defaultValue,
55-
OpBuilder &builder) {
53+
std::optional<PromotableAllocationOpInterface>
54+
LLVM::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
55+
Value defaultValue,
56+
OpBuilder &builder) {
5657
if (defaultValue && defaultValue.use_empty())
5758
defaultValue.getDefiningOp()->erase();
5859
this->erase();
60+
return std::nullopt;
5961
}
6062

6163
SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() {

mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,14 @@ Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
9696
});
9797
}
9898

99-
void memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
100-
Value defaultValue,
101-
OpBuilder &builder) {
99+
std::optional<PromotableAllocationOpInterface>
100+
memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
101+
Value defaultValue,
102+
OpBuilder &builder) {
102103
if (defaultValue.use_empty())
103104
defaultValue.getDefiningOp()->erase();
104105
this->erase();
106+
return std::nullopt;
105107
}
106108

107109
void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot,

mlir/lib/Transforms/Mem2Reg.cpp

Lines changed: 56 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,9 @@ class MemorySlotPromoter {
173173
/// Actually promotes the slot by mutating IR. Promoting a slot DOES
174174
/// invalidate the MemorySlotPromotionInfo of other slots. Preparation of
175175
/// promotion info should NOT be performed in batches.
176-
void promoteSlot();
176+
/// Returns a promotable allocation op if a new allocator was created, nullopt
177+
/// otherwise.
178+
std::optional<PromotableAllocationOpInterface> promoteSlot();
177179

178180
private:
179181
/// Computes the reaching definition for all the operations that require
@@ -595,7 +597,8 @@ void MemorySlotPromoter::removeBlockingUses() {
595597
"after promotion, the slot pointer should not be used anymore");
596598
}
597599

598-
void MemorySlotPromoter::promoteSlot() {
600+
std::optional<PromotableAllocationOpInterface>
601+
MemorySlotPromoter::promoteSlot() {
599602
computeReachingDefInRegion(slot.ptr.getParentRegion(),
600603
getOrCreateDefaultValue());
601604

@@ -622,7 +625,7 @@ void MemorySlotPromoter::promoteSlot() {
622625
if (statistics.promotedAmount)
623626
(*statistics.promotedAmount)++;
624627

625-
allocator.handlePromotionComplete(slot, defaultValue, builder);
628+
return allocator.handlePromotionComplete(slot, defaultValue, builder);
626629
}
627630

628631
LogicalResult mlir::tryToPromoteMemorySlots(
@@ -636,20 +639,50 @@ LogicalResult mlir::tryToPromoteMemorySlots(
636639
// lazily and cached to avoid expensive recomputation.
637640
BlockIndexCache blockIndexCache;
638641

639-
for (PromotableAllocationOpInterface allocator : allocators) {
640-
for (MemorySlot slot : allocator.getPromotableSlots()) {
641-
if (slot.ptr.use_empty())
642-
continue;
643-
644-
MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
645-
std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
646-
if (info) {
647-
MemorySlotPromoter(slot, allocator, builder, dominance, dataLayout,
648-
std::move(*info), statistics, blockIndexCache)
649-
.promoteSlot();
650-
promotedAny = true;
642+
SmallVector<PromotableAllocationOpInterface> workList(allocators.begin(),
643+
allocators.end());
644+
645+
SmallVector<PromotableAllocationOpInterface> newWorkList;
646+
newWorkList.reserve(workList.size());
647+
while (true) {
648+
bool changesInThisRound = false;
649+
for (PromotableAllocationOpInterface allocator : workList) {
650+
bool changedAllocator = false;
651+
for (MemorySlot slot : allocator.getPromotableSlots()) {
652+
if (slot.ptr.use_empty())
653+
continue;
654+
655+
MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
656+
std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
657+
if (info) {
658+
std::optional<PromotableAllocationOpInterface> newAllocator =
659+
MemorySlotPromoter(slot, allocator, builder, dominance,
660+
dataLayout, std::move(*info), statistics,
661+
blockIndexCache)
662+
.promoteSlot();
663+
changedAllocator = true;
664+
// Add newly created allocators to the worklist for further
665+
// processing.
666+
if (newAllocator)
667+
newWorkList.push_back(*newAllocator);
668+
669+
// A break is required, since promoting a slot may invalidate the
670+
// remaining slots of an allocator.
671+
break;
672+
}
651673
}
674+
if (!changedAllocator)
675+
newWorkList.push_back(allocator);
676+
changesInThisRound |= changedAllocator;
652677
}
678+
if (!changesInThisRound)
679+
break;
680+
promotedAny = true;
681+
682+
// Swap the vector's backing memory and clear the entries in newWorkList
683+
// afterwards. This ensures that additional heap allocations can be avoided.
684+
workList.swap(newWorkList);
685+
newWorkList.clear();
653686
}
654687

655688
return success(promotedAny);
@@ -677,22 +710,16 @@ struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
677710

678711
OpBuilder builder(&region.front(), region.front().begin());
679712

680-
// Promoting a slot can allow for further promotion of other slots,
681-
// promotion is tried until no promotion succeeds.
682-
while (true) {
683-
SmallVector<PromotableAllocationOpInterface> allocators;
684-
// Build a list of allocators to attempt to promote the slots of.
685-
region.walk([&](PromotableAllocationOpInterface allocator) {
686-
allocators.emplace_back(allocator);
687-
});
688-
689-
// Attempt promoting until no promotion succeeds.
690-
if (failed(tryToPromoteMemorySlots(allocators, builder, dataLayout,
691-
dominance, statistics)))
692-
break;
713+
SmallVector<PromotableAllocationOpInterface> allocators;
714+
// Build a list of allocators to attempt to promote the slots of.
715+
region.walk([&](PromotableAllocationOpInterface allocator) {
716+
allocators.emplace_back(allocator);
717+
});
693718

719+
// Attempt promoting as many of the slots as possible.
720+
if (succeeded(tryToPromoteMemorySlots(allocators, builder, dataLayout,
721+
dominance, statistics)))
694722
changed = true;
695-
}
696723
}
697724
if (!changed)
698725
markAllAnalysesPreserved();

mlir/test/Transforms/mem2reg.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(mem2reg))' --split-input-file | FileCheck %s
2+
3+
// Verifies that allocators with mutliple slots are handled properly.
4+
5+
// CHECK-LABEL: func.func @multi_slot_alloca
6+
func.func @multi_slot_alloca() -> (i32, i32) {
7+
// CHECK-NOT: test.multi_slot_alloca
8+
%1, %2 = test.multi_slot_alloca : () -> (memref<i32>, memref<i32>)
9+
%3 = memref.load %1[] : memref<i32>
10+
%4 = memref.load %2[] : memref<i32>
11+
return %3, %4 : i32, i32
12+
}
13+
14+
// -----
15+
16+
// Verifies that a multi slot allocator can be partially promoted.
17+
18+
func.func private @consumer(memref<i32>)
19+
20+
// CHECK-LABEL: func.func @multi_slot_alloca_only_second
21+
func.func @multi_slot_alloca_only_second() -> (i32, i32) {
22+
// CHECK: %{{[[:alnum:]]+}} = test.multi_slot_alloca
23+
%1, %2 = test.multi_slot_alloca : () -> (memref<i32>, memref<i32>)
24+
func.call @consumer(%1) : (memref<i32>) -> ()
25+
%3 = memref.load %1[] : memref<i32>
26+
%4 = memref.load %2[] : memref<i32>
27+
return %3, %4 : i32, i32
28+
}

mlir/test/lib/Dialect/Test/TestOpDefs.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1212
#include "mlir/IR/Verifier.h"
1313
#include "mlir/Interfaces/FunctionImplementation.h"
14+
#include "mlir/Interfaces/MemorySlotInterfaces.h"
1415

1516
using namespace mlir;
1617
using namespace test;
@@ -1172,3 +1173,61 @@ void TestOpWithVersionedProperties::writeToMlirBytecode(
11721173
writer.writeVarInt(prop.value1);
11731174
writer.writeVarInt(prop.value2);
11741175
}
1176+
1177+
//===----------------------------------------------------------------------===//
1178+
// TestMultiSlotAlloca
1179+
//===----------------------------------------------------------------------===//
1180+
1181+
llvm::SmallVector<MemorySlot> TestMultiSlotAlloca::getPromotableSlots() {
1182+
SmallVector<MemorySlot> slots;
1183+
for (Value result : getResults()) {
1184+
slots.push_back(MemorySlot{
1185+
result, cast<MemRefType>(result.getType()).getElementType()});
1186+
}
1187+
return slots;
1188+
}
1189+
1190+
Value TestMultiSlotAlloca::getDefaultValue(const MemorySlot &slot,
1191+
OpBuilder &builder) {
1192+
return builder.create<TestOpConstant>(getLoc(), slot.elemType,
1193+
builder.getI32IntegerAttr(42));
1194+
}
1195+
1196+
void TestMultiSlotAlloca::handleBlockArgument(const MemorySlot &slot,
1197+
BlockArgument argument,
1198+
OpBuilder &builder) {
1199+
// Not relevant for testing.
1200+
}
1201+
1202+
std::optional<PromotableAllocationOpInterface>
1203+
TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot,
1204+
Value defaultValue,
1205+
OpBuilder &builder) {
1206+
if (defaultValue && defaultValue.use_empty())
1207+
defaultValue.getDefiningOp()->erase();
1208+
1209+
if (getNumResults() == 1) {
1210+
erase();
1211+
return std::nullopt;
1212+
}
1213+
1214+
SmallVector<Type> newTypes;
1215+
SmallVector<Value> remainingValues;
1216+
1217+
for (Value oldResult : getResults()) {
1218+
if (oldResult == slot.ptr)
1219+
continue;
1220+
remainingValues.push_back(oldResult);
1221+
newTypes.push_back(oldResult.getType());
1222+
}
1223+
1224+
OpBuilder::InsertionGuard guard(builder);
1225+
builder.setInsertionPoint(*this);
1226+
auto replacement = builder.create<TestMultiSlotAlloca>(getLoc(), newTypes);
1227+
for (auto [oldResult, newResult] :
1228+
llvm::zip_equal(remainingValues, replacement.getResults()))
1229+
oldResult.replaceAllUsesWith(newResult);
1230+
1231+
erase();
1232+
return replacement;
1233+
}

mlir/test/lib/Dialect/Test/TestOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "mlir/Interfaces/InferIntRangeInterface.h"
3737
#include "mlir/Interfaces/InferTypeOpInterface.h"
3838
#include "mlir/Interfaces/LoopLikeInterface.h"
39+
#include "mlir/Interfaces/MemorySlotInterfaces.h"
3940
#include "mlir/Interfaces/SideEffectInterfaces.h"
4041
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
4142
#include "mlir/Interfaces/ViewLikeInterface.h"

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ include "mlir/Interfaces/DestinationStyleOpInterface.td"
2828
include "mlir/Interfaces/InferIntRangeInterface.td"
2929
include "mlir/Interfaces/InferTypeOpInterface.td"
3030
include "mlir/Interfaces/LoopLikeInterface.td"
31+
include "mlir/Interfaces/MemorySlotInterfaces.td"
3132
include "mlir/Interfaces/SideEffectInterfaces.td"
3233

3334

@@ -3167,4 +3168,14 @@ def TestOpOptionallyImplementingInterface
31673168
let arguments = (ins BoolAttr:$implementsInterface);
31683169
}
31693170

3171+
//===----------------------------------------------------------------------===//
3172+
// Test Mem2Reg
3173+
//===----------------------------------------------------------------------===//
3174+
3175+
def TestMultiSlotAlloca : TEST_Op<"multi_slot_alloca",
3176+
[DeclareOpInterfaceMethods<PromotableAllocationOpInterface>]> {
3177+
let results = (outs Variadic<MemRefOf<[I32]>>:$results);
3178+
let assemblyFormat = "attr-dict `:` functional-type(operands, results)";
3179+
}
3180+
31703181
#endif // TEST_OPS

0 commit comments

Comments
 (0)