Skip to content

[mlir][bufferization] Add BufferOriginAnalysis #86461

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Mar 25, 2024

This commit adds the BufferOriginAnalysis, which can be queried to check if two buffer SSA values originate from the same allocation. This new analysis is used in the buffer deallocation pass to fold away or simplify bufferization.dealloc ops more aggressively.

The BufferOriginAnalysis is based on the BufferViewFlowAnalysis, which collects buffer SSA value "same buffer" dependencies. E.g., given IR such as:

%0 = memref.alloc()
%1 = memref.subview %0
%2 = memref.subview %1

The BufferViewFlowAnalysis will report the following "reverse" dependencies (resolveReverse) for %2: {%2, %1, %0}. I.e., all buffer SSA values in the reverse use-def chain that originate from the same allocation as %2. The BufferOriginAnalysis is built on top of that. It handles only simple cases at the moment and may conservatively return "unknown" around certain IR with branches, memref globals and function arguments.

This analysis enables additional simplifications during -buffer-deallocation-simplification. In particular, "regular" scf.for loop nests, that yield buffers (or reallocations thereof) in the same order as they appear in the iter_args, are now handled much more efficiently. Such IR patterns are generated by the sparse compiler.

Copy link

✅ With the latest revision this PR passed the Python code formatter.

Copy link

✅ With the latest revision this PR passed the C/C++ code formatter.

@matthias-springer matthias-springer force-pushed the users/matthias-springer/buffer_origin_analysis branch from 83ca1fb to be4524e Compare March 25, 2024 02:58
@matthias-springer matthias-springer marked this pull request as ready for review March 25, 2024 02:59
@llvmbot llvmbot added the mlir label Mar 25, 2024
@llvmbot llvmbot added the mlir:bufferization Bufferization infrastructure label Mar 25, 2024
@llvmbot
Copy link
Member

llvmbot commented Mar 25, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-bufferization

Author: Matthias Springer (matthias-springer)

Changes

This commit adds the BufferOriginAnalysis, which can be queried to check if two buffer SSA values originate from the same allocation. This new analysis is used in the buffer deallocation pass to fold away or simplify bufferization.dealloc ops more aggressively.

The BufferOriginAnalysis is based on the BufferViewFlowAnalysis, which collects buffer SSA value "same buffer" dependencies. E.g., given IR such as:

%0 = memref.alloc()
%1 = memref.subview %0
%2 = memref.subview %1

The BufferViewFlowAnalysis will report the following "reverse" dependencies (resolveReverse) for %2: {%2, %1, %0}. I.e., all buffer SSA values that originate from the same allocation as %2. The BufferOriginAnalysis is built on top of that. It handles only simple cases at the moment and may conservatively return "unknown" around certain IR with branches, memref globals and function arguments.

This analysis enables additional simplifications during -buffer-deallocation-simplification. In particular, "regular" scf.for loop nests, that yield buffers (or reallocations thereof) in the same order as they appear in the iter_args, are now handled much more efficiently. Such IR patterns are generated by the sparse compiler.


Patch is 29.02 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/86461.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h (+36)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp (+42-38)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp (+150-10)
  • (added) mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-loops.mlir (+86)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir (+6-8)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 9dc6afcaab31c8..4f609ddff9a413 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -10,6 +10,7 @@
 #define BUFFERIZATION_OPS
 
 include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
