Skip to content

[mlir][bufferization] Add BufferOriginAnalysis #86461

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#define BUFFERIZATION_OPS

include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class BufferViewFlowAnalysis {
///
/// Results in resolve(B) returning {B, C}
ValueSetT resolve(Value value) const;
ValueSetT resolveReverse(Value value) const;

/// Removes the given values from all alias sets.
void remove(const SetVector<Value> &aliasValues);
Expand All @@ -73,11 +74,46 @@ class BufferViewFlowAnalysis {

/// Maps values to all immediate dependencies this value can have.
ValueMapT dependencies;
ValueMapT reverseDependencies;

/// A set of all SSA values that may be terminal buffers.
DenseSet<Value> terminals;
};

/// An is-same-buffer analysis that checks if two SSA values belong to the same
/// buffer allocation or not.
class BufferOriginAnalysis {
public:
BufferOriginAnalysis(Operation *op);

/// Return "true" if `v1` and `v2` originate from the same buffer allocation.
/// Return "false" if `v1` and `v2` originate from different allocations.
/// Return "nullopt" if we do not know for sure.
///
/// Example 1: isSameAllocation(%0, %1) == true
/// ```
/// %0 = memref.alloc()
/// %1 = memref.subview %0
/// ```
///
/// Example 2: isSameAllocation(%0, %1) == false
/// ```
/// %0 = memref.alloc()
/// %1 = memref.alloc()
/// ```
///
/// Example 3: isSameAllocation(%0, %2) == nullopt
/// ```
/// %0 = memref.alloc()
/// %1 = memref.alloc()
/// %2 = arith.select %c, %0, %1
/// ```
std::optional<bool> isSameAllocation(Value v1, Value v2);

private:
BufferViewFlowAnalysis analysis;
};

} // namespace mlir

#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_BUFFERVIEWFLOWANALYSIS_H
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Analysis/AliasAnalysis.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
Expand All @@ -34,6 +34,14 @@ using namespace mlir::bufferization;
// Helpers
//===----------------------------------------------------------------------===//

/// Given a memref value, return the "base" value by skipping over all
/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
static Value getViewBase(Value value) {
while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
value = viewLikeOp.getViewSource();
return value;
}

static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
ValueRange memrefs,
ValueRange conditions,
Expand All @@ -49,14 +57,6 @@ static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
return success();
}

/// Given a memref value, return the "base" value by skipping over all
/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
static Value getViewBase(Value value) {
while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
value = viewLikeOp.getViewSource();
return value;
}

