Skip to content

Commit 0bcae5e

Browse files
committed
[mlir][bufferization] Add pattern to BufferDeallocationSimplification pass
Add a pattern that splits one dealloc operation into multiple dealloc operation depending on static aliasing information of the values in the `memref` operand list. This reduces the total number of aliasing checks required at runtime and can enable futher canonicalizations of the new and simplified dealloc operations. Depends on D157407 Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D157508
1 parent 211ed03 commit 0bcae5e

File tree

2 files changed

+128
-12
lines changed

2 files changed

+128
-12
lines changed

mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp

Lines changed: 102 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,24 @@ static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
4848
return success();
4949
}
5050

51+
/// Checks if 'memref' may or must alias a MemRef in 'memrefList'. It is often a
52+
/// requirement of optimization patterns that there cannot be any aliasing
53+
/// memref in order to perform the desired simplification. The 'allowSelfAlias'
54+
/// argument indicates whether 'memref' may be present in 'memrefList' which
55+
/// makes this helper function applicable to situations where we already know
56+
/// that 'memref' is in the list but also when we don't want it in the list.
57+
static bool potentiallyAliasesMemref(AliasAnalysis &analysis,
58+
ValueRange memrefList, Value memref,
59+
bool allowSelfAlias) {
60+
for (auto mr : memrefList) {
61+
if (allowSelfAlias && mr == memref)
62+
continue;
63+
if (!analysis.alias(mr, memref).isNo())
64+
return true;
65+
}
66+
return false;
67+
}
68+
5169
//===----------------------------------------------------------------------===//
5270
// Patterns
5371
//===----------------------------------------------------------------------===//
@@ -176,15 +194,6 @@ struct RemoveRetainedMemrefsGuaranteedToNotAlias
176194
AliasAnalysis &aliasAnalysis)
177195
: OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
178196

179-
bool potentiallyAliasesMemref(DeallocOp deallocOp,
180-
Value retainedMemref) const {
181-
for (auto memref : deallocOp.getMemrefs()) {
182-
if (!aliasAnalysis.alias(memref, retainedMemref).isNo())
183-
return true;
184-
}
185-
return false;
186-
}
187-
188197
LogicalResult matchAndRewrite(DeallocOp deallocOp,
189198
PatternRewriter &rewriter) const override {
190199
SmallVector<Value> newRetainedMemrefs, replacements;
@@ -197,7 +206,8 @@ struct RemoveRetainedMemrefsGuaranteedToNotAlias
197206
};
198207

199208
for (auto retainedMemref : deallocOp.getRetained()) {
200-
if (potentiallyAliasesMemref(deallocOp, retainedMemref)) {
209+
if (potentiallyAliasesMemref(aliasAnalysis, deallocOp.getMemrefs(),
210+
retainedMemref, false)) {
201211
newRetainedMemrefs.push_back(retainedMemref);
202212
replacements.push_back({});
203213
continue;
@@ -226,6 +236,85 @@ struct RemoveRetainedMemrefsGuaranteedToNotAlias
226236
AliasAnalysis &aliasAnalysis;
227237
};
228238

239+
/// Split off memrefs to separate dealloc operations to reduce the number of
240+
/// runtime checks required and enable further canonicalization of the new and
241+
/// simpler dealloc operations. A memref can be split off if it is guaranteed to
242+
/// not alias with any other memref in the `memref` operand list. The results
243+
/// of the old and the new dealloc operation have to be combined by computing
244+
/// the element-wise disjunction of them.
245+
///
246+
/// Example:
247+
/// ```mlir
248+
/// %0:2 = bufferization.dealloc (%m0, %m1 : memref<2xi32>, memref<2xi32>)
249+
/// if (%cond0, %cond1)
250+
/// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
251+
/// return %0#0, %0#1
252+
/// ```
253+
/// Given that `%m0` is guaranteed to never alias with `%m1`, the above IR is
254+
/// canonicalized to the following, thus reducing the number of runtime alias
255+
/// checks by 1 and potentially enabling further canonicalization of the new
256+
/// split-up dealloc operations.
257+
/// ```mlir
258+
/// %0:2 = bufferization.dealloc (%m0 : memref<2xi32>) if (%cond0)
259+
/// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
260+
/// %1:2 = bufferization.dealloc (%m1 : memref<2xi32>) if (%cond1)
261+
/// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
262+
/// %2 = arith.ori %0#0, %1#0
263+
/// %3 = arith.ori %0#1, %1#1
264+
/// return %2, %3
265+
/// ```
266+
struct SplitDeallocWhenNotAliasingAnyOther
267+
: public OpRewritePattern<DeallocOp> {
268+
SplitDeallocWhenNotAliasingAnyOther(MLIRContext *context,
269+
AliasAnalysis &aliasAnalysis)
270+
: OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
271+
272+
LogicalResult matchAndRewrite(DeallocOp deallocOp,
273+
PatternRewriter &rewriter) const override {
274+
if (deallocOp.getMemrefs().size() <= 1)
275+
return failure();
276+
277+
SmallVector<Value> newMemrefs, newConditions, replacements;
278+
DenseSet<Operation *> exceptedUsers;
279+
replacements = deallocOp.getUpdatedConditions();
280+
for (auto [memref, cond] :
281+
llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
282+
if (potentiallyAliasesMemref(aliasAnalysis, deallocOp.getMemrefs(),
283+
memref, true)) {
284+
newMemrefs.push_back(memref);
285+
newConditions.push_back(cond);
286+
continue;
287+
}
288+
289+
auto newDeallocOp = rewriter.create<DeallocOp>(
290+
deallocOp.getLoc(), memref, cond, deallocOp.getRetained());
291+
replacements = SmallVector<Value>(llvm::map_range(
292+
llvm::zip(replacements, newDeallocOp.getUpdatedConditions()),
293+
[&](auto replAndNew) -> Value {
294+
auto orOp = rewriter.create<arith::OrIOp>(deallocOp.getLoc(),
295+
std::get<0>(replAndNew),
296+
std::get<1>(replAndNew));
297+
exceptedUsers.insert(orOp);
298+
return orOp.getResult();
299+
}));
300+
}
301+
302+
if (newMemrefs.size() == deallocOp.getMemrefs().size())
303+
return failure();
304+
305+
rewriter.replaceUsesWithIf(deallocOp.getUpdatedConditions(), replacements,
306+
[&](OpOperand &operand) {
307+
return !exceptedUsers.contains(
308+
operand.getOwner());
309+
});
310+
return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
311+
rewriter);
312+
}
313+
314+
private:
315+
AliasAnalysis &aliasAnalysis;
316+
};
317+
229318
} // namespace
230319

