Skip to content

Commit be4524e

Browse files
[mlir][bufferization] BufferOriginAnalysis
1 parent cceedc9 commit be4524e

File tree

6 files changed

+321
-56
lines changed

6 files changed

+321
-56
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define BUFFERIZATION_OPS
1111

1212
include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
13+
include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
1314
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
1415
include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
1516
include "mlir/Interfaces/DestinationStyleOpInterface.td"

mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class BufferViewFlowAnalysis {
5353
///
5454
/// Results in resolve(B) returning {B, C}
5555
ValueSetT resolve(Value value) const;
56+
ValueSetT resolveReverse(Value value) const;
5657

5758
/// Removes the given values from all alias sets.
5859
void remove(const SetVector<Value> &aliasValues);
@@ -73,11 +74,46 @@ class BufferViewFlowAnalysis {
7374

7475
/// Maps values to all immediate dependencies this value can have.
7576
ValueMapT dependencies;
77+
ValueMapT reverseDependencies;
7678

7779
/// A set of all SSA values that may be terminal buffers.
7880
DenseSet<Value> terminals;
7981
};
8082

83+
/// An is-same-buffer analysis that checks if two SSA values belong to the same
84+
/// buffer allocation or not.
85+
class BufferOriginAnalysis {
86+
public:
87+
BufferOriginAnalysis(Operation *op);
88+
89+
/// Return "true" if `v1` and `v2` originate from the same buffer allocation.
90+
/// Return "false" if `v1` and `v2` originate from different allocations.
91+
/// Return "nullopt" if we do not know for sure.
92+
///
93+
/// Example 1: isSameAllocation(%0, %1) == true
94+
/// ```
95+
/// %0 = memref.alloc()
96+
/// %1 = memref.subview %0
97+
/// ```
98+
///
99+
/// Example 2: isSameAllocation(%0, %1) == false
100+
/// ```
101+
/// %0 = memref.alloc()
102+
/// %1 = memref.alloc()
103+
/// ```
104+
///
105+
/// Example 3: isSameAllocation(%0, %2) == nullopt
106+
/// ```
107+
/// %0 = memref.alloc()
108+
/// %1 = memref.alloc()
109+
/// %2 = arith.select %c, %0, %1
110+
/// ```
111+
std::optional<bool> isSameAllocation(Value v1, Value v2);
112+
113+
private:
114+
BufferViewFlowAnalysis analysis;
115+
};
116+
81117
} // namespace mlir
82118

83119
#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_BUFFERVIEWFLOWANALYSIS_H

mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp

Lines changed: 42 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
//
1313
//===----------------------------------------------------------------------===//
1414

15-
#include "mlir/Analysis/AliasAnalysis.h"
1615
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
16+
#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
1717
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
1818
#include "mlir/Dialect/Func/IR/FuncOps.h"
1919
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -34,6 +34,14 @@ using namespace mlir::bufferization;
3434
// Helpers
3535
//===----------------------------------------------------------------------===//
3636

37+
/// Given a memref value, return the "base" value by skipping over all
38+
/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
39+
static Value getViewBase(Value value) {
40+
while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
41+
value = viewLikeOp.getViewSource();
42+
return value;
43+
}
44+
3745
static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
3846
ValueRange memrefs,
3947
ValueRange conditions,
@@ -49,14 +57,6 @@ static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
4957
return success();
5058
}
5159

