Skip to content

Commit 0b5b202

Browse files
authored
[MLIR][SROA] Reuse allocators to avoid rewalking the IR (llvm#91971)
This commit extends the SROA interfaces to ensure the interface instantiations can communicate newly created allocators to the algorithm. This ensures that the SROA implementation does no longer require re-walking the IR to find new allocators.
1 parent c285297 commit 0b5b202

File tree

8 files changed

+184
-58
lines changed

8 files changed

+184
-58
lines changed

mlir/include/mlir/Interfaces/MemorySlotInterfaces.td

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,14 +298,20 @@ def DestructurableAllocationOpInterface
298298
"destructure",
299299
(ins "const ::mlir::DestructurableMemorySlot &":$slot,
300300
"const ::llvm::SmallPtrSetImpl<::mlir::Attribute> &":$usedIndices,
301-
"::mlir::OpBuilder &":$builder)
301+
"::mlir::OpBuilder &":$builder,
302+
"::mlir::SmallVectorImpl<::mlir::DestructurableAllocationOpInterface> &":
303+
$newAllocators)
302304
>,
303305
InterfaceMethod<[{
304306
Hook triggered once the destructuring of a slot is complete, meaning the
305307
original slot is no longer being refered to and could be deleted.
306308
This will only be called for slots declared by this operation.
309+
310+
Must return a new destructurable allocation op if this hook creates
311+
a new destructurable op, nullopt otherwise.
307312
}],
308-
"void", "handleDestructuringComplete",
313+
"::std::optional<::mlir::DestructurableAllocationOpInterface>",
314+
"handleDestructuringComplete",
309315
(ins "const ::mlir::DestructurableMemorySlot &":$slot,
310316
"::mlir::OpBuilder &":$builder)
311317
>,

mlir/include/mlir/Transforms/SROA.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@ struct SROAStatistics {
2727
llvm::Statistic *maxSubelementAmount = nullptr;
2828
};
2929

30-
/// Attempts to destructure the slots of destructurable allocators. Returns
31-
/// failure if no slot was destructured.
30+
/// Attempts to destructure the slots of destructurable allocators. Iteratively
31+
/// retries the destructuring of all slots as destructuring one slot might
32+
/// enable subsequent destructuring. Returns failure if no slot was
33+
/// destructured.
3234
LogicalResult tryToDestructureMemorySlots(
3335
ArrayRef<DestructurableAllocationOpInterface> allocators,
3436
OpBuilder &builder, const DataLayout &dataLayout,

mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,10 @@ SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() {
7777
*destructuredType}};
7878
}
7979

