Skip to content

Commit 87f2dee

Browse files
committed
[mlir][bufferization] Add DeallocationSimplification pass
Adds a pass that can be run after buffer deallocation to simplify the deallocation operations. In particular, there are patterns that need alias information and thus cannot be added as a regular canonicalization pattern. This initial commit moves an incorrect canonicalization pattern from over to this new pass and fixes it by querying the alias analysis for the additional information it needs to be correct (there must not by any potential aliasing memref in the retain list other than the currently mached one). Also, improves this pattern by considering the `extract_strided_metadata` operation which is inserted by the deallocation pass by default. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D157398
1 parent 5a3753f commit 87f2dee

File tree

7 files changed

+260
-62
lines changed

7 files changed

+260
-62
lines changed

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ struct OneShotBufferizationOptions;
2424
/// buffers.
2525
std::unique_ptr<Pass> createBufferDeallocationPass();
2626

27+
/// Creates a pass that optimizes `bufferization.dealloc` operations. For
28+
/// example, it reduces the number of alias checks needed at runtime using
29+
/// static alias analysis.
30+
std::unique_ptr<Pass> createBufferDeallocationSimplificationPass();
31+
2732
/// Run buffer deallocation.
2833
LogicalResult deallocateBuffers(Operation *op);
2934

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,26 @@ def BufferDeallocation : Pass<"buffer-deallocation", "func::FuncOp"> {
8888
let constructor = "mlir::bufferization::createBufferDeallocationPass()";
8989
}
9090

91+
def BufferDeallocationSimplification :
92+
Pass<"buffer-deallocation-simplification", "func::FuncOp"> {
93+
let summary = "Optimizes `bufferization.dealloc` operation for more "
94+
"efficient codegen";
95+
let description = [{
96+
This pass uses static alias analysis to reduce the number of alias checks
97+
required at runtime. Such checks are sometimes necessary to make sure that
98+
memrefs aren't deallocated before their last usage (use after free) or that
99+
some memref isn't deallocated twice (double free).
100+
}];
101+
102+
let constructor =
103+
"mlir::bufferization::createBufferDeallocationSimplificationPass()";
104+
105+
let dependentDialects = [
106+
"mlir::bufferization::BufferizationDialect", "mlir::arith::ArithDialect",
107+
"mlir::memref::MemRefDialect"
108+
];
109+
}
110+
91111
def BufferHoisting : Pass<"buffer-hoisting", "func::FuncOp"> {
92112
let summary = "Optimizes placement of allocation operations by moving them "
93113
"into common dominators and out of nested regions";

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 1 addition & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -869,57 +869,6 @@ struct DeallocRemoveDuplicateRetainedMemrefs
869869
}
870870
};
871871

