Skip to content

Commit 778494a

Browse files
committed
[mlir][bufferization] Add bufferization.dealloc canonicalizer to remove unused alloc-dealloc pairs
Deallocation operations where the allocated value is the 'memref' and 'retained' list are currently not supported. This is because when values are in the retained list, they typically have a use-site at a later point and another deallocation op exists at that later point to free the memref then. There alrady exists a canonicalization pattern in the buffer deallocation simplification pass that removes the allocated value from the earlier dealloc because it will never be actually deallocated in that case and thus does not have to be considered in this new pattern. Differential Revision: https://reviews.llvm.org/D158740
1 parent ad7e250 commit 778494a

File tree

2 files changed

+69
-3
lines changed

2 files changed

+69
-3
lines changed

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -963,13 +963,65 @@ struct SkipExtractMetadataOfAlloc : public OpRewritePattern<DeallocOp> {
963963
}
964964
};
965965

966+
/// Removes pairs of `bufferization.dealloc` and alloc operations if there is no
967+
/// other user of the allocated value and the allocating operation can be safely
968+
/// removed. If the same value is present multiple times, this pattern relies on
969+
/// other canonicalization patterns to remove the duplicate first.
970+
///
971+
/// Example:
972+
/// ```mlir
973+
/// %alloc = memref.alloc() : memref<2xi32>
974+
/// bufferization.dealloc (%alloc, %arg0, : ...) if (%true, %true)
975+
/// ```
976+
/// is canonicalized to
977+
/// ```mlir
978+
/// bufferization.dealloc (%arg0 : ...) if (%true)
979+
/// ```
980+
struct RemoveAllocDeallocPairWhenNoOtherUsers
981+
: public OpRewritePattern<DeallocOp> {
982+
using OpRewritePattern<DeallocOp>::OpRewritePattern;
983+
984+
LogicalResult matchAndRewrite(DeallocOp deallocOp,
985+
PatternRewriter &rewriter) const override {
986+
SmallVector<Value> newMemrefs, newConditions;
987+
SmallVector<Operation *> toDelete;
988+
for (auto [memref, cond] :
989+
llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
990+
if (auto allocOp = memref.getDefiningOp<MemoryEffectOpInterface>()) {
991+
// Check that it is indeed an allocate effect, that the op has no other
992+
// side effects (which would not allow us to remove the op), and that
993+
// there are no other users.
994+
if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(memref) &&
995+
hasSingleEffect<MemoryEffects::Allocate>(allocOp, memref) &&
996+
memref.hasOneUse()) {
997+
toDelete.push_back(allocOp);
998+
continue;
999+
}
1000+
}
1001+
1002+
newMemrefs.push_back(memref);
1003+
newConditions.push_back(cond);
1004+
}
1005+
1006+
if (failed(updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
1007+
rewriter)))
1008+
return failure();
1009+
1010+
for (Operation *op : toDelete)
1011+
rewriter.eraseOp(op);
1012+
1013+
return success();
1014+
}
1015+
};
1016+
9661017
} // anonymous namespace
9671018

9681019
void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
9691020
MLIRContext *context) {
9701021
results.add<DeallocRemoveDuplicateDeallocMemrefs,
9711022
DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc,
972-
EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc>(context);
1023+
EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc,
1024+
RemoveAllocDeallocPairWhenNoOtherUsers>(context);
9731025
}
9741026

9751027
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Bufferization/canonicalize.mlir

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,12 +323,12 @@ func.func @dealloc_always_false_condition(%arg0: memref<2xi32>, %arg1: memref<2x
323323

324324
// -----
325325

326-
func.func @dealloc_base_memref_extract_of_alloc(%arg0: memref<2xi32>, %arg1: i1, %arg2: i1, %arg3: memref<2xi32>) {
326+
func.func @dealloc_base_memref_extract_of_alloc(%arg0: memref<2xi32>, %arg1: i1, %arg2: i1, %arg3: memref<2xi32>) -> memref<2xi32> {
327327
%alloc = memref.alloc() : memref<2xi32>
328328
%base0, %size0, %stride0, %offset0 = memref.extract_strided_metadata %alloc : memref<2xi32> -> memref<i32>, index, index, index
329329
%base1, %size1, %stride1, %offset1 = memref.extract_strided_metadata %arg3 : memref<2xi32> -> memref<i32>, index, index, index
330330
bufferization.dealloc (%base0, %arg0, %base1 : memref<i32>, memref<2xi32>, memref<i32>) if (%arg1, %arg2, %arg2)
331-
return
331+
return %alloc : memref<2xi32>
332332
}
333333

334334
// CHECK-LABEL: func @dealloc_base_memref_extract_of_alloc
@@ -337,3 +337,17 @@ func.func @dealloc_base_memref_extract_of_alloc(%arg0: memref<2xi32>, %arg1: i1,
337337
// CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG3]] :
338338
// CHECK-NEXT: bufferization.dealloc ([[ALLOC]], [[ARG0]], [[BASE]] : memref<2xi32>, memref<2xi32>, memref<i32>) if ([[ARG1]], [[ARG2]], [[ARG2]])
339339
// CHECK-NEXT: return
340+
341+
// -----
342+
343+
func.func @dealloc_base_memref_extract_of_alloc(%arg0: memref<2xi32>) {
344+
%true = arith.constant true
345+
%alloc = memref.alloc() : memref<2xi32>
346+
bufferization.dealloc (%alloc, %arg0 : memref<2xi32>, memref<2xi32>) if (%true, %true)
347+
return
348+
}
349+
350+
// CHECK-LABEL: func @dealloc_base_memref_extract_of_alloc
351+
// CHECK-SAME:([[ARG0:%.+]]: memref<2xi32>)
352+
// CHECK-NOT: memref.alloc(
353+
// CHECK: bufferization.dealloc ([[ARG0]] : memref<2xi32>) if (%true

0 commit comments

Comments
 (0)