Skip to content

Commit fff1830

Browse files
committed
[mlir][bufferization] Run the simple dealloc canonicalization patterns as part of BufferDeallocationSimplification
Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D158744
1 parent 778494a commit fff1830

File tree

4 files changed

+19
-12
lines changed

4 files changed

+19
-12
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ FailureOr<Value> castOrReallocMemRefValue(OpBuilder &b, Value value,
5858
LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter,
5959
ToMemrefOp toMemref);
6060

61+
/// Add the canonicalization patterns for bufferization.dealloc to the given
62+
/// pattern set to make them available to other passes (such as
63+
/// BufferDeallocationSimplification).
64+
void populateDeallocOpCanonicalizationPatterns(RewritePatternSet &patterns,
65+
MLIRContext *context);
66+
6167
} // namespace bufferization
6268
} // namespace mlir
6369

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,10 +1018,15 @@ struct RemoveAllocDeallocPairWhenNoOtherUsers
10181018

10191019
void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
10201020
MLIRContext *context) {
1021-
results.add<DeallocRemoveDuplicateDeallocMemrefs,
1022-
DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc,
1023-
EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc,
1024-
RemoveAllocDeallocPairWhenNoOtherUsers>(context);
1021+
populateDeallocOpCanonicalizationPatterns(results, context);
1022+
}
1023+
1024+
void bufferization::populateDeallocOpCanonicalizationPatterns(
1025+
RewritePatternSet &patterns, MLIRContext *context) {
1026+
patterns.add<DeallocRemoveDuplicateDeallocMemrefs,
1027+
DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc,
1028+
EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc,
1029+
RemoveAllocDeallocPairWhenNoOtherUsers>(context);
10251030
}
10261031

10271032
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,7 @@ struct BufferDeallocationSimplificationPass
432432
SplitDeallocWhenNotAliasingAnyOther,
433433
RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
434434
aliasAnalysis);
435+
populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
435436

436437
if (failed(
437438
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))

mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,13 @@ func.func @dealloc_deallocated_in_retained(%arg0: memref<2xi32>, %arg1: i1, %arg
1515
// CHECK-LABEL: func @dealloc_deallocated_in_retained
1616
// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>, [[ARG3:%.+]]: i1)
1717
// CHECK-NEXT: arith.constant false
18-
// CHECK-NEXT: bufferization.dealloc
1918
// CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>)
2019
// CHECK-NEXT: [[O1:%.+]] = arith.ori [[V1]], [[ARG1]]
2120
// CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[ARG0]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]], [[ARG2]] : memref<2xi32>, memref<2xi32>)
2221
// COM: the RemoveRetainedMemrefsGuaranteedToNotAlias pattern removes all the
2322
// COM: retained memrefs since the list of memrefs to be deallocated becomes empty
2423
// COM: due to the pattern under test (and thus there is no memref the retain values
2524
// COM: could alias to)
26-
// CHECK-NEXT: bufferization.dealloc
2725
// CHECK-NOT: if
2826
// CHECK-NEXT: [[V3:%.+]] = arith.ori [[ARG3]], [[ARG1]]
2927
// CHECK-NEXT: [[V4:%.+]] = arith.ori [[ARG3]], [[ARG1]]
@@ -50,27 +48,25 @@ func.func @dealloc_deallocated_in_retained_extract_base_memref(%arg0: memref<2xi
5048
// CHECK-NEXT: arith.constant false
5149
// CHECK-NEXT: [[BASE0:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG0]] :
5250
// CHECK-NEXT: [[BASE1:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG2]] :
53-
// CHECK-NEXT: bufferization.dealloc
5451
// CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[BASE1]] : memref<i32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>)
5552
// CHECK-NEXT: [[O1:%.+]] = arith.ori [[V1]], [[ARG1]]
5653
// CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[BASE0]] : memref<i32>) if ([[ARG1]]) retain ([[ARG0]], [[ARG2]] : memref<2xi32>, memref<2xi32>)
5754
// COM: the RemoveRetainedMemrefsGuaranteedToNotAlias pattern removes all the
5855
// COM: retained memrefs since the list of memrefs to be deallocated becomes empty
5956
// COM: due to the pattern under test (and thus there is no memref the retain values
6057
// COM: could alias to)
61-
// CHECK-NEXT: bufferization.dealloc
6258
// CHECK-NOT: if
6359
// CHECK-NEXT: [[V3:%.+]] = arith.ori [[ARG3]], [[ARG1]]
6460
// CHECK-NEXT: [[V4:%.+]] = arith.ori [[ARG3]], [[ARG1]]
6561
// CHECK-NEXT: return [[ARG1]], [[O1]], [[V2]]#0, [[V2]]#1, [[V3]], %false{{[0-9_]*}}, [[V4]] :
6662

6763
// -----
6864

69-
func.func @remove_retained_memrefs_guarateed_to_not_alias(%arg0: i1, %arg1: memref<2xi32>) -> (i1, i1) {
65+
func.func @remove_retained_memrefs_guarateed_to_not_alias(%arg0: i1, %arg1: memref<2xi32>) -> (i1, i1, memref<2xi32>) {
7066
%alloc = memref.alloc() : memref<2xi32>
7167
%alloc0 = memref.alloc() : memref<2xi32>
7268
%0:2 = bufferization.dealloc (%alloc : memref<2xi32>) if (%arg0) retain (%alloc0, %arg1 : memref<2xi32>, memref<2xi32>)
73-
return %0#0, %0#1 : i1, i1
69+
return %0#0, %0#1, %alloc : i1, i1, memref<2xi32>
7470
}
7571

7672
// CHECK-LABEL: func @remove_retained_memrefs_guarateed_to_not_alias
@@ -79,7 +75,7 @@ func.func @remove_retained_memrefs_guarateed_to_not_alias(%arg0: i1, %arg1: memr
7975
// CHECK-NEXT: [[ALLOC:%.+]] = memref.alloc(
8076
// CHECK-NEXT: bufferization.dealloc ([[ALLOC]] : memref<2xi32>) if ([[ARG0]])
8177
// CHECK-NOT: retain
82-
// CHECK-NEXT: return [[FALSE]], [[FALSE]] :
78+
// CHECK-NEXT: return [[FALSE]], [[FALSE]], [[ALLOC]] :
8379

8480
// -----
8581

@@ -104,7 +100,6 @@ func.func @dealloc_split_when_no_other_aliasing(%arg0: i1, %arg1: memref<2xi32>,
104100
// CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[ALLOC0]] : memref<2xi32>) if ([[ARG0]]) retain ([[V0]] : memref<2xi32>)
105101
// CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG3]]) retain ([[ARG1]], [[V0]] : memref<2xi32>, memref<2xi32>)
106102
// CHECK-NEXT: [[V3:%.+]] = arith.ori [[V1]], [[V2]]#1
107-
// CHECK-NEXT: bufferization.dealloc
108103
// CHECK-NEXT: return [[V2]]#0, [[V3]] :
109104

110105
// -----

0 commit comments

Comments
 (0)