Skip to content

Commit 83ca1fb

Browse files
[mlir][bufferization] Add BufferOriginAnalysis
This commit adds the `BufferOriginAnalysis`, which can be queried to check if two buffer SSA values originate from the same allocation. This new analysis is used in the buffer deallocation pass to fold away or simplify `bufferization.dealloc` ops more aggressively. The `BufferOriginAnalysis` is based on the `BufferViewFlowAnalysis`, which collects buffer SSA value "same buffer" dependencies. E.g., given IR such as: ``` %0 = memref.alloc() %1 = memref.subview %0 %2 = memref.subview %1 ``` The `BufferViewFlowAnalysis` will report the following "reverse" dependencies (`resolveReverse`) for `%2`: {`%2`, `%1`, `%0`}. I.e., all buffer SSA values that originate from the same allocation as `%2`. The `BufferOriginAnalysis` is built on top of that. It handles only simple cases at the moment and may conservatively return "unknown" around certain IR with branches, memref globals and function arguments. This analysis enables additional simplifications during `-buffer-deallocation-simplification`. In particular, "regular" scf.for loop nests, that yield buffers (or reallocations thereof) in the same order as they appear in the iter_args, are now handled much more efficiently. (TODO: Add test case.) Such IR patterns are generated by the sparse compiler.
1 parent cceedc9 commit 83ca1fb

File tree

5 files changed

+235
-56
lines changed

5 files changed

+235
-56
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define BUFFERIZATION_OPS
1111

1212
include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
13+
include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
1314
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
1415
include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
1516
include "mlir/Interfaces/DestinationStyleOpInterface.td"

mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class BufferViewFlowAnalysis {
5353
///
5454
/// Results in resolve(B) returning {B, C}
5555
ValueSetT resolve(Value value) const;
56+
ValueSetT resolveReverse(Value value) const;
5657

5758
/// Removes the given values from all alias sets.
5859
void remove(const SetVector<Value> &aliasValues);
@@ -73,11 +74,46 @@ class BufferViewFlowAnalysis {
7374

7475
/// Maps values to all immediate dependencies this value can have.
7576
ValueMapT dependencies;
77+
ValueMapT reverseDependencies;
7678

7779
/// A set of all SSA values that may be terminal buffers.
7880
DenseSet<Value> terminals;
7981
};
8082

83+
/// An is-same-buffer analysis that checks if two SSA values belong to the same
84+
/// buffer allocation or not.
85+
class BufferOriginAnalysis {
86+
public:
87+
BufferOriginAnalysis(Operation *op);
88+
89+
/// Return "true" if `v1` and `v2` originate from the same buffer allocation.
90+
/// Return "false" if `v1` and `v2` originate from different allocations.
91+
/// Return "nullopt" if we do not know for sure.
92+
///
93+
/// Example 1: isSameAllocation(%0, %1) == true
94+
/// ```
95+
/// %0 = memref.alloc()
96+
/// %1 = memref.subview %0
97+
/// ```
98+
///
99+
/// Example 2: isSameAllocation(%0, %1) == false
100+
/// ```
101+
/// %0 = memref.alloc()
102+
/// %1 = memref.alloc()
103+
/// ```
104+
///
105+
/// Example 3: isSameAllocation(%0, %2) == nullopt
106+
/// ```
107+
/// %0 = memref.alloc()
108+
/// %1 = memref.alloc()
109+
/// %2 = arith.select %c, %0, %1
110+
/// ```
111+
std::optional<bool> isSameAllocation(Value v1, Value v2);
112+
113+
private:
114+
BufferViewFlowAnalysis analysis;
115+
};
116+
81117
} // namespace mlir
82118

83119
#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_BUFFERVIEWFLOWANALYSIS_H

mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp

Lines changed: 42 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
//
1313
//===----------------------------------------------------------------------===//
1414

15-
#include "mlir/Analysis/AliasAnalysis.h"
1615
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
16+
#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
1717
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
1818
#include "mlir/Dialect/Func/IR/FuncOps.h"
1919
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -34,6 +34,14 @@ using namespace mlir::bufferization;
3434
// Helpers
3535
//===----------------------------------------------------------------------===//
3636

37+
/// Given a memref value, return the "base" value by skipping over all
38+
/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
39+
static Value getViewBase(Value value) {
40+
while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
41+
value = viewLikeOp.getViewSource();
42+
return value;
43+
}
44+
3745
static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
3846
ValueRange memrefs,
3947
ValueRange conditions,
@@ -49,14 +57,6 @@ static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
4957
return success();
5058
}
5159

