Skip to content

Commit f20cb3f

Browse files
authored
[mlir][bufferization] Drop the assumption for alloc result index (llvm#134503)
Relax the assumption that alloc op always has allocation at `getResult(0)`, allow to use `optimize-allocation-liveness` pass for custom ops with >1 results. Ops with multiple allocations are not handled here yet.
1 parent 44e32fb commit f20cb3f

File tree

3 files changed

+64
-3
lines changed

3 files changed

+64
-3
lines changed

mlir/lib/Dialect/Bufferization/Transforms/OptimizeAllocationLiveness.cpp

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
#include "mlir/Dialect/Func/IR/FuncOps.h"
1818
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1919
#include "mlir/IR/Operation.h"
20+
#include "mlir/IR/Value.h"
21+
#include "mlir/Interfaces/SideEffectInterfaces.h"
2022
#include "llvm/Support/Debug.h"
23+
#include "llvm/Support/ErrorHandling.h"
2124

2225
#define DEBUG_TYPE "optimize-allocation-liveness"
2326
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
@@ -88,6 +91,19 @@ static bool hasMemoryAllocEffect(MemoryEffectOpInterface memEffectOp) {
8891
return false;
8992
}
9093

94+
/// Extracts OpResult's with Allocate effects from given op
95+
static SmallVector<OpResult>
96+
collectAllocations(MemoryEffectOpInterface allocOp) {
97+
SmallVector<MemoryEffects::EffectInstance> effects;
98+
allocOp.getEffects(effects);
99+
SmallVector<OpResult> allocResults;
100+
for (const MemoryEffects::EffectInstance &it : effects)
101+
if (isa<MemoryEffects::Allocate>(it.getEffect()))
102+
if (auto val = it.getValue(); val && val.getDefiningOp() == allocOp)
103+
allocResults.push_back(cast<OpResult>(val));
104+
return allocResults;
105+
}
106+
91107
struct OptimizeAllocationLiveness
92108
: public bufferization::impl::OptimizeAllocationLivenessPassBase<
93109
OptimizeAllocationLiveness> {
@@ -109,7 +125,15 @@ struct OptimizeAllocationLiveness
109125
auto allocOp = memEffectOp;
110126
LDBG("Checking alloc op: " << allocOp);
111127

112-
auto deallocOp = findUserWithFreeSideEffect(allocOp->getResult(0));
128+
SmallVector<OpResult> allocationResults = collectAllocations(allocOp);
129+
// Multiple allocations from a single op are not considered here yet.
130+
if (allocationResults.size() != 1)
131+
return WalkResult::advance();
132+
133+
OpResult allocResult = allocationResults[0];
134+
LDBG("On allocation result: " << allocResult);
135+
136+
auto *deallocOp = findUserWithFreeSideEffect(allocResult);
113137
if (!deallocOp || (deallocOp->getBlock() != allocOp->getBlock())) {
114138
// The pass handles allocations that have a single dealloc op in the
115139
// same block. We also should not hoist the dealloc op out of
@@ -119,9 +143,9 @@ struct OptimizeAllocationLiveness
119143

120144
Operation *lastUser = nullptr;
121145
const BufferViewFlowAnalysis::ValueSetT &deps =
122-
analysis.resolve(allocOp->getResult(0));
146+
analysis.resolve(allocResult);
123147
for (auto dep : llvm::make_early_inc_range(deps)) {
124-
for (auto user : dep.getUsers()) {
148+
for (auto *user : dep.getUsers()) {
125149
// We are looking for a non dealloc op user.
126150
// check if user is the dealloc op itself.
127151
if (user == deallocOp)

mlir/test/Dialect/Bufferization/Transforms/optimize-allocation-liveness.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,3 +209,28 @@ func.func private @test_conditional_deallocation() -> memref<32xf32, 1> {
209209
return %3 : memref<32xf32, 1>
210210
}
211211

212+
213+
// -----
214+
// CHECK-LABEL: func.func private @test_alloc_with_multiple_results() {
215+
// CHECK: %[[ID1:.+]], %[[ALLOC1:.+]] = test.alloc_with_multiple_results : index, memref<64xf32>
216+
// CHECK: memref.expand_shape %[[ALLOC1]]
217+
// CHECK: memref.dealloc %[[ALLOC1]] : memref<64xf32>
218+
// CHECK: %[[ID2:.+]], %[[ALLOC2:.+]] = test.alloc_with_multiple_results : index, memref<64xf32>
219+
// CHECK: memref.expand_shape %[[ALLOC2]]
220+
// CHECK: memref.dealloc %[[ALLOC2]] : memref<64xf32>
221+
// CHECK: return
222+
// CHECK: }
223+
224+
// This test will check that allocations with multiple results and allocated
225+
// buffer at non-zero position are accepted.
226+
func.func private @test_alloc_with_multiple_results() -> () {
227+
%id1, %alloc1 = test.alloc_with_multiple_results : index, memref<64xf32>
228+
%expand_shape1 = memref.expand_shape %alloc1 [[0, 1]] output_shape [8, 8] : memref<64xf32> into memref<8x8xf32>
229+
230+
%id2, %alloc2 = test.alloc_with_multiple_results : index, memref<64xf32>
231+
%expand_shape2 = memref.expand_shape %alloc2 [[0, 1]] output_shape [8, 8] : memref<64xf32> into memref<8x8xf32>
232+
233+
memref.dealloc %alloc1 : memref<64xf32>
234+
memref.dealloc %alloc2 : memref<64xf32>
235+
return
236+
}

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3441,4 +3441,16 @@ def TestMultiSlotAlloca : TEST_Op<"multi_slot_alloca",
34413441
let assemblyFormat = "attr-dict `:` functional-type(operands, results)";
34423442
}
34433443

3444+
//===----------------------------------------------------------------------===//
3445+
// Test allocation Ops
3446+
//===----------------------------------------------------------------------===//
3447+
3448+
def TestAllocWithMultipleResults : TEST_Op<"alloc_with_multiple_results"> {
3449+
let results = (outs Index:$index,
3450+
Res<AnyMemRef, "", [MemAlloc]>:$memref);
3451+
let assemblyFormat = [{
3452+
attr-dict `:` type($index) `,` type($memref)
3453+
}];
3454+
}
3455+
34443456
#endif // TEST_OPS

0 commit comments

Comments
 (0)