52-
/// Given a memref value, return the "base" value by skipping over all
53-
/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
54-
static Value getViewBase(Value value) {
55-
while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
56-
value = viewLikeOp.getViewSource();
57-
return value;
58-
}
59-
6060
/// Return "true" if the given values are guaranteed to be different (and
6161
/// non-aliasing) allocations based on the fact that one value is the result
6262
/// of an allocation and the other value is a block argument of a parent block.
@@ -80,12 +80,14 @@ static bool distinctAllocAndBlockArgument(Value v1, Value v2) {
8080
/// Checks if `memref` may potentially alias a MemRef in `otherList`. It is
8181
/// often a requirement of optimization patterns that there cannot be any
8282
/// aliasing memref in order to perform the desired simplification.
83-
static bool potentiallyAliasesMemref(AliasAnalysis &analysis,
83+
static bool potentiallyAliasesMemref(BufferOriginAnalysis &analysis,
8484
ValueRange otherList, Value memref) {
8585
for (auto other : otherList) {
8686
if (distinctAllocAndBlockArgument(other, memref))
8787
continue;
88-
if (!analysis.alias(other, memref).isNo())
88+
std::optional<bool> analysisResult =
89+
analysis.isSameAllocation(other, memref);
90+
if (!analysisResult.has_value() || analysisResult == true)
8991
return true;
9092
}
9193
return false;
@@ -129,8 +131,8 @@ namespace {
129131
struct RemoveDeallocMemrefsContainedInRetained
130132
: public OpRewritePattern<DeallocOp> {
131133
RemoveDeallocMemrefsContainedInRetained(MLIRContext *context,
132-
AliasAnalysis &aliasAnalysis)
133-
: OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
134+
BufferOriginAnalysis &analysis)
135+
: OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
134136

135137
/// The passed 'memref' must not have a may-alias relation to any retained
136138
/// memref, and at least one must-alias relation. If there is no must-aliasing
@@ -147,10 +149,11 @@ struct RemoveDeallocMemrefsContainedInRetained
147149
// deallocated in some situations and can thus not be dropped).
148150
bool atLeastOneMustAlias = false;
149151
for (Value retained : deallocOp.getRetained()) {
150-
AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
151-
if (analysisResult.isMay())
152+
std::optional<bool> analysisResult =
153+
analysis.isSameAllocation(retained, memref);
154+
if (!analysisResult.has_value())
152155
return failure();
153-
if (analysisResult.isMust() || analysisResult.isPartial())
156+
if (analysisResult == true)
154157
atLeastOneMustAlias = true;
155158
}
156159
if (!atLeastOneMustAlias)
@@ -161,8 +164,9 @@ struct RemoveDeallocMemrefsContainedInRetained
161164
// we can remove that operand later on.
162165
for (auto [i, retained] : llvm::enumerate(deallocOp.getRetained())) {
163166
Value updatedCondition = deallocOp.getUpdatedConditions()[i];
164-
AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
165-
if (analysisResult.isMust() || analysisResult.isPartial()) {
167+
std::optional<bool> analysisResult =
168+
analysis.isSameAllocation(retained, memref);
169+
if (analysisResult == true) {
166170
auto disjunction = rewriter.create<arith::OrIOp>(
167171
deallocOp.getLoc(), updatedCondition, cond);
168172
rewriter.replaceAllUsesExcept(updatedCondition, disjunction.getResult(),
@@ -206,7 +210,7 @@ struct RemoveDeallocMemrefsContainedInRetained
206210
}
207211

208212
private:
209-
AliasAnalysis &aliasAnalysis;
213+
BufferOriginAnalysis &analysis;
210214
};
211215

212216
/// Remove memrefs from the `retained` list which are guaranteed to not alias
@@ -228,15 +232,15 @@ struct RemoveDeallocMemrefsContainedInRetained
228232
struct RemoveRetainedMemrefsGuaranteedToNotAlias
229233
: public OpRewritePattern<DeallocOp> {
230234
RemoveRetainedMemrefsGuaranteedToNotAlias(MLIRContext *context,
231-
AliasAnalysis &aliasAnalysis)
232-
: OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
235+
BufferOriginAnalysis &analysis)
236+
: OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
233237

234238
LogicalResult matchAndRewrite(DeallocOp deallocOp,
235239
PatternRewriter &rewriter) const override {
236240
SmallVector<Value> newRetainedMemrefs, replacements;
237241

238242
for (auto retainedMemref : deallocOp.getRetained()) {
239-
if (potentiallyAliasesMemref(aliasAnalysis, deallocOp.getMemrefs(),
243+
if (potentiallyAliasesMemref(analysis, deallocOp.getMemrefs(),
240244
retainedMemref)) {
241245
newRetainedMemrefs.push_back(retainedMemref);
242246
replacements.push_back({});
@@ -264,7 +268,7 @@ struct RemoveRetainedMemrefsGuaranteedToNotAlias
264268
}
265269

266270
private:
267-
AliasAnalysis &aliasAnalysis;
271+
BufferOriginAnalysis &analysis;
268272
};
269273

270274
/// Split off memrefs to separate dealloc operations to reduce the number of
@@ -297,8 +301,8 @@ struct RemoveRetainedMemrefsGuaranteedToNotAlias
297301
struct SplitDeallocWhenNotAliasingAnyOther
298302
: public OpRewritePattern<DeallocOp> {
299303
SplitDeallocWhenNotAliasingAnyOther(MLIRContext *context,
300-
AliasAnalysis &aliasAnalysis)
301-
: OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
304+
BufferOriginAnalysis &analysis)
305+
: OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
302306

303307
LogicalResult matchAndRewrite(DeallocOp deallocOp,
304308
PatternRewriter &rewriter) const override {
@@ -314,7 +318,7 @@ struct SplitDeallocWhenNotAliasingAnyOther
314318
SmallVector<Value> otherMemrefs(deallocOp.getMemrefs());
315319
otherMemrefs.erase(otherMemrefs.begin() + i);
316320
// Check if `memref` can split off into a separate bufferization.dealloc.
317-
if (potentiallyAliasesMemref(aliasAnalysis, otherMemrefs, memref)) {
321+
if (potentiallyAliasesMemref(analysis, otherMemrefs, memref)) {
318322
// `memref` alias with other memrefs, do not split off.
319323
remainingMemrefs.push_back(memref);
320324
remainingConditions.push_back(cond);
@@ -352,7 +356,7 @@ struct SplitDeallocWhenNotAliasingAnyOther
352356
}
353357

354358
private:
355-
AliasAnalysis &aliasAnalysis;
359+
BufferOriginAnalysis &analysis;
356360
};
357361

358362
/// Check for every retained memref if a must-aliasing memref exists in the
@@ -381,8 +385,8 @@ struct SplitDeallocWhenNotAliasingAnyOther
381385
struct RetainedMemrefAliasingAlwaysDeallocatedMemref
382386
: public OpRewritePattern<DeallocOp> {
383387
RetainedMemrefAliasingAlwaysDeallocatedMemref(MLIRContext *context,
384-
AliasAnalysis &aliasAnalysis)
385-
: OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
388+
BufferOriginAnalysis &analysis)
389+
: OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
386390

387391
LogicalResult matchAndRewrite(DeallocOp deallocOp,
388392
PatternRewriter &rewriter) const override {
@@ -396,8 +400,9 @@ struct RetainedMemrefAliasingAlwaysDeallocatedMemref
396400
if (!matchPattern(cond, m_One()))
397401
continue;
398402

399-
AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
400-
if (analysisResult.isMust() || analysisResult.isPartial()) {
403+
std::optional<bool> analysisResult =
404+
analysis.isSameAllocation(retained, memref);
405+
if (analysisResult == true) {
401406
rewriter.replaceAllUsesWith(res, cond);
402407
aliasesWithConstTrueMemref[i] = true;
403408
canDropMemref = true;
@@ -411,10 +416,9 @@ struct RetainedMemrefAliasingAlwaysDeallocatedMemref
411416
if (!extractOp)
412417
continue;
413418

414-
AliasResult extractAnalysisResult =
415-
aliasAnalysis.alias(retained, extractOp.getOperand());
416-
if (extractAnalysisResult.isMust() ||
417-
extractAnalysisResult.isPartial()) {
419+
std::optional<bool> extractAnalysisResult =
420+
analysis.isSameAllocation(retained, extractOp.getOperand());
421+
if (extractAnalysisResult == true) {
418422
rewriter.replaceAllUsesWith(res, cond);
419423
aliasesWithConstTrueMemref[i] = true;
420424
canDropMemref = true;
@@ -434,7 +438,7 @@ struct RetainedMemrefAliasingAlwaysDeallocatedMemref
434438
}
435439

436440
private:
437-
AliasAnalysis &aliasAnalysis;
441+
BufferOriginAnalysis &analysis;
438442
};
439443

440444
} // namespace
@@ -452,13 +456,13 @@ struct BufferDeallocationSimplificationPass
452456
: public bufferization::impl::BufferDeallocationSimplificationBase<
453457
BufferDeallocationSimplificationPass> {
454458
void runOnOperation() override {
455-
AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>();
459+
BufferOriginAnalysis analysis(getOperation());
456460
RewritePatternSet patterns(&getContext());
457461
patterns.add<RemoveDeallocMemrefsContainedInRetained,
458462
RemoveRetainedMemrefsGuaranteedToNotAlias,
459463
SplitDeallocWhenNotAliasingAnyOther,
460464
RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
461-
aliasAnalysis);
465+
analysis);
462466
populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
463467

464468
if (failed(

0 commit comments

Comments
 (0)