-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][bufferization] Add BufferOriginAnalysis
#86461
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][bufferization] Add BufferOriginAnalysis
#86461
Conversation
✅ With the latest revision this PR passed the Python code formatter. |
✅ With the latest revision this PR passed the C/C++ code formatter. |
83ca1fb
to
be4524e
Compare
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-bufferization Author: Matthias Springer (matthias-springer) ChangesThis commit adds the The
The This analysis enables additional simplifications during Patch is 29.02 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/86461.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 9dc6afcaab31c8..4f609ddff9a413 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -10,6 +10,7 @@
#define BUFFERIZATION_OPS
include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
+include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
index 9e43265c5dfede..4015231c845daf 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
@@ -53,6 +53,7 @@ class BufferViewFlowAnalysis {
///
/// Results in resolve(B) returning {B, C}
ValueSetT resolve(Value value) const;
+ ValueSetT resolveReverse(Value value) const;
/// Removes the given values from all alias sets.
void remove(const SetVector<Value> &aliasValues);
@@ -73,11 +74,46 @@ class BufferViewFlowAnalysis {
/// Maps values to all immediate dependencies this value can have.
ValueMapT dependencies;
+ ValueMapT reverseDependencies;
/// A set of all SSA values that may be terminal buffers.
DenseSet<Value> terminals;
};
+/// An is-same-buffer analysis that checks if two SSA values belong to the same
+/// buffer allocation or not.
+class BufferOriginAnalysis {
+public:
+ BufferOriginAnalysis(Operation *op);
+
+ /// Return "true" if `v1` and `v2` originate from the same buffer allocation.
+ /// Return "false" if `v1` and `v2` originate from different allocations.
+ /// Return "nullopt" if we do not know for sure.
+ ///
+ /// Example 1: isSameAllocation(%0, %1) == true
+ /// ```
+ /// %0 = memref.alloc()
+ /// %1 = memref.subview %0
+ /// ```
+ ///
+ /// Example 2: isSameAllocation(%0, %1) == false
+ /// ```
+ /// %0 = memref.alloc()
+ /// %1 = memref.alloc()
+ /// ```
+ ///
+ /// Example 3: isSameAllocation(%0, %2) == nullopt
+ /// ```
+ /// %0 = memref.alloc()
+ /// %1 = memref.alloc()
+ /// %2 = arith.select %c, %0, %1
+ /// ```
+ std::optional<bool> isSameAllocation(Value v1, Value v2);
+
+private:
+ BufferViewFlowAnalysis analysis;
+};
+
} // namespace mlir
#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_BUFFERVIEWFLOWANALYSIS_H
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index e30779868b4753..954485cfede3da 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -12,8 +12,8 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Analysis/AliasAnalysis.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -34,6 +34,14 @@ using namespace mlir::bufferization;
// Helpers
//===----------------------------------------------------------------------===//
+/// Given a memref value, return the "base" value by skipping over all
+/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
+static Value getViewBase(Value value) {
+ while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
+ value = viewLikeOp.getViewSource();
+ return value;
+}
+
static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
ValueRange memrefs,
ValueRange conditions,
@@ -49,14 +57,6 @@ static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
return success();
}
-/// Given a memref value, return the "base" value by skipping over all
-/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
-static Value getViewBase(Value value) {
- while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
- value = viewLikeOp.getViewSource();
- return value;
-}
-
/// Return "true" if the given values are guaranteed to be different (and
/// non-aliasing) allocations based on the fact that one value is the result
/// of an allocation and the other value is a block argument of a parent block.
@@ -80,12 +80,14 @@ static bool distinctAllocAndBlockArgument(Value v1, Value v2) {
/// Checks if `memref` may potentially alias a MemRef in `otherList`. It is
/// often a requirement of optimization patterns that there cannot be any
/// aliasing memref in order to perform the desired simplification.
-static bool potentiallyAliasesMemref(AliasAnalysis &analysis,
+static bool potentiallyAliasesMemref(BufferOriginAnalysis &analysis,
ValueRange otherList, Value memref) {
for (auto other : otherList) {
if (distinctAllocAndBlockArgument(other, memref))
continue;
- if (!analysis.alias(other, memref).isNo())
+ std::optional<bool> analysisResult =
+ analysis.isSameAllocation(other, memref);
+ if (!analysisResult.has_value() || analysisResult == true)
return true;
}
return false;
@@ -129,8 +131,8 @@ namespace {
struct RemoveDeallocMemrefsContainedInRetained
: public OpRewritePattern<DeallocOp> {
RemoveDeallocMemrefsContainedInRetained(MLIRContext *context,
- AliasAnalysis &aliasAnalysis)
- : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
+ BufferOriginAnalysis &analysis)
+ : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
/// The passed 'memref' must not have a may-alias relation to any retained
/// memref, and at least one must-alias relation. If there is no must-aliasing
@@ -147,10 +149,11 @@ struct RemoveDeallocMemrefsContainedInRetained
// deallocated in some situations and can thus not be dropped).
bool atLeastOneMustAlias = false;
for (Value retained : deallocOp.getRetained()) {
- AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
- if (analysisResult.isMay())
+ std::optional<bool> analysisResult =
+ analysis.isSameAllocation(retained, memref);
+ if (!analysisResult.has_value())
return failure();
- if (analysisResult.isMust() || analysisResult.isPartial())
+ if (analysisResult == true)
atLeastOneMustAlias = true;
}
if (!atLeastOneMustAlias)
@@ -161,8 +164,9 @@ struct RemoveDeallocMemrefsContainedInRetained
// we can remove that operand later on.
for (auto [i, retained] : llvm::enumerate(deallocOp.getRetained())) {
Value updatedCondition = deallocOp.getUpdatedConditions()[i];
- AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
- if (analysisResult.isMust() || analysisResult.isPartial()) {
+ std::optional<bool> analysisResult =
+ analysis.isSameAllocation(retained, memref);
+ if (analysisResult == true) {
auto disjunction = rewriter.create<arith::OrIOp>(
deallocOp.getLoc(), updatedCondition, cond);
rewriter.replaceAllUsesExcept(updatedCondition, disjunction.getResult(),
@@ -206,7 +210,7 @@ struct RemoveDeallocMemrefsContainedInRetained
}
private:
- AliasAnalysis &aliasAnalysis;
+ BufferOriginAnalysis &analysis;
};
/// Remove memrefs from the `retained` list which are guaranteed to not alias
@@ -228,15 +232,15 @@ struct RemoveDeallocMemrefsContainedInRetained
struct RemoveRetainedMemrefsGuaranteedToNotAlias
: public OpRewritePattern<DeallocOp> {
RemoveRetainedMemrefsGuaranteedToNotAlias(MLIRContext *context,
- AliasAnalysis &aliasAnalysis)
- : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
+ BufferOriginAnalysis &analysis)
+ : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
LogicalResult matchAndRewrite(DeallocOp deallocOp,
PatternRewriter &rewriter) const override {
SmallVector<Value> newRetainedMemrefs, replacements;
for (auto retainedMemref : deallocOp.getRetained()) {
- if (potentiallyAliasesMemref(aliasAnalysis, deallocOp.getMemrefs(),
+ if (potentiallyAliasesMemref(analysis, deallocOp.getMemrefs(),
retainedMemref)) {
newRetainedMemrefs.push_back(retainedMemref);
replacements.push_back({});
@@ -264,7 +268,7 @@ struct RemoveRetainedMemrefsGuaranteedToNotAlias
}
private:
- AliasAnalysis &aliasAnalysis;
+ BufferOriginAnalysis &analysis;
};
/// Split off memrefs to separate dealloc operations to reduce the number of
@@ -297,8 +301,8 @@ struct RemoveRetainedMemrefsGuaranteedToNotAlias
struct SplitDeallocWhenNotAliasingAnyOther
: public OpRewritePattern<DeallocOp> {
SplitDeallocWhenNotAliasingAnyOther(MLIRContext *context,
- AliasAnalysis &aliasAnalysis)
- : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
+ BufferOriginAnalysis &analysis)
+ : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
LogicalResult matchAndRewrite(DeallocOp deallocOp,
PatternRewriter &rewriter) const override {
@@ -314,7 +318,7 @@ struct SplitDeallocWhenNotAliasingAnyOther
SmallVector<Value> otherMemrefs(deallocOp.getMemrefs());
otherMemrefs.erase(otherMemrefs.begin() + i);
// Check if `memref` can split off into a separate bufferization.dealloc.
- if (potentiallyAliasesMemref(aliasAnalysis, otherMemrefs, memref)) {
+ if (potentiallyAliasesMemref(analysis, otherMemrefs, memref)) {
// `memref` alias with other memrefs, do not split off.
remainingMemrefs.push_back(memref);
remainingConditions.push_back(cond);
@@ -352,7 +356,7 @@ struct SplitDeallocWhenNotAliasingAnyOther
}
private:
- AliasAnalysis &aliasAnalysis;
+ BufferOriginAnalysis &analysis;
};
/// Check for every retained memref if a must-aliasing memref exists in the
@@ -381,8 +385,8 @@ struct SplitDeallocWhenNotAliasingAnyOther
struct RetainedMemrefAliasingAlwaysDeallocatedMemref
: public OpRewritePattern<DeallocOp> {
RetainedMemrefAliasingAlwaysDeallocatedMemref(MLIRContext *context,
- AliasAnalysis &aliasAnalysis)
- : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
+ BufferOriginAnalysis &analysis)
+ : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
LogicalResult matchAndRewrite(DeallocOp deallocOp,
PatternRewriter &rewriter) const override {
@@ -396,8 +400,9 @@ struct RetainedMemrefAliasingAlwaysDeallocatedMemref
if (!matchPattern(cond, m_One()))
continue;
- AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
- if (analysisResult.isMust() || analysisResult.isPartial()) {
+ std::optional<bool> analysisResult =
+ analysis.isSameAllocation(retained, memref);
+ if (analysisResult == true) {
rewriter.replaceAllUsesWith(res, cond);
aliasesWithConstTrueMemref[i] = true;
canDropMemref = true;
@@ -411,10 +416,9 @@ struct RetainedMemrefAliasingAlwaysDeallocatedMemref
if (!extractOp)
continue;
- AliasResult extractAnalysisResult =
- aliasAnalysis.alias(retained, extractOp.getOperand());
- if (extractAnalysisResult.isMust() ||
- extractAnalysisResult.isPartial()) {
+ std::optional<bool> extractAnalysisResult =
+ analysis.isSameAllocation(retained, extractOp.getOperand());
+ if (extractAnalysisResult == true) {
rewriter.replaceAllUsesWith(res, cond);
aliasesWithConstTrueMemref[i] = true;
canDropMemref = true;
@@ -434,7 +438,7 @@ struct RetainedMemrefAliasingAlwaysDeallocatedMemref
}
private:
- AliasAnalysis &aliasAnalysis;
+ BufferOriginAnalysis &analysis;
};
} // namespace
@@ -452,13 +456,13 @@ struct BufferDeallocationSimplificationPass
: public bufferization::impl::BufferDeallocationSimplificationBase<
BufferDeallocationSimplificationPass> {
void runOnOperation() override {
- AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>();
+ BufferOriginAnalysis analysis(getOperation());
RewritePatternSet patterns(&getContext());
patterns.add<RemoveDeallocMemrefsContainedInRetained,
RemoveRetainedMemrefsGuaranteedToNotAlias,
SplitDeallocWhenNotAliasingAnyOther,
RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
- aliasAnalysis);
+ analysis);
populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
if (failed(
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
index 9a36057425f366..72f47b8b468ea6 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
@@ -19,22 +19,23 @@
using namespace mlir;
using namespace mlir::bufferization;
+//===----------------------------------------------------------------------===//
+// BufferViewFlowAnalysis
+//===----------------------------------------------------------------------===//
+
/// Constructs a new alias analysis using the op provided.
BufferViewFlowAnalysis::BufferViewFlowAnalysis(Operation *op) { build(op); }
-/// Find all immediate and indirect dependent buffers this value could
-/// potentially have. Note that the resulting set will also contain the value
-/// provided as it is a dependent alias of itself.
-BufferViewFlowAnalysis::ValueSetT
-BufferViewFlowAnalysis::resolve(Value rootValue) const {
- ValueSetT result;
+static BufferViewFlowAnalysis::ValueSetT
+resolveValues(const BufferViewFlowAnalysis::ValueMapT &map, Value value) {
+ BufferViewFlowAnalysis::ValueSetT result;
SmallVector<Value, 8> queue;
- queue.push_back(rootValue);
+ queue.push_back(value);
while (!queue.empty()) {
Value currentValue = queue.pop_back_val();
if (result.insert(currentValue).second) {
- auto it = dependencies.find(currentValue);
- if (it != dependencies.end()) {
+ auto it = map.find(currentValue);
+ if (it != map.end()) {
for (Value aliasValue : it->second)
queue.push_back(aliasValue);
}
@@ -43,6 +44,19 @@ BufferViewFlowAnalysis::resolve(Value rootValue) const {
return result;
}
+/// Find all immediate and indirect dependent buffers this value could
+/// potentially have. Note that the resulting set will also contain the value
+/// provided as it is a dependent alias of itself.
+BufferViewFlowAnalysis::ValueSetT
+BufferViewFlowAnalysis::resolve(Value rootValue) const {
+ return resolveValues(dependencies, rootValue);
+}
+
+BufferViewFlowAnalysis::ValueSetT
+BufferViewFlowAnalysis::resolveReverse(Value rootValue) const {
+ return resolveValues(reverseDependencies, rootValue);
+}
+
/// Removes the given values from all alias sets.
void BufferViewFlowAnalysis::remove(const SetVector<Value> &aliasValues) {
for (auto &entry : dependencies)
@@ -69,8 +83,10 @@ void BufferViewFlowAnalysis::rename(Value from, Value to) {
void BufferViewFlowAnalysis::build(Operation *op) {
// Registers all dependencies of the given values.
auto registerDependencies = [&](ValueRange values, ValueRange dependencies) {
- for (auto [value, dep] : llvm::zip_equal(values, dependencies))
+ for (auto [value, dep] : llvm::zip_equal(values, dependencies)) {
this->dependencies[value].insert(dep);
+ this->reverseDependencies[dep].insert(value);
+ }
};
// Mark all buffer results and buffer region entry block arguments of the
@@ -188,3 +204,127 @@ bool BufferViewFlowAnalysis::mayBeTerminalBuffer(Value value) const {
assert(isa<BaseMemRefType>(value.getType()) && "expected memref");
return terminals.contains(value);
}
+
+//===----------------------------------------------------------------------===//
+// BufferOriginAnalysis
+//===----------------------------------------------------------------------===//
+
+/// Return "true" if the given value is the result of a memory allocation.
+static bool hasAllocateSideEffect(Value v) {
+ Operation *op = v.getDefiningOp();
+ if (!op)
+ return false;
+ return hasEffect<MemoryEffects::Allocate>(op, v);
+}
+
+/// Return "true" if the given value is a function block argument.
+static bool isFunctionArgument(Value v) {
+ auto bbArg = dyn_cast<BlockArgument>(v);
+ if (!bbArg)
+ return false;
+ Block *b = bbArg.getOwner();
+ auto funcOp = dyn_cast<FunctionOpInterface>(b->getParentOp());
+ if (!funcOp)
+ return false;
+ return bbArg.getOwner() == &funcOp.getFunctionBody().front();
+}
+
+/// Given a memref value, return the "base" value by skipping over all
+/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
+static Value getViewBase(Value value) {
+ while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
+ value = viewLikeOp.getViewSource();
+ return value;
+}
+
+BufferOriginAnalysis::BufferOriginAnalysis(Operation *op) : analysis(op) {}
+
+std::optional<bool> BufferOriginAnalysis::isSameAllocation(Value v1, Value v2) {
+ assert(isa<BaseMemRefType>(v1.getType()) && "expected buffer");
+ assert(isa<BaseMemRefType>(v2.getType()) && "expected buffer");
+
+ // Skip over all view-like ops.
+ v1 = getViewBase(v1);
+ v2 = getViewBase(v2);
+
+ // Fast path: If both buffers are the same SSA value, we can be sure that
+ // they originate from the same allocation.
+ if (v1 == v2)
+ return true;
+
+ // Compute the SSA values from which the buffers `v1` and `v2` originate.
+ SmallPtrSet<Value, 16> origin1 = analysis.resolveReverse(v1);
+ SmallPtrSet<Value, 16> origin2 = analysis.resolveReverse(v2);
+
+ // Originating buffers are "terminal" if they could not be traced back any
+ // further by the `BufferViewFlowAnalysis`. Examples of terminal buffers:
+ // - function block arguments
+ // - values defined by allocation ops such as "memref.alloc"
+ // - values defined by ops that are unknown to the buffer view flow analysis
+ // - values that are marked as "terminal" in the `BufferViewFlowOpInterface`
+ SmallPtrSet<Value, 16> terminal1, terminal2;
+
+ // While gathering terminal buffers, keep track of whether all terminal
+ // buffers are newly allocated buffer or function entry arguments.
+ bool allAllocs1 = true, allAllocs2 = true;
+ bool allAllocsOrFuncEntryArgs1 = true, allAllocsOrFuncEntryArgs2 = true;
+
+ // Helper function that gathers terminal buffers among `origin`.
+ auto gatherTerminalBuffers = [this](const SmallPtrSet<Value, 16> &origin,
+ SmallPtrSet<Value, 16> &terminal,
+ bool &allAllocs,
+ bool &allAllocsOrFuncEntryArgs) {
+ for (Value v : origin) {
+ if (isa<BaseMemRefType>(v.getType()) && analysis.mayBeTerminalBuffer(v)) {
+ terminal.insert(v);
+ allAllocs &= hasAllocateSideEffect(v);
+ allAllocsOrFuncEntryArgs &=
+ isFunctionArgument(v) || hasAllocateSideEffect(v);
+ }
+ }
+ assert(!terminal.empty() && "expected non-empty terminal set");
+ };
+
+ // Gather terminal buffers for `v1` and `v2`.
+ gatherTerminalBuffers(origin1, terminal1, allAllocs1,
+ ...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, looks good and useful.
I wonder whether this could be refactored to be a more generally useful utility: tracing back through sets of values to terminal nodes is something I could reuse in other places.
Having a single impl that is configurable with the proper conditions and op/interfaces would be useful.
I am working on something similar. In order to support more cases, I need to do multiple graph traversals over the IR (variants of DFS, checking reachability, etc.) and what you suggested will probably be part of that. (Still working on the design...) |
This commit adds the
BufferOriginAnalysis
, which can be queried to check if two buffer SSA values originate from the same allocation. This new analysis is used in the buffer deallocation pass to fold away or simplifybufferization.dealloc
ops more aggressively.The
BufferOriginAnalysis
is based on theBufferViewFlowAnalysis
, which collects buffer SSA value "same buffer" dependencies. E.g., given IR such as:The
BufferViewFlowAnalysis
will report the following "reverse" dependencies (resolveReverse
) for%2
: {%2
,%1
,%0
}. I.e., all buffer SSA values in the reverse use-def chain that originate from the same allocation as%2
. TheBufferOriginAnalysis
is built on top of that. It handles only simple cases at the moment and may conservatively return "unknown" around certain IR with branches, memref globals and function arguments.This analysis enables additional simplifications during
-buffer-deallocation-simplification
. In particular, "regular" scf.for loop nests, that yield buffers (or reallocations thereof) in the same order as they appear in the iter_args, are now handled much more efficiently. Such IR patterns are generated by the sparse compiler.