80-
DenseMap<Attribute, MemorySlot>
81-
LLVM::AllocaOp::destructure(const DestructurableMemorySlot &slot,
82-
const SmallPtrSetImpl<Attribute> &usedIndices,
83-
OpBuilder &builder) {
80+
DenseMap<Attribute, MemorySlot> LLVM::AllocaOp::destructure(
81+
const DestructurableMemorySlot &slot,
82+
const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
83+
SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
8484
assert(slot.ptr == getResult());
8585
builder.setInsertionPointAfter(*this);
8686

@@ -92,16 +92,19 @@ LLVM::AllocaOp::destructure(const DestructurableMemorySlot &slot,
9292
auto subAlloca = builder.create<LLVM::AllocaOp>(
9393
getLoc(), LLVM::LLVMPointerType::get(getContext()), elemType,
9494
getArraySize());
95+
newAllocators.push_back(subAlloca);
9596
slotMap.try_emplace<MemorySlot>(index, {subAlloca.getResult(), elemType});
9697
}
9798

9899
return slotMap;
99100
}
100101

101-
void LLVM::AllocaOp::handleDestructuringComplete(
102+
std::optional<DestructurableAllocationOpInterface>
103+
LLVM::AllocaOp::handleDestructuringComplete(
102104
const DestructurableMemorySlot &slot, OpBuilder &builder) {
103105
assert(slot.ptr == getResult());
104106
this->erase();
107+
return std::nullopt;
105108
}
106109

107110
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,10 @@ memref::AllocaOp::getDestructurableSlots() {
126126
DestructurableMemorySlot{{getMemref(), memrefType}, *destructuredType}};
127127
}
128128

129-
DenseMap<Attribute, MemorySlot>
130-
memref::AllocaOp::destructure(const DestructurableMemorySlot &slot,
131-
const SmallPtrSetImpl<Attribute> &usedIndices,
132-
OpBuilder &builder) {
129+
DenseMap<Attribute, MemorySlot> memref::AllocaOp::destructure(
130+
const DestructurableMemorySlot &slot,
131+
const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
132+
SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
133133
builder.setInsertionPointAfter(*this);
134134

135135
DenseMap<Attribute, MemorySlot> slotMap;
@@ -139,17 +139,20 @@ memref::AllocaOp::destructure(const DestructurableMemorySlot &slot,
139139
Type elemType = memrefType.getTypeAtIndex(usedIndex);
140140
MemRefType elemPtr = MemRefType::get({}, elemType);
141141
auto subAlloca = builder.create<memref::AllocaOp>(getLoc(), elemPtr);
142+
newAllocators.push_back(subAlloca);
142143
slotMap.try_emplace<MemorySlot>(usedIndex,
143144
{subAlloca.getResult(), elemType});
144145
}
145146

146147
return slotMap;
147148
}
148149

149-
void memref::AllocaOp::handleDestructuringComplete(
150+
std::optional<DestructurableAllocationOpInterface>
151+
memref::AllocaOp::handleDestructuringComplete(
150152
const DestructurableMemorySlot &slot, OpBuilder &builder) {
151153
assert(slot.ptr == getResult());
152154
this->erase();
155+
return std::nullopt;
153156
}
154157

155158
//===----------------------------------------------------------------------===//

mlir/lib/Transforms/SROA.cpp

Lines changed: 56 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -132,16 +132,17 @@ computeDestructuringInfo(DestructurableMemorySlot &slot,
132132
/// Performs the destructuring of a destructible slot given associated
133133
/// destructuring information. The provided slot will be destructured in
134134
/// subslots as specified by its allocator.
135-
static void destructureSlot(DestructurableMemorySlot &slot,
136-
DestructurableAllocationOpInterface allocator,
137-
OpBuilder &builder, const DataLayout &dataLayout,
138-
MemorySlotDestructuringInfo &info,
139-
const SROAStatistics &statistics) {
135+
static void destructureSlot(
136+
DestructurableMemorySlot &slot,
137+
DestructurableAllocationOpInterface allocator, OpBuilder &builder,
138+
const DataLayout &dataLayout, MemorySlotDestructuringInfo &info,
139+
SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators,
140+
const SROAStatistics &statistics) {
140141
OpBuilder::InsertionGuard guard(builder);
141142

142143
builder.setInsertionPointToStart(slot.ptr.getParentBlock());
143144
DenseMap<Attribute, MemorySlot> subslots =
144-
allocator.destructure(slot, info.usedIndices, builder);
145+
allocator.destructure(slot, info.usedIndices, builder, newAllocators);
145146

146147
if (statistics.slotsWithMemoryBenefit &&
147148
slot.elementPtrs.size() != info.usedIndices.size())
@@ -185,7 +186,11 @@ static void destructureSlot(DestructurableMemorySlot &slot,
185186
if (statistics.destructuredAmount)
186187
(*statistics.destructuredAmount)++;
187188

188-
allocator.handleDestructuringComplete(slot, builder);
189+
std::optional<DestructurableAllocationOpInterface> newAllocator =
190+
allocator.handleDestructuringComplete(slot, builder);
191+
// Add newly created allocators to the worklist for further processing.
192+
if (newAllocator)
193+
newAllocators.push_back(*newAllocator);
189194
}
190195

191196
LogicalResult mlir::tryToDestructureMemorySlots(
@@ -194,16 +199,44 @@ LogicalResult mlir::tryToDestructureMemorySlots(
194199
SROAStatistics statistics) {
195200
bool destructuredAny = false;
196201

197-
for (DestructurableAllocationOpInterface allocator : allocators) {
198-
for (DestructurableMemorySlot slot : allocator.getDestructurableSlots()) {
199-
std::optional<MemorySlotDestructuringInfo> info =
200-
computeDestructuringInfo(slot, dataLayout);
201-
if (!info)
202-
continue;
202+
SmallVector<DestructurableAllocationOpInterface> workList(allocators.begin(),
203+
allocators.end());
204+
SmallVector<DestructurableAllocationOpInterface> newWorkList;
205+
newWorkList.reserve(allocators.size());
206+
// Destructuring a slot can allow for further destructuring of other
207+
// slots, destructuring is tried until no destructuring succeeds.
208+
while (true) {
209+
bool changesInThisRound = false;
210+
211+
for (DestructurableAllocationOpInterface allocator : workList) {
212+
bool destructuredAnySlot = false;
213+
for (DestructurableMemorySlot slot : allocator.getDestructurableSlots()) {
214+
std::optional<MemorySlotDestructuringInfo> info =
215+
computeDestructuringInfo(slot, dataLayout);
216+
if (!info)
217+
continue;
203218

204-
destructureSlot(slot, allocator, builder, dataLayout, *info, statistics);
205-
destructuredAny = true;
219+
destructureSlot(slot, allocator, builder, dataLayout, *info,
220+
newWorkList, statistics);
221+
destructuredAnySlot = true;
222+
223+
// A break is required, since destructuring a slot may invalidate the
224+
// remaning slots of an allocator.
225+
break;
226+
}
227+
if (!destructuredAnySlot)
228+
newWorkList.push_back(allocator);
229+
changesInThisRound |= destructuredAnySlot;
206230
}
231+
232+
if (!changesInThisRound)
233+
break;
234+
destructuredAny |= changesInThisRound;
235+
236+
// Swap the vector's backing memory and clear the entries in newWorkList
237+
// afterwards. This ensures that additional heap allocations can be avoided.
238+
workList.swap(newWorkList);
239+
newWorkList.clear();
207240
}
208241

209242
return success(destructuredAny);
@@ -230,23 +263,16 @@ struct SROA : public impl::SROABase<SROA> {
230263

231264
OpBuilder builder(&region.front(), region.front().begin());
232265

233-
// Destructuring a slot can allow for further destructuring of other
234-
// slots, destructuring is tried until no destructuring succeeds.
235-
while (true) {
236-
SmallVector<DestructurableAllocationOpInterface> allocators;
237-
// Build a list of allocators to attempt to destructure the slots of.
238-
// TODO: Update list on the fly to avoid repeated visiting of the same
239-
// allocators.
240-
region.walk([&](DestructurableAllocationOpInterface allocator) {
241-
allocators.emplace_back(allocator);
242-
});
243-
244-
if (failed(tryToDestructureMemorySlots(allocators, builder, dataLayout,
245-
statistics)))
246-
break;
266+
SmallVector<DestructurableAllocationOpInterface> allocators;
267+
// Build a list of allocators to attempt to destructure the slots of.
268+
region.walk([&](DestructurableAllocationOpInterface allocator) {
269+
allocators.emplace_back(allocator);
270+
});
247271

272+
// Attempt to destructure as many slots as possible.
273+
if (succeeded(tryToDestructureMemorySlots(allocators, builder, dataLayout,
274+
statistics)))
248275
changed = true;
249-
}
250276
}
251277
if (!changed)
252278
markAllAnalysesPreserved();

mlir/test/Transforms/sroa.mlir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(sroa))' --split-input-file | FileCheck %s
2+
3+
// Verifies that allocators with mutliple slots are handled properly.
4+
5+
// CHECK-LABEL: func.func @multi_slot_alloca
6+
func.func @multi_slot_alloca() -> (i32, i32) {
7+
%0 = arith.constant 0 : index
8+
%1, %2 = test.multi_slot_alloca : () -> (memref<2xi32>, memref<4xi32>)
9+
// CHECK-COUNT-2: test.multi_slot_alloca : () -> memref<i32>
10+
%3 = memref.load %1[%0] {first}: memref<2xi32>
11+
%4 = memref.load %2[%0] {second} : memref<4xi32>
12+
return %3, %4 : i32, i32
13+
}
14+
15+
// -----
16+
17+
// Verifies that a multi slot allocator can be partially destructured.
18+
19+
func.func private @consumer(memref<2xi32>)
20+
21+
// CHECK-LABEL: func.func @multi_slot_alloca_only_second
22+
func.func @multi_slot_alloca_only_second() -> (i32, i32) {
23+
%0 = arith.constant 0 : index
24+
// CHECK: test.multi_slot_alloca : () -> memref<2xi32>
25+
// CHECK: test.multi_slot_alloca : () -> memref<i32>
26+
%1, %2 = test.multi_slot_alloca : () -> (memref<2xi32>, memref<4xi32>)
27+
func.call @consumer(%1) : (memref<2xi32>) -> ()
28+
%3 = memref.load %1[%0] : memref<2xi32>
29+
%4 = memref.load %2[%0] : memref<4xi32>
30+
return %3, %4 : i32, i32
31+
}

mlir/test/lib/Dialect/Test/TestOpDefs.cpp

Lines changed: 66 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1199,35 +1199,89 @@ void TestMultiSlotAlloca::handleBlockArgument(const MemorySlot &slot,
11991199
// Not relevant for testing.
12001200
}
12011201

1202-
std::optional<PromotableAllocationOpInterface>
1203-
TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot,
1204-
Value defaultValue,
1205-
OpBuilder &builder) {
1206-
if (defaultValue && defaultValue.use_empty())
1207-
defaultValue.getDefiningOp()->erase();
1202+
/// Creates a new TestMultiSlotAlloca operation, just without the `slot`.
1203+
static std::optional<TestMultiSlotAlloca>
1204+
createNewMultiAllocaWithoutSlot(const MemorySlot &slot, OpBuilder &builder,
1205+
TestMultiSlotAlloca oldOp) {
12081206

1209-
if (getNumResults() == 1) {
1210-
erase();
1207+
if (oldOp.getNumResults() == 1) {
1208+
oldOp.erase();
12111209
return std::nullopt;
12121210
}
12131211

12141212
SmallVector<Type> newTypes;
12151213
SmallVector<Value> remainingValues;
12161214

1217-
for (Value oldResult : getResults()) {
1215+
for (Value oldResult : oldOp.getResults()) {
12181216
if (oldResult == slot.ptr)
12191217
continue;
12201218
remainingValues.push_back(oldResult);
12211219
newTypes.push_back(oldResult.getType());
12221220
}
12231221

12241222
OpBuilder::InsertionGuard guard(builder);
1225-
builder.setInsertionPoint(*this);
1226-
auto replacement = builder.create<TestMultiSlotAlloca>(getLoc(), newTypes);
1223+
builder.setInsertionPoint(oldOp);
1224+
auto replacement =
1225+
builder.create<TestMultiSlotAlloca>(oldOp->getLoc(), newTypes);
12271226
for (auto [oldResult, newResult] :
12281227
llvm::zip_equal(remainingValues, replacement.getResults()))
12291228
oldResult.replaceAllUsesWith(newResult);
12301229

1231-
erase();
1230+
oldOp.erase();
12321231
return replacement;
12331232
}
1233+
1234+
std::optional<PromotableAllocationOpInterface>
1235+
TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot,
1236+
Value defaultValue,
1237+
OpBuilder &builder) {
1238+
if (defaultValue && defaultValue.use_empty())
1239+
defaultValue.getDefiningOp()->erase();
1240+
return createNewMultiAllocaWithoutSlot(slot, builder, *this);
1241+
}
1242+
1243+
SmallVector<DestructurableMemorySlot>
1244+
TestMultiSlotAlloca::getDestructurableSlots() {
1245+
SmallVector<DestructurableMemorySlot> slots;
1246+
for (Value result : getResults()) {
1247+
auto memrefType = cast<MemRefType>(result.getType());
1248+
auto destructurable = dyn_cast<DestructurableTypeInterface>(memrefType);
1249+
if (!destructurable)
1250+
continue;
1251+
1252+
std::optional<DenseMap<Attribute, Type>> destructuredType =
1253+
destructurable.getSubelementIndexMap();
1254+
if (!destructuredType)
1255+
continue;
1256+
slots.emplace_back(
1257+
DestructurableMemorySlot{{result, memrefType}, *destructuredType});
1258+
}
1259+
return slots;
1260+
}
1261+
1262+
DenseMap<Attribute, MemorySlot> TestMultiSlotAlloca::destructure(
1263+
const DestructurableMemorySlot &slot,
1264+
const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
1265+
SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
1266+
OpBuilder::InsertionGuard guard(builder);
1267+
builder.setInsertionPointAfter(*this);
1268+
1269+
DenseMap<Attribute, MemorySlot> slotMap;
1270+
1271+
for (Attribute usedIndex : usedIndices) {
1272+
Type elemType = slot.elementPtrs.lookup(usedIndex);
1273+
MemRefType elemPtr = MemRefType::get({}, elemType);
1274+
auto subAlloca = builder.create<TestMultiSlotAlloca>(getLoc(), elemPtr);
1275+
newAllocators.push_back(subAlloca);
1276+
slotMap.try_emplace<MemorySlot>(usedIndex,
1277+
{subAlloca.getResult(0), elemType});
1278+
}
1279+
1280+
return slotMap;
1281+
}
1282+
1283+
std::optional<DestructurableAllocationOpInterface>
1284+
TestMultiSlotAlloca::handleDestructuringComplete(
1285+
const DestructurableMemorySlot &slot, OpBuilder &builder) {
1286+
return createNewMultiAllocaWithoutSlot(slot, builder, *this);
1287+
}

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3169,11 +3169,12 @@ def TestOpOptionallyImplementingInterface
31693169
}
31703170

31713171
//===----------------------------------------------------------------------===//
3172-
// Test Mem2Reg
3172+
// Test Mem2Reg & SROA
31733173
//===----------------------------------------------------------------------===//
31743174

31753175
def TestMultiSlotAlloca : TEST_Op<"multi_slot_alloca",
3176-
[DeclareOpInterfaceMethods<PromotableAllocationOpInterface>]> {
3176+
[DeclareOpInterfaceMethods<PromotableAllocationOpInterface>,
3177+
DeclareOpInterfaceMethods<DestructurableAllocationOpInterface>]> {
31773178
let results = (outs Variadic<MemRefOf<[I32]>>:$results);
31783179
let assemblyFormat = "attr-dict `:` functional-type(operands, results)";
31793180
}

0 commit comments

Comments
 (0)