+include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
index 9e43265c5dfede..4015231c845daf 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
@@ -53,6 +53,7 @@ class BufferViewFlowAnalysis {
   ///
   /// Results in resolve(B) returning {B, C}
   ValueSetT resolve(Value value) const;
+  ValueSetT resolveReverse(Value value) const;
 
   /// Removes the given values from all alias sets.
   void remove(const SetVector<Value> &aliasValues);
@@ -73,11 +74,46 @@ class BufferViewFlowAnalysis {
 
   /// Maps values to all immediate dependencies this value can have.
   ValueMapT dependencies;
+  ValueMapT reverseDependencies;
 
   /// A set of all SSA values that may be terminal buffers.
   DenseSet<Value> terminals;
 };
 
+/// An is-same-buffer analysis that checks if two SSA values belong to the same
+/// buffer allocation or not.
+class BufferOriginAnalysis {
+public:
+  BufferOriginAnalysis(Operation *op);
+
+  /// Return "true" if `v1` and `v2` originate from the same buffer allocation.
+  /// Return "false" if `v1` and `v2` originate from different allocations.
+  /// Return "nullopt" if we do not know for sure.
+  ///
+  /// Example 1: isSameAllocation(%0, %1) == true
+  /// ```
+  /// %0 = memref.alloc()
+  /// %1 = memref.subview %0
+  /// ```
+  ///
+  /// Example 2: isSameAllocation(%0, %1) == false
+  /// ```
+  /// %0 = memref.alloc()
+  /// %1 = memref.alloc()
+  /// ```
+  ///
+  /// Example 3: isSameAllocation(%0, %2) == nullopt
+  /// ```
+  /// %0 = memref.alloc()
+  /// %1 = memref.alloc()
+  /// %2 = arith.select %c, %0, %1
+  /// ```
+  std::optional<bool> isSameAllocation(Value v1, Value v2);
+
+private:
+  BufferViewFlowAnalysis analysis;
+};
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_BUFFERVIEWFLOWANALYSIS_H
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index e30779868b4753..954485cfede3da 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -12,8 +12,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Analysis/AliasAnalysis.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -34,6 +34,14 @@ using namespace mlir::bufferization;
 // Helpers
 //===----------------------------------------------------------------------===//
 
+/// Given a memref value, return the "base" value by skipping over all
+/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
+static Value getViewBase(Value value) {
+  while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
+    value = viewLikeOp.getViewSource();
+  return value;
+}
+
 static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
                                             ValueRange memrefs,
                                             ValueRange conditions,
@@ -49,14 +57,6 @@ static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
   return success();
 }
 
-/// Given a memref value, return the "base" value by skipping over all
-/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
-static Value getViewBase(Value value) {
-  while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
-    value = viewLikeOp.getViewSource();
-  return value;
-}
-
 /// Return "true" if the given values are guaranteed to be different (and
 /// non-aliasing) allocations based on the fact that one value is the result
 /// of an allocation and the other value is a block argument of a parent block.