872-
/// Remove memrefs to be deallocated that are also present in the retained list
873-
/// since they will always alias and thus never actually be deallocated.
874-
/// Example:
875-
/// ```mlir
876-
/// %0 = bufferization.dealloc (%arg0 : ...) if (%arg1) retain (%arg0 : ...)
877-
/// ```
878-
/// is canonicalized to
879-
/// ```mlir
880-
/// %0 = bufferization.dealloc retain (%arg0 : ...)
881-
/// ```
882-
struct DeallocRemoveDeallocMemrefsContainedInRetained
883-
: public OpRewritePattern<DeallocOp> {
884-
using OpRewritePattern<DeallocOp>::OpRewritePattern;
885-
886-
LogicalResult matchAndRewrite(DeallocOp deallocOp,
887-
PatternRewriter &rewriter) const override {
888-
// Unique memrefs to be deallocated.
889-
DenseMap<Value, unsigned> retained;
890-
for (auto [i, ret] : llvm::enumerate(deallocOp.getRetained()))
891-
retained[ret] = i;
892-
893-
// There must not be any duplicates in the retain list anymore because we
894-
// would miss updating one of the result values otherwise.
895-
if (retained.size() != deallocOp.getRetained().size())
896-
return failure();
897-
898-
SmallVector<Value> newMemrefs, newConditions;
899-
for (auto [memref, cond] :
900-
llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
901-
if (retained.contains(memref)) {
902-
rewriter.setInsertionPointAfter(deallocOp);
903-
auto orOp = rewriter.create<arith::OrIOp>(
904-
deallocOp.getLoc(),
905-
deallocOp.getUpdatedConditions()[retained[memref]], cond);
906-
rewriter.replaceAllUsesExcept(
907-
deallocOp.getUpdatedConditions()[retained[memref]],
908-
orOp.getResult(), orOp);
909-
continue;
910-
}
911-
912-
newMemrefs.push_back(memref);
913-
newConditions.push_back(cond);
914-
}
915-
916-
// Return failure if we don't change anything such that we don't run into an
917-
// infinite loop of pattern applications.
918-
return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
919-
rewriter);
920-
}
921-
};
922-
923872
/// Erase deallocation operations where the variadic list of memrefs to
924873
/// deallocate is empty. Example:
925874
/// ```mlir
@@ -1021,8 +970,7 @@ struct SkipExtractMetadataOfAlloc : public OpRewritePattern<DeallocOp> {
1021970
void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
1022971
MLIRContext *context) {
1023972
results.add<DeallocRemoveDuplicateDeallocMemrefs,
1024-
DeallocRemoveDuplicateRetainedMemrefs,
1025-
DeallocRemoveDeallocMemrefsContainedInRetained, EraseEmptyDealloc,
973+
DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc,
1026974
EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc>(context);
1027975
}
1028976

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
//===- BufferDeallocationSimplification.cpp -------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements logic for optimizing `bufferization.dealloc` operations
10+
// that requires more analysis than what can be supported by regular
11+
// canonicalization patterns.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#include "mlir/Analysis/AliasAnalysis.h"
16+
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
17+
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
18+
#include "mlir/Dialect/Func/IR/FuncOps.h"
19+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
20+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21+
22+
namespace mlir {
23+
namespace bufferization {
24+
#define GEN_PASS_DEF_BUFFERDEALLOCATIONSIMPLIFICATION
25+
#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
26+
} // namespace bufferization
27+
} // namespace mlir
28+
29+
using namespace mlir;
30+
using namespace mlir::bufferization;
31+
32+
//===----------------------------------------------------------------------===//
33+
// Helpers
34+
//===----------------------------------------------------------------------===//
35+
36+
static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
37+
ValueRange memrefs,
38+
ValueRange conditions,
39+
PatternRewriter &rewriter) {
40+
if (deallocOp.getMemrefs() == memrefs &&
41+
deallocOp.getConditions() == conditions)
42+
return failure();
43+
44+
rewriter.updateRootInPlace(deallocOp, [&]() {
45+
deallocOp.getMemrefsMutable().assign(memrefs);
46+
deallocOp.getConditionsMutable().assign(conditions);
47+
});
48+
return success();
49+
}
50+
51+
//===----------------------------------------------------------------------===//
52+
// Patterns
53+
//===----------------------------------------------------------------------===//
54+
55+
namespace {
56+
57+
/// Remove values from the `memref` operand list that are also present in the
58+
/// `retained` list since they will always alias and thus never actually be
59+
/// deallocated. However, we also need to be certain that no other value in the
60+
/// `retained` list can alias, for which we use a static alias analysis. This is
61+
/// necessary because the `dealloc` operation is defined to return one `i1`
62+
/// value per memref in the `retained` list which represents the disjunction of
63+
/// the condition values corresponding to all aliasing values in the `memref`
64+
/// list. In particular, this means that if there is some value R in the
65+
/// `retained` list which aliases with a value M in the `memref` list (but can
66+
/// only be staticaly determined to may-alias) and M is also present in the
67+
/// `retained` list, then it would be illegal to remove M because the result
68+
/// corresponding to R would be computed incorrectly afterwards.
69+
/// Because we require an alias analysis, this pattern cannot be applied as a
70+
/// regular canonicalization pattern.
71+
///
72+
/// Example:
73+
/// ```mlir
74+
/// %0:3 = bufferization.dealloc (%m0 : ...) if (%cond0)
75+
/// retain (%m0, %r0, %r1 : ...)
76+
/// ```
77+
/// is canonicalized to
78+
/// ```mlir
79+
/// // bufferization.dealloc without memrefs and conditions returns %false for
80+
/// // every retained value
81+
/// %0:3 = bufferization.dealloc retain (%m0, %r0, %r1 : ...)
82+
/// %1 = arith.ori %0#0, %cond0 : i1
83+
/// // replace %0#0 with %1
84+
/// ```
85+
/// given that `%r0` and `%r1` may not alias with `%m0`.
86+
struct DeallocRemoveDeallocMemrefsContainedInRetained
87+
: public OpRewritePattern<DeallocOp> {
88+
DeallocRemoveDeallocMemrefsContainedInRetained(MLIRContext *context,
89+
AliasAnalysis &aliasAnalysis)
90+
: OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
91+
92+
LogicalResult matchAndRewrite(DeallocOp deallocOp,
93+
PatternRewriter &rewriter) const override {
94+
// Unique memrefs to be deallocated.
95+
DenseMap<Value, unsigned> retained;
96+
for (auto [i, ret] : llvm::enumerate(deallocOp.getRetained()))
97+
retained[ret] = i;
98+
99+
// There must not be any duplicates in the retain list anymore because we
100+
// would miss updating one of the result values otherwise.
101+
if (retained.size() != deallocOp.getRetained().size())
102+
return failure();
103+
104+
SmallVector<Value> newMemrefs, newConditions;
105+
for (auto memrefAndCond :
106+
llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
107+
Value memref = std::get<0>(memrefAndCond);
108+
Value cond = std::get<1>(memrefAndCond);
109+
110+
auto replaceResultsIfNoInvalidAliasing = [&](Value memref) -> bool {
111+
Value retainedMemref = deallocOp.getRetained()[retained[memref]];
112+
// The current memref must not have a may-alias relation to any retained
113+
// memref, and exactly one must-alias relation.
114+
// TODO: it is possible to extend this pattern to allow an arbitrary
115+
// number of must-alias relations as long as there is no may-alias. If
116+
// it's no-alias, then just proceed (only supported case as of now), if
117+
// it's must-alias, we also need to update the condition for that alias.
118+
if (llvm::all_of(deallocOp.getRetained(), [&](Value mr) {
119+
return aliasAnalysis.alias(mr, memref).isNo() ||
120+
mr == retainedMemref;
121+
})) {
122+
rewriter.setInsertionPointAfter(deallocOp);
123+
auto orOp = rewriter.create<arith::OrIOp>(
124+
deallocOp.getLoc(),
125+
deallocOp.getUpdatedConditions()[retained[memref]], cond);
126+
rewriter.replaceAllUsesExcept(
127+
deallocOp.getUpdatedConditions()[retained[memref]],
128+
orOp.getResult(), orOp);
129+
return true;
130+
}
131+
return false;
132+
};
133+
134+
if (retained.contains(memref) &&
135+
replaceResultsIfNoInvalidAliasing(memref))
136+
continue;
137+
138+
auto extractOp = memref.getDefiningOp<memref::ExtractStridedMetadataOp>();
139+
if (extractOp && retained.contains(extractOp.getOperand()) &&
140+
replaceResultsIfNoInvalidAliasing(extractOp.getOperand()))
141+
continue;
142+
143+
newMemrefs.push_back(memref);
144+
newConditions.push_back(cond);
145+
}
146+
147+
// Return failure if we don't change anything such that we don't run into an
148+
// infinite loop of pattern applications.
149+
return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
150+
rewriter);
151+
}
152+
153+
private:
154+
AliasAnalysis &aliasAnalysis;
155+
};
156+
157+
} // namespace
158+
159+
//===----------------------------------------------------------------------===//
160+
// BufferDeallocationSimplificationPass
161+
//===----------------------------------------------------------------------===//
162+
163+
namespace {
164+
165+
/// The actual buffer deallocation pass that inserts and moves dealloc nodes
166+
/// into the right positions. Furthermore, it inserts additional clones if
167+
/// necessary. It uses the algorithm described at the top of the file.
168+
struct BufferDeallocationSimplificationPass
169+
: public bufferization::impl::BufferDeallocationSimplificationBase<
170+
BufferDeallocationSimplificationPass> {
171+
void runOnOperation() override {
172+
AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>();
173+
RewritePatternSet patterns(&getContext());
174+
patterns.add<DeallocRemoveDeallocMemrefsContainedInRetained>(&getContext(),
175+
aliasAnalysis);
176+
177+
if (failed(
178+
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
179+
signalPassFailure();
180+
}
181+
};
182+
183+
} // namespace
184+
185+
std::unique_ptr<Pass>
186+
mlir::bufferization::createBufferDeallocationSimplificationPass() {
187+
return std::make_unique<BufferDeallocationSimplificationPass>();
188+
}

mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
add_mlir_dialect_library(MLIRBufferizationTransforms
22
Bufferize.cpp
33
BufferDeallocation.cpp
4+
BufferDeallocationSimplification.cpp
45
BufferOptimizations.cpp
56
BufferResultsToOutParams.cpp
67
BufferUtils.cpp
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// RUN: mlir-opt %s --buffer-deallocation-simplification --split-input-file | FileCheck %s
2+
3+
func.func @dealloc_deallocated_in_retained(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>) -> (i1, i1, i1, i1) {
4+
%0 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg1) retain (%arg0 : memref<2xi32>)
5+
%1 = bufferization.dealloc (%arg0, %arg2 : memref<2xi32>, memref<2xi32>) if (%arg1, %arg1) retain (%arg0 : memref<2xi32>)
6+
%2:2 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg1) retain (%arg0, %arg2 : memref<2xi32>, memref<2xi32>)
7+
return %0, %1, %2#0, %2#1 : i1, i1, i1, i1
8+
}
9+
10+
// CHECK-LABEL: func @dealloc_deallocated_in_retained
11+
// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>)
12+
// CHECK-NEXT: [[V0:%.+]] = bufferization.dealloc retain ([[ARG0]] : memref<2xi32>)
13+
// CHECK-NEXT: [[O0:%.+]] = arith.ori [[V0]], [[ARG1]]
14+
// CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>)
15+
// CHECK-NEXT: [[O1:%.+]] = arith.ori [[V1]], [[ARG1]]
16+
// CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[ARG0]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]], [[ARG2]] : memref<2xi32>, memref<2xi32>)
17+
// CHECK-NEXT: return [[O0]], [[O1]], [[V2]]#0, [[V2]]#1 :
18+
19+
// -----
20+
21+
func.func @dealloc_deallocated_in_retained_extract_base_memref(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>) -> (i1, i1, i1, i1) {
22+
%base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %arg0 : memref<2xi32> -> memref<i32>, index, index, index
23+
%base_buffer0, %offset0, %size0, %stride0 = memref.extract_strided_metadata %arg2 : memref<2xi32> -> memref<i32>, index, index, index
24+
%0 = bufferization.dealloc (%base_buffer : memref<i32>) if (%arg1) retain (%arg0 : memref<2xi32>)
25+
%1 = bufferization.dealloc (%base_buffer, %base_buffer0 : memref<i32>, memref<i32>) if (%arg1, %arg1) retain (%arg0 : memref<2xi32>)
26+
%2:2 = bufferization.dealloc (%base_buffer : memref<i32>) if (%arg1) retain (%arg0, %arg2 : memref<2xi32>, memref<2xi32>)
27+
return %0, %1, %2#0, %2#1 : i1, i1, i1, i1
28+
}
29+
30+
// CHECK-LABEL: func @dealloc_deallocated_in_retained_extract_base_memref
31+
// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>)
32+
// CHECK-NEXT: [[BASE0:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG0]] :
33+
// CHECK-NEXT: [[BASE1:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG2]] :
34+
// CHECK-NEXT: [[V0:%.+]] = bufferization.dealloc retain ([[ARG0]] : memref<2xi32>)
35+
// CHECK-NEXT: [[O0:%.+]] = arith.ori [[V0]], [[ARG1]]
36+
// CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[BASE1]] : memref<i32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>)
37+
// CHECK-NEXT: [[O1:%.+]] = arith.ori [[V1]], [[ARG1]]
38+
// CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[BASE0]] : memref<i32>) if ([[ARG1]]) retain ([[ARG0]], [[ARG2]] : memref<2xi32>, memref<2xi32>)
39+
// CHECK-NEXT: return [[O0]], [[O1]], [[V2]]#0, [[V2]]#1 :

mlir/test/Dialect/Bufferization/canonicalize.mlir

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -297,19 +297,16 @@ func.func @dealloc_canonicalize_duplicates(%arg0: memref<2xi32>, %arg1: i1, %arg
297297

298298
// -----
299299

300-
func.func @dealloc_canonicalize_retained_and_deallocated(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>) -> (i1, i1) {
301-
%0 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg1) retain (%arg0 : memref<2xi32>)
302-
%1 = bufferization.dealloc (%arg0, %arg2 : memref<2xi32>, memref<2xi32>) if (%arg1, %arg1) retain (%arg0 : memref<2xi32>)
300+
func.func @dealloc_erase_empty(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>) -> i1 {
303301
bufferization.dealloc
304-
bufferization.dealloc retain (%arg0 : memref<2xi32>)
305-
return %0, %1 : i1, i1
302+
%0 = bufferization.dealloc retain (%arg0 : memref<2xi32>)
303+
return %0 : i1
306304
}
307305

308-
// CHECK-LABEL: func @dealloc_canonicalize_retained_and_deallocated
306+
// CHECK-LABEL: func @dealloc_erase_empty
309307
// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>)
310-
// CHECK-NEXT: [[V0:%.+]] = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>)
311-
// CHECK-NEXT: [[V1:%.+]] = arith.ori [[V0]], [[ARG1]]
312-
// CHECK-NEXT: return [[ARG1]], [[V1]] :
308+
// CHECK-NEXT: [[FALSE:%.+]] = arith.constant false
309+
// CHECK-NEXT: return [[FALSE]] :
313310

314311
// -----
315312

0 commit comments

Comments
 (0)