Skip to content

Commit b7326df

Browse files
committed
extend with fix for multi slot allocators
1 parent 0f0628b commit b7326df

File tree

8 files changed

+119
-16
lines changed

8 files changed

+119
-16
lines changed

mlir/include/mlir/Interfaces/MemorySlotInterfaces.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,12 @@ def PromotableAllocationOpInterface
6868
Hook triggered once the promotion of a slot is complete. This can
6969
also clean up the created default value if necessary.
7070
This will only be called for slots declared by this operation.
71+
72+
Must return a new promotable allocation op if this operation produced
73+
multiple promotable slots, nullopt otherwise.
7174
}],
72-
"void", "handlePromotionComplete",
75+
"std::optional<::mlir::PromotableAllocationOpInterface>",
76+
"handlePromotionComplete",
7377
(ins
7478
"const ::mlir::MemorySlot &":$slot,
7579
"::mlir::Value":$defaultValue,

mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,14 @@ void LLVM::AllocaOp::handleBlockArgument(const MemorySlot &slot,
5050
declareOp.getLocationExpr());
5151
}
5252

53-
void LLVM::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
54-
Value defaultValue,
55-
OpBuilder &builder) {
53+
std::optional<PromotableAllocationOpInterface>
54+
LLVM::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
55+
Value defaultValue,
56+
OpBuilder &builder) {
5657
if (defaultValue && defaultValue.use_empty())
5758
defaultValue.getDefiningOp()->erase();
5859
this->erase();
60+
return std::nullopt;
5961
}
6062

6163
SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() {

mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,14 @@ Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
9696
});
9797
}
9898

99-
void memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
100-
Value defaultValue,
101-
OpBuilder &builder) {
99+
std::optional<PromotableAllocationOpInterface>
100+
memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
101+
Value defaultValue,
102+
OpBuilder &builder) {
102103
if (defaultValue.use_empty())
103104
defaultValue.getDefiningOp()->erase();
104105
this->erase();
106+
return std::nullopt;
105107
}
106108

107109
void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot,

mlir/lib/Transforms/Mem2Reg.cpp

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,9 @@ class MemorySlotPromoter {
173173
/// Actually promotes the slot by mutating IR. Promoting a slot DOES
174174
/// invalidate the MemorySlotPromotionInfo of other slots. Preparation of
175175
/// promotion info should NOT be performed in batches.
176-
void promoteSlot();
176+
/// Returns a promotable allocation op if a new allocator was created, nullopt
177+
/// otherwise.
178+
std::optional<PromotableAllocationOpInterface> promoteSlot();
177179

178180
private:
179181
/// Computes the reaching definition for all the operations that require
@@ -595,7 +597,8 @@ void MemorySlotPromoter::removeBlockingUses() {
595597
"after promotion, the slot pointer should not be used anymore");
596598
}
597599

