@@ -48,6 +48,24 @@ static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
48
48
return success ();
49
49
}
50
50
51
+ // / Checks if 'memref' may or must alias a MemRef in 'memrefList'. It is often a
52
+ // / requirement of optimization patterns that there cannot be any aliasing
53
+ // / memref in order to perform the desired simplification. The 'allowSelfAlias'
54
+ // / argument indicates whether 'memref' may be present in 'memrefList' which
55
+ // / makes this helper function applicable to situations where we already know
56
+ // / that 'memref' is in the list but also when we don't want it in the list.
57
+ static bool potentiallyAliasesMemref (AliasAnalysis &analysis,
58
+ ValueRange memrefList, Value memref,
59
+ bool allowSelfAlias) {
60
+ for (auto mr : memrefList) {
61
+ if (allowSelfAlias && mr == memref)
62
+ continue ;
63
+ if (!analysis.alias (mr, memref).isNo ())
64
+ return true ;
65
+ }
66
+ return false ;
67
+ }
68
+
51
69
// ===----------------------------------------------------------------------===//
52
70
// Patterns
53
71
// ===----------------------------------------------------------------------===//
@@ -176,15 +194,6 @@ struct RemoveRetainedMemrefsGuaranteedToNotAlias
176
194
AliasAnalysis &aliasAnalysis)
177
195
: OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
178
196
179
- bool potentiallyAliasesMemref (DeallocOp deallocOp,
180
- Value retainedMemref) const {
181
- for (auto memref : deallocOp.getMemrefs ()) {
182
- if (!aliasAnalysis.alias (memref, retainedMemref).isNo ())
183
- return true ;
184
- }
185
- return false ;
186
- }
187
-
188
197
LogicalResult matchAndRewrite (DeallocOp deallocOp,
189
198
PatternRewriter &rewriter) const override {
190
199
SmallVector<Value> newRetainedMemrefs, replacements;
@@ -197,7 +206,8 @@ struct RemoveRetainedMemrefsGuaranteedToNotAlias
197
206
};
198
207
199
208
for (auto retainedMemref : deallocOp.getRetained ()) {
200
- if (potentiallyAliasesMemref (deallocOp, retainedMemref)) {
209
+ if (potentiallyAliasesMemref (aliasAnalysis, deallocOp.getMemrefs (),
210
+ retainedMemref, false )) {
201
211
newRetainedMemrefs.push_back (retainedMemref);
202
212
replacements.push_back ({});
203
213
continue ;
@@ -226,6 +236,85 @@ struct RemoveRetainedMemrefsGuaranteedToNotAlias
226
236
AliasAnalysis &aliasAnalysis;
227
237
};
228
238
239
+ // / Split off memrefs to separate dealloc operations to reduce the number of
240
+ // / runtime checks required and enable further canonicalization of the new and
241
+ // / simpler dealloc operations. A memref can be split off if it is guaranteed to
242
+ // / not alias with any other memref in the `memref` operand list. The results
243
+ // / of the old and the new dealloc operation have to be combined by computing
244
+ // / the element-wise disjunction of them.
245
+ // /
246
+ // / Example:
247
+ // / ```mlir
248
+ // / %0:2 = bufferization.dealloc (%m0, %m1 : memref<2xi32>, memref<2xi32>)
249
+ // / if (%cond0, %cond1)
250
+ // / retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
251
+ // / return %0#0, %0#1
252
+ // / ```
253
+ // / Given that `%m0` is guaranteed to never alias with `%m1`, the above IR is
254
+ // / canonicalized to the following, thus reducing the number of runtime alias
255
+ // / checks by 1 and potentially enabling further canonicalization of the new
256
+ // / split-up dealloc operations.
257
+ // / ```mlir
258
+ // / %0:2 = bufferization.dealloc (%m0 : memref<2xi32>) if (%cond0)
259
+ // / retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
260
+ // / %1:2 = bufferization.dealloc (%m1 : memref<2xi32>) if (%cond1)
261
+ // / retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
262
+ // / %2 = arith.ori %0#0, %1#0
263
+ // / %3 = arith.ori %0#1, %1#1
264
+ // / return %2, %3
265
+ // / ```
266
+ struct SplitDeallocWhenNotAliasingAnyOther
267
+ : public OpRewritePattern<DeallocOp> {
268
+ SplitDeallocWhenNotAliasingAnyOther (MLIRContext *context,
269
+ AliasAnalysis &aliasAnalysis)
270
+ : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
271
+
272
+ LogicalResult matchAndRewrite (DeallocOp deallocOp,
273
+ PatternRewriter &rewriter) const override {
274
+ if (deallocOp.getMemrefs ().size () <= 1 )
275
+ return failure ();
276
+
277
+ SmallVector<Value> newMemrefs, newConditions, replacements;
278
+ DenseSet<Operation *> exceptedUsers;
279
+ replacements = deallocOp.getUpdatedConditions ();
280
+ for (auto [memref, cond] :
281
+ llvm::zip (deallocOp.getMemrefs (), deallocOp.getConditions ())) {
282
+ if (potentiallyAliasesMemref (aliasAnalysis, deallocOp.getMemrefs (),
283
+ memref, true )) {
284
+ newMemrefs.push_back (memref);
285
+ newConditions.push_back (cond);
286
+ continue ;
287
+ }
288
+
289
+ auto newDeallocOp = rewriter.create <DeallocOp>(
290
+ deallocOp.getLoc (), memref, cond, deallocOp.getRetained ());
291
+ replacements = SmallVector<Value>(llvm::map_range (
292
+ llvm::zip (replacements, newDeallocOp.getUpdatedConditions ()),
293
+ [&](auto replAndNew) -> Value {
294
+ auto orOp = rewriter.create <arith::OrIOp>(deallocOp.getLoc (),
295
+ std::get<0 >(replAndNew),
296
+ std::get<1 >(replAndNew));
297
+ exceptedUsers.insert (orOp);
298
+ return orOp.getResult ();
299
+ }));
300
+ }
301
+
302
+ if (newMemrefs.size () == deallocOp.getMemrefs ().size ())
303
+ return failure ();
304
+
305
+ rewriter.replaceUsesWithIf (deallocOp.getUpdatedConditions (), replacements,
306
+ [&](OpOperand &operand) {
307
+ return !exceptedUsers.contains (
308
+ operand.getOwner ());
309
+ });
310
+ return updateDeallocIfChanged (deallocOp, newMemrefs, newConditions,
311
+ rewriter);
312
+ }
313
+
314
+ private:
315
+ AliasAnalysis &aliasAnalysis;
316
+ };
317
+
229
318
} // namespace
230
319
231
320
// ===----------------------------------------------------------------------===//
@@ -244,8 +333,9 @@ struct BufferDeallocationSimplificationPass
244
333
AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>();
245
334
RewritePatternSet patterns (&getContext ());
246
335
patterns.add <DeallocRemoveDeallocMemrefsContainedInRetained,
247
- RemoveRetainedMemrefsGuaranteedToNotAlias>(&getContext (),
248
- aliasAnalysis);
336
+ RemoveRetainedMemrefsGuaranteedToNotAlias,
337
+ SplitDeallocWhenNotAliasingAnyOther>(&getContext (),
338
+ aliasAnalysis);
249
339
250
340
if (failed (
251
341
applyPatternsAndFoldGreedily (getOperation (), std::move (patterns))))
0 commit comments