231320
//===----------------------------------------------------------------------===//
@@ -244,8 +333,9 @@ struct BufferDeallocationSimplificationPass
244333
AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>();
245334
RewritePatternSet patterns(&getContext());
246335
patterns.add<DeallocRemoveDeallocMemrefsContainedInRetained,
247-
RemoveRetainedMemrefsGuaranteedToNotAlias>(&getContext(),
248-
aliasAnalysis);
336+
RemoveRetainedMemrefsGuaranteedToNotAlias,
337+
SplitDeallocWhenNotAliasingAnyOther>(&getContext(),
338+
aliasAnalysis);
249339

250340
if (failed(
251341
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))

mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,29 @@ func.func @dealloc_deallocated_in_retained(%arg0: i1, %arg1: memref<2xi32>) -> (
5252
// CHECK-NEXT: bufferization.dealloc ([[ALLOC]] : memref<2xi32>) if ([[ARG0]])
5353
// CHECK-NOT: retain
5454
// CHECK-NEXT: return [[FALSE]], [[FALSE]] :
55+
56+
// -----
57+
58+
func.func @dealloc_split_when_no_other_aliasing(%arg0: i1, %arg1: memref<2xi32>, %arg2: memref<2xi32>, %arg3: i1) -> (i1, i1) {
59+
%alloc = memref.alloc() : memref<2xi32>
60+
%alloc0 = memref.alloc() : memref<2xi32>
61+
%0 = arith.select %arg0, %alloc, %alloc0 : memref<2xi32>
62+
%1:2 = bufferization.dealloc (%alloc, %arg2 : memref<2xi32>, memref<2xi32>) if (%arg0, %arg3) retain (%arg1, %0 : memref<2xi32>, memref<2xi32>)
63+
return %1#0, %1#1 : i1, i1
64+
}
65+
66+
// CHECK-LABEL: func @dealloc_split_when_no_other_aliasing
67+
// CHECK-SAME: ([[ARG0:%.+]]: i1, [[ARG1:%.+]]: memref<2xi32>, [[ARG2:%.+]]: memref<2xi32>, [[ARG3:%.+]]: i1)
68+
// CHECK-NEXT: [[ALLOC0:%.+]] = memref.alloc(
69+
// CHECK-NEXT: [[ALLOC1:%.+]] = memref.alloc(
70+
// CHECK-NEXT: [[V0:%.+]] = arith.select{{.*}}[[ALLOC0]], [[ALLOC1]] :
71+
// COM: there is only one value in the retained list because the
72+
// COM: RemoveRetainedMemrefsGuaranteedToNotAlias pattern also applies here and
73+
// COM: removes %arg1 from the list. In the second dealloc, this does not apply
74+
// COM: because function arguments are assumed potentially alias (even if the
75+
// COM: types don't exactly match).
76+
// CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[ALLOC0]] : memref<2xi32>) if ([[ARG0]]) retain ([[V0]] : memref<2xi32>)
77+
// CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG3]]) retain ([[ARG1]], [[V0]] : memref<2xi32>, memref<2xi32>)
78+
// CHECK-NEXT: [[V3:%.+]] = arith.ori [[V1]], [[V2]]#1
79+
// CHECK-NEXT: bufferization.dealloc
80+
// CHECK-NEXT: return [[V2]]#0, [[V3]] :

0 commit comments

Comments
 (0)