52-
/// Given a memref value, return the "base" value by skipping over all
53-
/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
54-
static Value getViewBase(Value value) {
55-
while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
56-
value = viewLikeOp.getViewSource();
57-
return value;
58-
}
59-
6060
/// Return "true" if the given values are guaranteed to be different (and
6161
/// non-aliasing) allocations based on the fact that one value is the result
6262
/// of an allocation and the other value is a block argument of a parent block.
@@ -80,12 +80,14 @@ static bool distinctAllocAndBlockArgument(Value v1, Value v2) {
8080
/// Checks if `memref` may potentially alias a MemRef in `otherList`. It is
8181
/// often a requirement of optimization patterns that there cannot be any
8282
/// aliasing memref in order to perform the desired simplification.
83-
static bool potentiallyAliasesMemref(AliasAnalysis &analysis,
83+
static bool potentiallyAliasesMemref(BufferOriginAnalysis &analysis,
8484
ValueRange otherList, Value memref) {
8585
for (auto other : otherList) {
8686
if (distinctAllocAndBlockArgument(other, memref))
8787
continue;
88-
if (!analysis.alias(other, memref).isNo())
88+
std::optional<bool> analysisResult =
89+
analysis.isSameAllocation(other, memref);
90+
if (!analysisResult.has_value() || analysisResult == true)
8991
return true;
9092
}
9193
return false;
@@ -129,8 +131,8 @@ namespace {
129131
struct RemoveDeallocMemrefsContainedInRetained
130132
: public OpRewritePattern<DeallocOp> {
131133
RemoveDeallocMemrefsContainedInRetained(MLIRContext *context,
132-
AliasAnalysis &aliasAnalysis)
133-
: OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
134+
BufferOriginAnalysis &analysis)
135+
: OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
134136

135137
/// The passed 'memref' must not have a may-alias relation to any retained
136138
/// memref, and at least one must-alias relation. If there is no must-aliasing
@@ -147,10 +149,11 @@ struct RemoveDeallocMemrefsContainedInRetained
147149
// deallocated in some situations and can thus not be dropped).
148150
bool atLeastOneMustAlias = false;
149151
for (Value retained : deallocOp.getRetained()) {
150-
AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
151-
if (analysisResult.isMay())
152+
std::optional<bool> analysisResult =
153+
analysis.isSameAllocation(retained, memref);
154+
if (!analysisResult.has_value())
152155
return failure();
153-
if (analysisResult.isMust() || analysisResult.isPartial())
156+
if (analysisResult == true)
154157
atLeastOneMustAlias = true;
155158
}
156159
if (!atLeastOneMustAlias)
@@ -161,8 +164,9 @@ struct RemoveDeallocMemrefsContainedInRetained
161164
// we can remove that operand later on.
162165
for (auto [i, retained] : llvm::enumerate(deallocOp.getRetained())) {
163166
Value updatedCondition = deallocOp.getUpdatedConditions()[i];
164-
AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
165-
if (analysisResult.isMust() || analysisResult.isPartial()) {
167+
std::optional<bool> analysisResult =
168+
analysis.isSameAllocation(retained, memref);
169+
if (analysisResult == true) {
166170
auto disjunction = rewriter.create<arith::OrIOp>(
167171
deallocOp.getLoc(), updatedCondition, cond);
168172
rewriter.replaceAllUsesExcept(updatedCondition, disjunction.getResult(),
@@ -206,7 +210,7 @@ struct RemoveDeallocMemrefsContainedInRetained
206210
}
207211

208212
private:
209-
AliasAnalysis &aliasAnalysis;
213+
BufferOriginAnalysis &analysis;
210214
};
211215

212216
/// Remove memrefs from the `retained` list which are guaranteed to not alias
@@ -228,15 +232,15 @@ struct RemoveDeallocMemrefsContainedInRetained
228232
struct RemoveRetainedMemrefsGuaranteedToNotAlias
229233
: public OpRewritePattern<DeallocOp> {
230234
RemoveRetainedMemrefsGuaranteedToNotAlias(MLIRContext *context,
231-
AliasAnalysis &aliasAnalysis)
232-
: OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
235+
BufferOriginAnalysis &analysis)
236+
: OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
233237

234238
LogicalResult matchAndRewrite(DeallocOp deallocOp,
235239
PatternRewriter &rewriter) const override {
236240
SmallVector<Value> newRetainedMemrefs, replacements;
237241

238242
for (auto retainedMemref : deallocOp.getRetained()) {
239-
if (potentiallyAliasesMemref(aliasAnalysis, deallocOp.getMemrefs(),
243+
if (potentiallyAliasesMemref(analysis, deallocOp.getMemrefs(),
240244
retainedMemref)) {
241245
newRetainedMemrefs.push_back(retainedMemref);
242246
replacements.push_back({});
@@ -264,7 +268,7 @@ struct RemoveRetainedMemrefsGuaranteedToNotAlias
264268
}
265269

266270
private:
267-
AliasAnalysis &aliasAnalysis;
271+
BufferOriginAnalysis &analysis;
268272
};
269273

270274
/// Split off memrefs to separate dealloc operations to reduce the number of
@@ -297,8 +301,8 @@ struct RemoveRetainedMemrefsGuaranteedToNotAlias
297301
struct SplitDeallocWhenNotAliasingAnyOther
298302
: public OpRewritePattern<DeallocOp> {
299303
SplitDeallocWhenNotAliasingAnyOther(MLIRContext *context,
300-
AliasAnalysis &aliasAnalysis)
301-
: OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
304+
BufferOriginAnalysis &analysis)
305+
: OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
302306

303307
LogicalResult matchAndRewrite(DeallocOp deallocOp,
304308
PatternRewriter &rewriter) const override {
@@ -314,7 +318,7 @@ struct SplitDeallocWhenNotAliasingAnyOther
314318
SmallVector<Value> otherMemrefs(deallocOp.getMemrefs());
315319
otherMemrefs.erase(otherMemrefs.begin() + i);
316320
// Check if `memref` can split off into a separate bufferization.dealloc.
317-
if (potentiallyAliasesMemref(aliasAnalysis, otherMemrefs, memref)) {
321+
if (potentiallyAliasesMemref(analysis, otherMemrefs, memref)) {
318322
// `memref` alias with other memrefs, do not split off.
319323
remainingMemrefs.push_back(memref);
320324
remainingConditions.push_back(cond);
@@ -352,7 +356,7 @@ struct SplitDeallocWhenNotAliasingAnyOther
352356
}
353357

354358
private:
355-
AliasAnalysis &aliasAnalysis;
359+
BufferOriginAnalysis &analysis;
356360
};
357361

358362
/// Check for every retained memref if a must-aliasing memref exists in the
@@ -381,8 +385,8 @@ struct SplitDeallocWhenNotAliasingAnyOther
381385
struct RetainedMemrefAliasingAlwaysDeallocatedMemref
382386
: public OpRewritePattern<DeallocOp> {
383387
RetainedMemrefAliasingAlwaysDeallocatedMemref(MLIRContext *context,
384-
AliasAnalysis &aliasAnalysis)
385-
: OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
388+
BufferOriginAnalysis &analysis)
389+
: OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
386390

387391
LogicalResult matchAndRewrite(DeallocOp deallocOp,
388392
PatternRewriter &rewriter) const override {
@@ -396,8 +400,9 @@ struct RetainedMemrefAliasingAlwaysDeallocatedMemref
396400
if (!matchPattern(cond, m_One()))
397401
continue;
398402

399-
AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
400-
if (analysisResult.isMust() || analysisResult.isPartial()) {
403+
std::optional<bool> analysisResult =
404+
analysis.isSameAllocation(retained, memref);
405+
if (analysisResult == true) {
401406
rewriter.replaceAllUsesWith(res, cond);
402407
aliasesWithConstTrueMemref[i] = true;
403408
canDropMemref = true;
@@ -411,10 +416,9 @@ struct RetainedMemrefAliasingAlwaysDeallocatedMemref
411416
if (!extractOp)
412417
continue;
413418

414-
AliasResult extractAnalysisResult =
415-
aliasAnalysis.alias(retained, extractOp.getOperand());
416-
if (extractAnalysisResult.isMust() ||
417-
extractAnalysisResult.isPartial()) {
419+
std::optional<bool> extractAnalysisResult =
420+
analysis.isSameAllocation(retained, extractOp.getOperand());
421+
if (extractAnalysisResult == true) {
418422
rewriter.replaceAllUsesWith(res, cond);
419423
aliasesWithConstTrueMemref[i] = true;
420424
canDropMemref = true;
@@ -434,7 +438,7 @@ struct RetainedMemrefAliasingAlwaysDeallocatedMemref
434438
}
435439

436440
private:
437-
AliasAnalysis &aliasAnalysis;
441+
BufferOriginAnalysis &analysis;
438442
};
439443

440444
} // namespace
@@ -452,13 +456,13 @@ struct BufferDeallocationSimplificationPass
452456
: public bufferization::impl::BufferDeallocationSimplificationBase<
453457
BufferDeallocationSimplificationPass> {
454458
void runOnOperation() override {
455-
AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>();
459+
BufferOriginAnalysis analysis(getOperation());
456460
RewritePatternSet patterns(&getContext());
457461
patterns.add<RemoveDeallocMemrefsContainedInRetained,
458462
RemoveRetainedMemrefsGuaranteedToNotAlias,
459463
SplitDeallocWhenNotAliasingAnyOther,
460464
RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
461-
aliasAnalysis);
465+
analysis);
462466
populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
463467

464468
if (failed(

0 commit comments

Comments
 (0)