/// Return "true" if the given values are guaranteed to be different (and
/// non-aliasing) allocations based on the fact that one value is the result
/// of an allocation and the other value is a block argument of a parent block.
Expand All @@ -80,12 +80,14 @@ static bool distinctAllocAndBlockArgument(Value v1, Value v2) {
/// Checks if `memref` may potentially alias a MemRef in `otherList`. It is
/// often a requirement of optimization patterns that there cannot be any
/// aliasing memref in order to perform the desired simplification.
static bool potentiallyAliasesMemref(AliasAnalysis &analysis,
static bool potentiallyAliasesMemref(BufferOriginAnalysis &analysis,
ValueRange otherList, Value memref) {
for (auto other : otherList) {
if (distinctAllocAndBlockArgument(other, memref))
continue;
if (!analysis.alias(other, memref).isNo())
std::optional<bool> analysisResult =
analysis.isSameAllocation(other, memref);
if (!analysisResult.has_value() || analysisResult == true)
return true;
}
return false;
Expand Down Expand Up @@ -129,8 +131,8 @@ namespace {
struct RemoveDeallocMemrefsContainedInRetained
: public OpRewritePattern<DeallocOp> {
RemoveDeallocMemrefsContainedInRetained(MLIRContext *context,
AliasAnalysis &aliasAnalysis)
: OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
BufferOriginAnalysis &analysis)
: OpRewritePattern<DeallocOp>(context), analysis(analysis) {}

/// The passed 'memref' must not have a may-alias relation to any retained
/// memref, and at least one must-alias relation. If there is no must-aliasing
Expand All @@ -147,10 +149,11 @@ struct RemoveDeallocMemrefsContainedInRetained
// deallocated in some situations and can thus not be dropped).
bool atLeastOneMustAlias = false;
for (Value retained : deallocOp.getRetained()) {
AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
if (analysisResult.isMay())
std::optional<bool> analysisResult =
analysis.isSameAllocation(retained, memref);
if (!analysisResult.has_value())
return failure();
if (analysisResult.isMust() || analysisResult.isPartial())
if (analysisResult == true)
atLeastOneMustAlias = true;
}
if (!atLeastOneMustAlias)
Expand All @@ -161,8 +164,9 @@ struct RemoveDeallocMemrefsContainedInRetained
// we can remove that operand later on.
for (auto [i, retained] : llvm::enumerate(deallocOp.getRetained())) {
Value updatedCondition = deallocOp.getUpdatedConditions()[i];
AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
if (analysisResult.isMust() || analysisResult.isPartial()) {
std::optional<bool> analysisResult =
analysis.isSameAllocation(retained, memref);
if (analysisResult == true) {
auto disjunction = rewriter.create<arith::OrIOp>(
deallocOp.getLoc(), updatedCondition, cond);
rewriter.replaceAllUsesExcept(updatedCondition, disjunction.getResult(),
Expand Down Expand Up @@ -206,7 +210,7 @@ struct RemoveDeallocMemrefsContainedInRetained
}

private:
AliasAnalysis &aliasAnalysis;
BufferOriginAnalysis &analysis;
};

/// Remove memrefs from the `retained` list which are guaranteed to not alias
Expand All @@ -228,15 +232,15 @@ struct RemoveDeallocMemrefsContainedInRetained
struct RemoveRetainedMemrefsGuaranteedToNotAlias
: public OpRewritePattern<DeallocOp> {
RemoveRetainedMemrefsGuaranteedToNotAlias(MLIRContext *context,
AliasAnalysis &aliasAnalysis)
: OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
BufferOriginAnalysis &analysis)
: OpRewritePattern<DeallocOp>(context), analysis(analysis) {}

LogicalResult matchAndRewrite(DeallocOp deallocOp,
PatternRewriter &rewriter) const override {
SmallVector<Value> newRetainedMemrefs, replacements;

for (auto retainedMemref : deallocOp.getRetained()) {
if (potentiallyAliasesMemref(aliasAnalysis, deallocOp.getMemrefs(),
if (potentiallyAliasesMemref(analysis, deallocOp.getMemrefs(),
retainedMemref)) {
newRetainedMemrefs.push_back(retainedMemref);
replacements.push_back({});
Expand Down Expand Up @@ -264,7 +268,7 @@ struct RemoveRetainedMemrefsGuaranteedToNotAlias
}

private:
AliasAnalysis &aliasAnalysis;
BufferOriginAnalysis &analysis;
};

/// Split off memrefs to separate dealloc operations to reduce the number of
Expand Down Expand Up @@ -297,8 +301,8 @@ struct RemoveRetainedMemrefsGuaranteedToNotAlias
struct SplitDeallocWhenNotAliasingAnyOther
: public OpRewritePattern<DeallocOp> {
SplitDeallocWhenNotAliasingAnyOther(MLIRContext *context,
AliasAnalysis &aliasAnalysis)
: OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
BufferOriginAnalysis &analysis)
: OpRewritePattern<DeallocOp>(context), analysis(analysis) {}

LogicalResult matchAndRewrite(DeallocOp deallocOp,
PatternRewriter &rewriter) const override {
Expand All @@ -314,7 +318,7 @@ struct SplitDeallocWhenNotAliasingAnyOther
SmallVector<Value> otherMemrefs(deallocOp.getMemrefs());
otherMemrefs.erase(otherMemrefs.begin() + i);
// Check if `memref` can split off into a separate bufferization.dealloc.
if (potentiallyAliasesMemref(aliasAnalysis, otherMemrefs, memref)) {
if (potentiallyAliasesMemref(analysis, otherMemrefs, memref)) {
// `memref` alias with other memrefs, do not split off.
remainingMemrefs.push_back(memref);
remainingConditions.push_back(cond);
Expand Down Expand Up @@ -352,7 +356,7 @@ struct SplitDeallocWhenNotAliasingAnyOther
}

private:
AliasAnalysis &aliasAnalysis;
BufferOriginAnalysis &analysis;
};

/// Check for every retained memref if a must-aliasing memref exists in the
Expand Down Expand Up @@ -381,8 +385,8 @@ struct SplitDeallocWhenNotAliasingAnyOther
struct RetainedMemrefAliasingAlwaysDeallocatedMemref
: public OpRewritePattern<DeallocOp> {
RetainedMemrefAliasingAlwaysDeallocatedMemref(MLIRContext *context,
AliasAnalysis &aliasAnalysis)
: OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
BufferOriginAnalysis &analysis)
: OpRewritePattern<DeallocOp>(context), analysis(analysis) {}

LogicalResult matchAndRewrite(DeallocOp deallocOp,
PatternRewriter &rewriter) const override {
Expand All @@ -396,8 +400,9 @@ struct RetainedMemrefAliasingAlwaysDeallocatedMemref
if (!matchPattern(cond, m_One()))
continue;

AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
if (analysisResult.isMust() || analysisResult.isPartial()) {
std::optional<bool> analysisResult =
analysis.isSameAllocation(retained, memref);
if (analysisResult == true) {
rewriter.replaceAllUsesWith(res, cond);
aliasesWithConstTrueMemref[i] = true;
canDropMemref = true;
Expand All @@ -411,10 +416,9 @@ struct RetainedMemrefAliasingAlwaysDeallocatedMemref
if (!extractOp)
continue;

AliasResult extractAnalysisResult =
aliasAnalysis.alias(retained, extractOp.getOperand());
if (extractAnalysisResult.isMust() ||
extractAnalysisResult.isPartial()) {
std::optional<bool> extractAnalysisResult =
analysis.isSameAllocation(retained, extractOp.getOperand());
if (extractAnalysisResult == true) {
rewriter.replaceAllUsesWith(res, cond);
aliasesWithConstTrueMemref[i] = true;
canDropMemref = true;
Expand All @@ -434,7 +438,7 @@ struct RetainedMemrefAliasingAlwaysDeallocatedMemref
}

private:
AliasAnalysis &aliasAnalysis;
BufferOriginAnalysis &analysis;
};

} // namespace
Expand All @@ -452,13 +456,13 @@ struct BufferDeallocationSimplificationPass
: public bufferization::impl::BufferDeallocationSimplificationBase<
BufferDeallocationSimplificationPass> {
void runOnOperation() override {
AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>();
BufferOriginAnalysis analysis(getOperation());
RewritePatternSet patterns(&getContext());
patterns.add<RemoveDeallocMemrefsContainedInRetained,
RemoveRetainedMemrefsGuaranteedToNotAlias,
SplitDeallocWhenNotAliasingAnyOther,
RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
aliasAnalysis);
analysis);
populateDeallocOpCanonicalizationPatterns(patterns, &getContext());

if (failed(
Expand Down
Loading