@@ -80,12 +80,14 @@ static bool distinctAllocAndBlockArgument(Value v1, Value v2) {
 /// Checks if `memref` may potentially alias a MemRef in `otherList`. It is
 /// often a requirement of optimization patterns that there cannot be any
 /// aliasing memref in order to perform the desired simplification.
-static bool potentiallyAliasesMemref(AliasAnalysis &analysis,
+static bool potentiallyAliasesMemref(BufferOriginAnalysis &analysis,
                                      ValueRange otherList, Value memref) {
   for (auto other : otherList) {
     if (distinctAllocAndBlockArgument(other, memref))
       continue;
-    if (!analysis.alias(other, memref).isNo())
+    std::optional<bool> analysisResult =
+        analysis.isSameAllocation(other, memref);
+    if (!analysisResult.has_value() || analysisResult == true)
       return true;
   }
   return false;
@@ -129,8 +131,8 @@ namespace {
 struct RemoveDeallocMemrefsContainedInRetained
     : public OpRewritePattern<DeallocOp> {
   RemoveDeallocMemrefsContainedInRetained(MLIRContext *context,
-                                          AliasAnalysis &aliasAnalysis)
-      : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
+                                          BufferOriginAnalysis &analysis)
+      : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
 
   /// The passed 'memref' must not have a may-alias relation to any retained
   /// memref, and at least one must-alias relation. If there is no must-aliasing
@@ -147,10 +149,11 @@ struct RemoveDeallocMemrefsContainedInRetained
     // deallocated in some situations and can thus not be dropped).
     bool atLeastOneMustAlias = false;
     for (Value retained : deallocOp.getRetained()) {
-      AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
-      if (analysisResult.isMay())
+      std::optional<bool> analysisResult =
+          analysis.isSameAllocation(retained, memref);
+      if (!analysisResult.has_value())
         return failure();
-      if (analysisResult.isMust() || analysisResult.isPartial())
+      if (analysisResult == true)
         atLeastOneMustAlias = true;
     }
     if (!atLeastOneMustAlias)
@@ -161,8 +164,9 @@ struct RemoveDeallocMemrefsContainedInRetained
     // we can remove that operand later on.
     for (auto [i, retained] : llvm::enumerate(deallocOp.getRetained())) {
       Value updatedCondition = deallocOp.getUpdatedConditions()[i];
-      AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
-      if (analysisResult.isMust() || analysisResult.isPartial()) {
+      std::optional<bool> analysisResult =
+          analysis.isSameAllocation(retained, memref);
+      if (analysisResult == true) {
         auto disjunction = rewriter.create<arith::OrIOp>(
             deallocOp.getLoc(), updatedCondition, cond);
         rewriter.replaceAllUsesExcept(updatedCondition, disjunction.getResult(),
@@ -206,7 +210,7 @@ struct RemoveDeallocMemrefsContainedInRetained
   }
 
 private:
-  AliasAnalysis &aliasAnalysis;
+  BufferOriginAnalysis &analysis;
 };
 
 /// Remove memrefs from the `retained` list which are guaranteed to not alias
@@ -228,15 +232,15 @@ struct RemoveDeallocMemrefsContainedInRetained
 struct RemoveRetainedMemrefsGuaranteedToNotAlias
     : public OpRewritePattern<DeallocOp> {
   RemoveRetainedMemrefsGuaranteedToNotAlias(MLIRContext *context,
-                                            AliasAnalysis &aliasAnalysis)
-      : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
+                                            BufferOriginAnalysis &analysis)
+      : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
 
   LogicalResult matchAndRewrite(DeallocOp deallocOp,
                                 PatternRewriter &rewriter) const override {
     SmallVector<Value> newRetainedMemrefs, replacements;
 
     for (auto retainedMemref : deallocOp.getRetained()) {
-      if (potentiallyAliasesMemref(aliasAnalysis, deallocOp.getMemrefs(),
+      if (potentiallyAliasesMemref(analysis, deallocOp.getMemrefs(),
                                    retainedMemref)) {
         newRetainedMemrefs.push_back(retainedMemref);
         replacements.push_back({});
@@ -264,7 +268,7 @@ struct RemoveRetainedMemrefsGuaranteedToNotAlias
   }
 
 private:
-  AliasAnalysis &aliasAnalysis;
+  BufferOriginAnalysis &analysis;
 };
 
 /// Split off memrefs to separate dealloc operations to reduce the number of
@@ -297,8 +301,8 @@ struct RemoveRetainedMemrefsGuaranteedToNotAlias
 struct SplitDeallocWhenNotAliasingAnyOther
     : public OpRewritePattern<DeallocOp> {
   SplitDeallocWhenNotAliasingAnyOther(MLIRContext *context,
-                                      AliasAnalysis &aliasAnalysis)
-      : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
+                                      BufferOriginAnalysis &analysis)
+      : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
 
   LogicalResult matchAndRewrite(DeallocOp deallocOp,
                                 PatternRewriter &rewriter) const override {
@@ -314,7 +318,7 @@ struct SplitDeallocWhenNotAliasingAnyOther
       SmallVector<Value> otherMemrefs(deallocOp.getMemrefs());
       otherMemrefs.erase(otherMemrefs.begin() + i);
       // Check if `memref` can split off into a separate bufferization.dealloc.
-      if (potentiallyAliasesMemref(aliasAnalysis, otherMemrefs, memref)) {
+      if (potentiallyAliasesMemref(analysis, otherMemrefs, memref)) {
         // `memref` alias with other memrefs, do not split off.
         remainingMemrefs.push_back(memref);
         remainingConditions.push_back(cond);
@@ -352,7 +356,7 @@ struct SplitDeallocWhenNotAliasingAnyOther
   }
 
 private:
-  AliasAnalysis &aliasAnalysis;
+  BufferOriginAnalysis &analysis;
 };
 
 /// Check for every retained memref if a must-aliasing memref exists in the
@@ -381,8 +385,8 @@ struct SplitDeallocWhenNotAliasingAnyOther
 struct RetainedMemrefAliasingAlwaysDeallocatedMemref
     : public OpRewritePattern<DeallocOp> {
   RetainedMemrefAliasingAlwaysDeallocatedMemref(MLIRContext *context,
-                                                AliasAnalysis &aliasAnalysis)
-      : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
+                                                BufferOriginAnalysis &analysis)
+      : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
 
   LogicalResult matchAndRewrite(DeallocOp deallocOp,
                                 PatternRewriter &rewriter) const override {
@@ -396,8 +400,9 @@ struct RetainedMemrefAliasingAlwaysDeallocatedMemref
         if (!matchPattern(cond, m_One()))
           continue;
 
-        AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
-        if (analysisResult.isMust() || analysisResult.isPartial()) {
+        std::optional<bool> analysisResult =
+            analysis.isSameAllocation(retained, memref);
+        if (analysisResult == true) {
           rewriter.replaceAllUsesWith(res, cond);
           aliasesWithConstTrueMemref[i] = true;
           canDropMemref = true;
@@ -411,10 +416,9 @@ struct RetainedMemrefAliasingAlwaysDeallocatedMemref
         if (!extractOp)
           continue;
 
-        AliasResult extractAnalysisResult =
-            aliasAnalysis.alias(retained, extractOp.getOperand());
-        if (extractAnalysisResult.isMust() ||
-            extractAnalysisResult.isPartial()) {
+        std::optional<bool> extractAnalysisResult =
+            analysis.isSameAllocation(retained, extractOp.getOperand());
+        if (extractAnalysisResult == true) {
           rewriter.replaceAllUsesWith(res, cond);
           aliasesWithConstTrueMemref[i] = true;
           canDropMemref = true;
@@ -434,7 +438,7 @@ struct RetainedMemrefAliasingAlwaysDeallocatedMemref
   }
 
 private:
-  AliasAnalysis &aliasAnalysis;
+  BufferOriginAnalysis &analysis;
 };
 
 } // namespace
@@ -452,13 +456,13 @@ struct BufferDeallocationSimplificationPass
     : public bufferization::impl::BufferDeallocationSimplificationBase<
           BufferDeallocationSimplificationPass> {
   void runOnOperation() override {
-    AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>();
+    BufferOriginAnalysis analysis(getOperation());
     RewritePatternSet patterns(&getContext());
     patterns.add<RemoveDeallocMemrefsContainedInRetained,
                  RemoveRetainedMemrefsGuaranteedToNotAlias,
                  SplitDeallocWhenNotAliasingAnyOther,
                  RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
-                                                                aliasAnalysis);
+                                                                analysis);
     populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
 
     if (failed(
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
index 9a36057425f366..72f47b8b468ea6 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
@@ -19,22 +19,23 @@
 using namespace mlir;
 using namespace mlir::bufferization;
 
+//===----------------------------------------------------------------------===//
+// BufferViewFlowAnalysis
+//===----------------------------------------------------------------------===//
+
 /// Constructs a new alias analysis using the op provided.
 BufferViewFlowAnalysis::BufferViewFlowAnalysis(Operation *op) { build(op); }
 
-/// Find all immediate and indirect dependent buffers this value could
-/// potentially have. Note that the resulting set will also contain the value
-/// provided as it is a dependent alias of itself.
-BufferViewFlowAnalysis::ValueSetT
-BufferViewFlowAnalysis::resolve(Value rootValue) const {
-  ValueSetT result;
+static BufferViewFlowAnalysis::ValueSetT
+resolveValues(const BufferViewFlowAnalysis::ValueMapT &map, Value value) {
+  BufferViewFlowAnalysis::ValueSetT result;
   SmallVector<Value, 8> queue;
-  queue.push_back(rootValue);
+  queue.push_back(value);
   while (!queue.empty()) {
     Value currentValue = queue.pop_back_val();
     if (result.insert(currentValue).second) {
-      auto it = dependencies.find(currentValue);
-      if (it != dependencies.end()) {
+      auto it = map.find(currentValue);
+      if (it != map.end()) {
         for (Value aliasValue : it->second)
           queue.push_back(aliasValue);
       }
@@ -43,6 +44,19 @@ BufferViewFlowAnalysis::resolve(Value rootValue) const {
   return result;
 }
 
+/// Find all immediate and indirect dependent buffers this value could
+/// potentially have. Note that the resulting set will also contain the value
+/// provided as it is a dependent alias of itself.
+BufferViewFlowAnalysis::ValueSetT
+BufferViewFlowAnalysis::resolve(Value rootValue) const {
+  return resolveValues(dependencies, rootValue);
+}
+
+BufferViewFlowAnalysis::ValueSetT
+BufferViewFlowAnalysis::resolveReverse(Value rootValue) const {
+  return resolveValues(reverseDependencies, rootValue);
+}
+
 /// Removes the given values from all alias sets.
 void BufferViewFlowAnalysis::remove(const SetVector<Value> &aliasValues) {
   for (auto &entry : dependencies)
@@ -69,8 +83,10 @@ void BufferViewFlowAnalysis::rename(Value from, Value to) {
 void BufferViewFlowAnalysis::build(Operation *op) {
   // Registers all dependencies of the given values.
   auto registerDependencies = [&](ValueRange values, ValueRange dependencies) {
-    for (auto [value, dep] : llvm::zip_equal(values, dependencies))
+    for (auto [value, dep] : llvm::zip_equal(values, dependencies)) {
       this->dependencies[value].insert(dep);
+      this->reverseDependencies[dep].insert(value);
+    }
   };
 
   // Mark all buffer results and buffer region entry block arguments of the
@@ -188,3 +204,127 @@ bool BufferViewFlowAnalysis::mayBeTerminalBuffer(Value value) const {
   assert(isa<BaseMemRefType>(value.getType()) && "expected memref");
   return terminals.contains(value);
 }
+
+//===----------------------------------------------------------------------===//
+// BufferOriginAnalysis
+//===----------------------------------------------------------------------===//
+
+/// Return "true" if the given value is the result of a memory allocation.
+static bool hasAllocateSideEffect(Value v) {
+  Operation *op = v.getDefiningOp();
+  if (!op)
+    return false;
+  return hasEffect<MemoryEffects::Allocate>(op, v);
+}
+
+/// Return "true" if the given value is a function block argument.
+static bool isFunctionArgument(Value v) {
+  auto bbArg = dyn_cast<BlockArgument>(v);
+  if (!bbArg)
+    return false;
+  Block *b = bbArg.getOwner();
+  auto funcOp = dyn_cast<FunctionOpInterface>(b->getParentOp());
+  if (!funcOp)
+    return false;
+  return bbArg.getOwner() == &funcOp.getFunctionBody().front();
+}
+
+/// Given a memref value, return the "base" value by skipping over all
+/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
+static Value getViewBase(Value value) {
+  while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
+    value = viewLikeOp.getViewSource();
+  return value;
+}
+
+BufferOriginAnalysis::BufferOriginAnalysis(Operation *op) : analysis(op) {}
+
+std::optional<bool> BufferOriginAnalysis::isSameAllocation(Value v1, Value v2) {
+  assert(isa<BaseMemRefType>(v1.getType()) && "expected buffer");
+  assert(isa<BaseMemRefType>(v2.getType()) && "expected buffer");
+
+  // Skip over all view-like ops.
+  v1 = getViewBase(v1);
+  v2 = getViewBase(v2);
+
+  // Fast path: If both buffers are the same SSA value, we can be sure that
+  // they originate from the same allocation.
+  if (v1 == v2)
+    return true;
+
+  // Compute the SSA values from which the buffers `v1` and `v2` originate.
+  SmallPtrSet<Value, 16> origin1 = analysis.resolveReverse(v1);
+  SmallPtrSet<Value, 16> origin2 = analysis.resolveReverse(v2);
+
+  // Originating buffers are "terminal" if they could not be traced back any
+  // further by the `BufferViewFlowAnalysis`. Examples of terminal buffers:
+  // - function block arguments
+  // - values defined by allocation ops such as "memref.alloc"
+  // - values defined by ops that are unknown to the buffer view flow analysis
+  // - values that are marked as "terminal" in the `BufferViewFlowOpInterface`
+  SmallPtrSet<Value, 16> terminal1, terminal2;
+
+  // While gathering terminal buffers, keep track of whether all terminal
+  // buffers are newly allocated buffer or function entry arguments.
+  bool allAllocs1 = true, allAllocs2 = true;
+  bool allAllocsOrFuncEntryArgs1 = true, allAllocsOrFuncEntryArgs2 = true;
+
+  // Helper function that gathers terminal buffers among `origin`.
+  auto gatherTerminalBuffers = [this](const SmallPtrSet<Value, 16> &origin,
+                                      SmallPtrSet<Value, 16> &terminal,
+                                      bool &allAllocs,
+                                      bool &allAllocsOrFuncEntryArgs) {
+    for (Value v : origin) {
+      if (isa<BaseMemRefType>(v.getType()) && analysis.mayBeTerminalBuffer(v)) {
+        terminal.insert(v);
+        allAllocs &= hasAllocateSideEffect(v);
+        allAllocsOrFuncEntryArgs &=
+            isFunctionArgument(v) || hasAllocateSideEffect(v);
+      }
+    }
+    assert(!terminal.empty() && "expected non-empty terminal set");
+  };
+
+  // Gather terminal buffers for `v1` and `v2`.
+  gatherTerminalBuffers(origin1, terminal1, allAllocs1,
+      ...
[truncated]

Copy link
Contributor

@nicolasvasilache nicolasvasilache left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, looks good and useful.

I wonder whether this could be refactored to be a more generally useful utility: tracing back through sets of values to terminal nodes is something I could reuse in other places.

Having a single impl that is configurable with the proper conditions and op/interfaces would be useful.

@nicolasvasilache nicolasvasilache self-requested a review March 25, 2024 07:27
@matthias-springer
Copy link
Member Author

I wonder whether this could be refactored to be a more generally useful utility: tracing back through sets of values to terminal nodes is something I could reuse in other places.

I am working on something similar. In order to support more cases, I need to do multiple graph traversals over the IR (variants of DFS, checking reachability, etc.) and what you suggested will probably be part of that. (Still working on the design...)

@matthias-springer matthias-springer merged commit dbfc38e into main Mar 25, 2024
@matthias-springer matthias-springer deleted the users/matthias-springer/buffer_origin_analysis branch March 25, 2024 09:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:bufferization Bufferization infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants