12
12
//
13
13
// ===----------------------------------------------------------------------===//
14
14
15
- #include " mlir/Analysis/AliasAnalysis.h"
16
15
#include " mlir/Dialect/Bufferization/IR/Bufferization.h"
16
+ #include " mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
17
17
#include " mlir/Dialect/Bufferization/Transforms/Passes.h"
18
18
#include " mlir/Dialect/Func/IR/FuncOps.h"
19
19
#include " mlir/Dialect/MemRef/IR/MemRef.h"
@@ -34,6 +34,14 @@ using namespace mlir::bufferization;
34
34
// Helpers
35
35
// ===----------------------------------------------------------------------===//
36
36
37
+ // / Given a memref value, return the "base" value by skipping over all
38
+ // / ViewLikeOpInterface ops (if any) in the reverse use-def chain.
39
+ static Value getViewBase (Value value) {
40
+ while (auto viewLikeOp = value.getDefiningOp <ViewLikeOpInterface>())
41
+ value = viewLikeOp.getViewSource ();
42
+ return value;
43
+ }
44
+
37
45
static LogicalResult updateDeallocIfChanged (DeallocOp deallocOp,
38
46
ValueRange memrefs,
39
47
ValueRange conditions,
@@ -49,14 +57,6 @@ static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
49
57
return success ();
50
58
}
51
59
52
- // / Given a memref value, return the "base" value by skipping over all
53
- // / ViewLikeOpInterface ops (if any) in the reverse use-def chain.
54
- static Value getViewBase (Value value) {
55
- while (auto viewLikeOp = value.getDefiningOp <ViewLikeOpInterface>())
56
- value = viewLikeOp.getViewSource ();
57
- return value;
58
- }
59
-
60
60
// / Return "true" if the given values are guaranteed to be different (and
61
61
// / non-aliasing) allocations based on the fact that one value is the result
62
62
// / of an allocation and the other value is a block argument of a parent block.
@@ -80,12 +80,14 @@ static bool distinctAllocAndBlockArgument(Value v1, Value v2) {
80
80
// / Checks if `memref` may potentially alias a MemRef in `otherList`. It is
81
81
// / often a requirement of optimization patterns that there cannot be any
82
82
// / aliasing memref in order to perform the desired simplification.
83
- static bool potentiallyAliasesMemref (AliasAnalysis &analysis,
83
+ static bool potentiallyAliasesMemref (BufferOriginAnalysis &analysis,
84
84
ValueRange otherList, Value memref) {
85
85
for (auto other : otherList) {
86
86
if (distinctAllocAndBlockArgument (other, memref))
87
87
continue ;
88
- if (!analysis.alias (other, memref).isNo ())
88
+ std::optional<bool > analysisResult =
89
+ analysis.isSameAllocation (other, memref);
90
+ if (!analysisResult.has_value () || analysisResult == true )
89
91
return true ;
90
92
}
91
93
return false ;
@@ -129,8 +131,8 @@ namespace {
129
131
struct RemoveDeallocMemrefsContainedInRetained
130
132
: public OpRewritePattern<DeallocOp> {
131
133
RemoveDeallocMemrefsContainedInRetained (MLIRContext *context,
132
- AliasAnalysis &aliasAnalysis )
133
- : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis ) {}
134
+ BufferOriginAnalysis &analysis )
135
+ : OpRewritePattern<DeallocOp>(context), analysis(analysis ) {}
134
136
135
137
// / The passed 'memref' must not have a may-alias relation to any retained
136
138
// / memref, and at least one must-alias relation. If there is no must-aliasing
@@ -147,10 +149,11 @@ struct RemoveDeallocMemrefsContainedInRetained
147
149
// deallocated in some situations and can thus not be dropped).
148
150
bool atLeastOneMustAlias = false ;
149
151
for (Value retained : deallocOp.getRetained ()) {
150
- AliasResult analysisResult = aliasAnalysis.alias (retained, memref);
151
- if (analysisResult.isMay ())
152
+ std::optional<bool > analysisResult =
153
+ analysis.isSameAllocation (retained, memref);
154
+ if (!analysisResult.has_value ())
152
155
return failure ();
153
- if (analysisResult. isMust () || analysisResult. isPartial () )
156
+ if (analysisResult == true )
154
157
atLeastOneMustAlias = true ;
155
158
}
156
159
if (!atLeastOneMustAlias)
@@ -161,8 +164,9 @@ struct RemoveDeallocMemrefsContainedInRetained
161
164
// we can remove that operand later on.
162
165
for (auto [i, retained] : llvm::enumerate (deallocOp.getRetained ())) {
163
166
Value updatedCondition = deallocOp.getUpdatedConditions ()[i];
164
- AliasResult analysisResult = aliasAnalysis.alias (retained, memref);
165
- if (analysisResult.isMust () || analysisResult.isPartial ()) {
167
+ std::optional<bool > analysisResult =
168
+ analysis.isSameAllocation (retained, memref);
169
+ if (analysisResult == true ) {
166
170
auto disjunction = rewriter.create <arith::OrIOp>(
167
171
deallocOp.getLoc (), updatedCondition, cond);
168
172
rewriter.replaceAllUsesExcept (updatedCondition, disjunction.getResult (),
@@ -206,7 +210,7 @@ struct RemoveDeallocMemrefsContainedInRetained
206
210
}
207
211
208
212
private:
209
- AliasAnalysis &aliasAnalysis ;
213
+ BufferOriginAnalysis &analysis ;
210
214
};
211
215
212
216
// / Remove memrefs from the `retained` list which are guaranteed to not alias
@@ -228,15 +232,15 @@ struct RemoveDeallocMemrefsContainedInRetained
228
232
struct RemoveRetainedMemrefsGuaranteedToNotAlias
229
233
: public OpRewritePattern<DeallocOp> {
230
234
RemoveRetainedMemrefsGuaranteedToNotAlias (MLIRContext *context,
231
- AliasAnalysis &aliasAnalysis )
232
- : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis ) {}
235
+ BufferOriginAnalysis &analysis )
236
+ : OpRewritePattern<DeallocOp>(context), analysis(analysis ) {}
233
237
234
238
LogicalResult matchAndRewrite (DeallocOp deallocOp,
235
239
PatternRewriter &rewriter) const override {
236
240
SmallVector<Value> newRetainedMemrefs, replacements;
237
241
238
242
for (auto retainedMemref : deallocOp.getRetained ()) {
239
- if (potentiallyAliasesMemref (aliasAnalysis , deallocOp.getMemrefs (),
243
+ if (potentiallyAliasesMemref (analysis , deallocOp.getMemrefs (),
240
244
retainedMemref)) {
241
245
newRetainedMemrefs.push_back (retainedMemref);
242
246
replacements.push_back ({});
@@ -264,7 +268,7 @@ struct RemoveRetainedMemrefsGuaranteedToNotAlias
264
268
}
265
269
266
270
private:
267
- AliasAnalysis &aliasAnalysis ;
271
+ BufferOriginAnalysis &analysis ;
268
272
};
269
273
270
274
// / Split off memrefs to separate dealloc operations to reduce the number of
@@ -297,8 +301,8 @@ struct RemoveRetainedMemrefsGuaranteedToNotAlias
297
301
struct SplitDeallocWhenNotAliasingAnyOther
298
302
: public OpRewritePattern<DeallocOp> {
299
303
SplitDeallocWhenNotAliasingAnyOther (MLIRContext *context,
300
- AliasAnalysis &aliasAnalysis )
301
- : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis ) {}
304
+ BufferOriginAnalysis &analysis )
305
+ : OpRewritePattern<DeallocOp>(context), analysis(analysis ) {}
302
306
303
307
LogicalResult matchAndRewrite (DeallocOp deallocOp,
304
308
PatternRewriter &rewriter) const override {
@@ -314,7 +318,7 @@ struct SplitDeallocWhenNotAliasingAnyOther
314
318
SmallVector<Value> otherMemrefs (deallocOp.getMemrefs ());
315
319
otherMemrefs.erase (otherMemrefs.begin () + i);
316
320
// Check if `memref` can split off into a separate bufferization.dealloc.
317
- if (potentiallyAliasesMemref (aliasAnalysis , otherMemrefs, memref)) {
321
+ if (potentiallyAliasesMemref (analysis , otherMemrefs, memref)) {
318
322
// `memref` alias with other memrefs, do not split off.
319
323
remainingMemrefs.push_back (memref);
320
324
remainingConditions.push_back (cond);
@@ -352,7 +356,7 @@ struct SplitDeallocWhenNotAliasingAnyOther
352
356
}
353
357
354
358
private:
355
- AliasAnalysis &aliasAnalysis ;
359
+ BufferOriginAnalysis &analysis ;
356
360
};
357
361
358
362
// / Check for every retained memref if a must-aliasing memref exists in the
@@ -381,8 +385,8 @@ struct SplitDeallocWhenNotAliasingAnyOther
381
385
struct RetainedMemrefAliasingAlwaysDeallocatedMemref
382
386
: public OpRewritePattern<DeallocOp> {
383
387
RetainedMemrefAliasingAlwaysDeallocatedMemref (MLIRContext *context,
384
- AliasAnalysis &aliasAnalysis )
385
- : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis ) {}
388
+ BufferOriginAnalysis &analysis )
389
+ : OpRewritePattern<DeallocOp>(context), analysis(analysis ) {}
386
390
387
391
LogicalResult matchAndRewrite (DeallocOp deallocOp,
388
392
PatternRewriter &rewriter) const override {
@@ -396,8 +400,9 @@ struct RetainedMemrefAliasingAlwaysDeallocatedMemref
396
400
if (!matchPattern (cond, m_One ()))
397
401
continue ;
398
402
399
- AliasResult analysisResult = aliasAnalysis.alias (retained, memref);
400
- if (analysisResult.isMust () || analysisResult.isPartial ()) {
403
+ std::optional<bool > analysisResult =
404
+ analysis.isSameAllocation (retained, memref);
405
+ if (analysisResult == true ) {
401
406
rewriter.replaceAllUsesWith (res, cond);
402
407
aliasesWithConstTrueMemref[i] = true ;
403
408
canDropMemref = true ;
@@ -411,10 +416,9 @@ struct RetainedMemrefAliasingAlwaysDeallocatedMemref
411
416
if (!extractOp)
412
417
continue ;
413
418
414
- AliasResult extractAnalysisResult =
415
- aliasAnalysis.alias (retained, extractOp.getOperand ());
416
- if (extractAnalysisResult.isMust () ||
417
- extractAnalysisResult.isPartial ()) {
419
+ std::optional<bool > extractAnalysisResult =
420
+ analysis.isSameAllocation (retained, extractOp.getOperand ());
421
+ if (extractAnalysisResult == true ) {
418
422
rewriter.replaceAllUsesWith (res, cond);
419
423
aliasesWithConstTrueMemref[i] = true ;
420
424
canDropMemref = true ;
@@ -434,7 +438,7 @@ struct RetainedMemrefAliasingAlwaysDeallocatedMemref
434
438
}
435
439
436
440
private:
437
- AliasAnalysis &aliasAnalysis ;
441
+ BufferOriginAnalysis &analysis ;
438
442
};
439
443
440
444
} // namespace
@@ -452,13 +456,13 @@ struct BufferDeallocationSimplificationPass
452
456
: public bufferization::impl::BufferDeallocationSimplificationBase<
453
457
BufferDeallocationSimplificationPass> {
454
458
void runOnOperation () override {
455
- AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>( );
459
+ BufferOriginAnalysis analysis ( getOperation () );
456
460
RewritePatternSet patterns (&getContext ());
457
461
patterns.add <RemoveDeallocMemrefsContainedInRetained,
458
462
RemoveRetainedMemrefsGuaranteedToNotAlias,
459
463
SplitDeallocWhenNotAliasingAnyOther,
460
464
RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext (),
461
- aliasAnalysis );
465
+ analysis );
462
466
populateDeallocOpCanonicalizationPatterns (patterns, &getContext ());
463
467
464
468
if (failed (
0 commit comments