598-
void MemorySlotPromoter::promoteSlot() {
600+
std::optional<PromotableAllocationOpInterface>
601+
MemorySlotPromoter::promoteSlot() {
599602
computeReachingDefInRegion(slot.ptr.getParentRegion(),
600603
getOrCreateDefaultValue());
601604

@@ -622,7 +625,7 @@ void MemorySlotPromoter::promoteSlot() {
622625
if (statistics.promotedAmount)
623626
(*statistics.promotedAmount)++;
624627

625-
allocator.handlePromotionComplete(slot, defaultValue, builder);
628+
return allocator.handlePromotionComplete(slot, defaultValue, builder);
626629
}
627630

628631
LogicalResult mlir::tryToPromoteMemorySlots(
@@ -642,6 +645,7 @@ LogicalResult mlir::tryToPromoteMemorySlots(
642645
SmallVector<PromotableAllocationOpInterface> newWorkList;
643646
newWorkList.reserve(workList.size());
644647
while (true) {
648+
bool changesInThisRound = false;
645649
for (PromotableAllocationOpInterface allocator : workList) {
646650
for (MemorySlot slot : allocator.getPromotableSlots()) {
647651
if (slot.ptr.use_empty())
@@ -650,17 +654,27 @@ LogicalResult mlir::tryToPromoteMemorySlots(
650654
MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
651655
std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
652656
if (info) {
653-
MemorySlotPromoter(slot, allocator, builder, dominance, dataLayout,
654-
std::move(*info), statistics, blockIndexCache)
655-
.promoteSlot();
656-
promotedAny = true;
657-
continue;
657+
std::optional<PromotableAllocationOpInterface> newAllocator =
658+
MemorySlotPromoter(slot, allocator, builder, dominance,
659+
dataLayout, std::move(*info), statistics,
660+
blockIndexCache)
661+
.promoteSlot();
662+
changesInThisRound = true;
663+
// Add newly created allocators to the worklist for further
664+
// processing.
665+
if (newAllocator)
666+
newWorkList.push_back(*newAllocator);
667+
668+
// Breaking is required, as a modification to an allocator might have
669+
// removed it, making the other slots invalid.
670+
break;
658671
}
659672
newWorkList.push_back(allocator);
660673
}
661674
}
662-
if (workList.size() == newWorkList.size())
675+
if (!changesInThisRound)
663676
break;
677+
promotedAny = true;
664678

665679
// Swap the vector's backing memory and clear the entries in newWorkList
666680
// afterwards. This ensures that additional heap allocations can be avoided.

mlir/test/Transforms/mem2reg.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(mem2reg))' --split-input-file | FileCheck %s
2+
3+
// Verifies that allocators with mutliple slots are handled properly.
4+
5+
// CHECK-LABEL: func.func @multi_slot_alloca
6+
func.func @multi_slot_alloca() -> (i32, i32) {
7+
// CHECK-NOT: test.multi_slot_alloca
8+
%1, %2 = test.multi_slot_alloca : () -> (memref<i32>, memref<i32>)
9+
%3 = memref.load %1[] : memref<i32>
10+
%4 = memref.load %2[] : memref<i32>
11+
return %3, %4 : i32, i32
12+
}

mlir/test/lib/Dialect/Test/TestOpDefs.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1212
#include "mlir/IR/Verifier.h"
1313
#include "mlir/Interfaces/FunctionImplementation.h"
14+
#include "mlir/Interfaces/MemorySlotInterfaces.h"
1415

1516
using namespace mlir;
1617
using namespace test;
@@ -1172,3 +1173,59 @@ void TestOpWithVersionedProperties::writeToMlirBytecode(
11721173
writer.writeVarInt(prop.value1);
11731174
writer.writeVarInt(prop.value2);
11741175
}
1176+
1177+
//===----------------------------------------------------------------------===//
1178+
// TestMultiSlotAlloca
1179+
//===----------------------------------------------------------------------===//
1180+
1181+
llvm::SmallVector<MemorySlot> TestMultiSlotAlloca::getPromotableSlots() {
1182+
SmallVector<MemorySlot> slots;
1183+
for (Value result : getResults()) {
1184+
slots.push_back(MemorySlot{
1185+
result, cast<MemRefType>(result.getType()).getElementType()});
1186+
}
1187+
return slots;
1188+
}
1189+
1190+
Value TestMultiSlotAlloca::getDefaultValue(const MemorySlot &slot,
1191+
OpBuilder &builder) {
1192+
return builder.create<TestOpConstant>(getLoc(), slot.elemType,
1193+
builder.getI32IntegerAttr(42));
1194+
}
1195+
1196+
void TestMultiSlotAlloca::handleBlockArgument(const MemorySlot &slot,
1197+
BlockArgument argument,
1198+
OpBuilder &builder) {
1199+
// Not relevant for testing.
1200+
}
1201+
1202+
std::optional<PromotableAllocationOpInterface>
1203+
TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot,
1204+
Value defaultValue,
1205+
OpBuilder &builder) {
1206+
if (defaultValue && defaultValue.use_empty())
1207+
defaultValue.getDefiningOp()->erase();
1208+
1209+
if (getNumResults() == 1) {
1210+
erase();
1211+
return std::nullopt;
1212+
}
1213+
1214+
SmallVector<Type> newTypes;
1215+
SmallVector<Value> remainingValues;
1216+
1217+
for (Value oldResult : getResults()) {
1218+
if (oldResult == slot.ptr)
1219+
continue;
1220+
remainingValues.push_back(oldResult);
1221+
newTypes.push_back(oldResult.getType());
1222+
}
1223+
1224+
auto replacement = builder.create<TestMultiSlotAlloca>(getLoc(), newTypes);
1225+
for (auto [oldResult, newResult] :
1226+
llvm::zip_equal(remainingValues, replacement.getResults()))
1227+
oldResult.replaceAllUsesWith(newResult);
1228+
1229+
erase();
1230+
return replacement;
1231+
}

mlir/test/lib/Dialect/Test/TestOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "mlir/Interfaces/InferIntRangeInterface.h"
3737
#include "mlir/Interfaces/InferTypeOpInterface.h"
3838
#include "mlir/Interfaces/LoopLikeInterface.h"
39+
#include "mlir/Interfaces/MemorySlotInterfaces.h"
3940
#include "mlir/Interfaces/SideEffectInterfaces.h"
4041
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
4142
#include "mlir/Interfaces/ViewLikeInterface.h"

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ include "mlir/Interfaces/DestinationStyleOpInterface.td"
2828
include "mlir/Interfaces/InferIntRangeInterface.td"
2929
include "mlir/Interfaces/InferTypeOpInterface.td"
3030
include "mlir/Interfaces/LoopLikeInterface.td"
31+
include "mlir/Interfaces/MemorySlotInterfaces.td"
3132
include "mlir/Interfaces/SideEffectInterfaces.td"
3233

3334

@@ -3167,4 +3168,14 @@ def TestOpOptionallyImplementingInterface
31673168
let arguments = (ins BoolAttr:$implementsInterface);
31683169
}
31693170

3171+
//===----------------------------------------------------------------------===//
3172+
// Test Mem2Reg
3173+
//===----------------------------------------------------------------------===//
3174+
3175+
def TestMultiSlotAlloca : TEST_Op<"multi_slot_alloca",
3176+
[DeclareOpInterfaceMethods<PromotableAllocationOpInterface>]> {
3177+
let results = (outs Variadic<MemRefOf<[I32]>>:$results);
3178+
let assemblyFormat = "attr-dict `:` functional-type(operands, results)";
3179+
}
3180+
31703181
#endif // TEST_OPS

0 commit comments

Comments
 (0)