Skip to content

[mlir][Interfaces][NFC] Move region loop detection to RegionBranchOpInterface #77090

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Jan 5, 2024

BufferPlacementTransformationBase::isLoop checks if there a loop in the region branching graph of an operation. This algorithm is similar to isRegionReachable in the RegionBranchOpInterface. To avoid duplicate code, isRegionReachable is generalized, so that it can be used to detect region loops. A helper function RegionBranchOpInterface::hasLoop is added.

This change also turns a recursive implementation into an iterative one, which is the preferred implementation strategy in LLVM.

Also move the isLoop to BufferOptimizations.cpp, so that we can gradually retire BufferPlacementTransformationBase. (This is so that proper error handling can be added to BufferViewFlowAnalysis.)

@llvmbot llvmbot added mlir mlir:bufferization Bufferization infrastructure labels Jan 5, 2024
@llvmbot
Copy link
Member

llvmbot commented Jan 5, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-bufferization

Author: Matthias Springer (matthias-springer)

Changes

BufferPlacementTransformationBase::isLoop checks if there a loop in the region branching graph of an operation. This algorithm is similar to isRegionReachable in the RegionBranchOpInterface. To avoid duplicate code, isRegionReachable is generalized, so that it can be used to detect region loops. A helper function RegionBranchOpInterface::hasLoop is added.

This change also turns a recursive implementation into an iterative one, which is the preferred implementation strategy in LLVM.


Full diff: https://github.com/llvm/llvm-project/pull/77090.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Interfaces/ControlFlowInterfaces.td (+4)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp (+3-31)
  • (modified) mlir/lib/Interfaces/ControlFlowInterfaces.cpp (+43-6)
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 120ddf01ebce5c..95ac5dea243aa4 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -272,6 +272,10 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
     /// eventually branch back to the same region. (Maybe after passing through
     /// other regions.)
     bool isRepetitiveRegion(unsigned index);
+
+    /// Return `true` if there is a loop in the region branching graph. Only
+    /// reachable regions (starting from the entry regions) are considered.
+    bool hasLoop();
   }];
 }
 
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
index 119801f9cc92f3..227a3df8fb9974 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
@@ -108,39 +108,11 @@ bool BufferPlacementTransformationBase::isLoop(Operation *op) {
 
   // If the operation does not implement the `RegionBranchOpInterface`, it is
   // (currently) not possible to detect a loop.
-  RegionBranchOpInterface regionInterface;
-  if (!(regionInterface = dyn_cast<RegionBranchOpInterface>(op)))
+  auto regionInterface = dyn_cast<RegionBranchOpInterface>(op);
+  if (!regionInterface)
     return false;
 
-  // Recurses into a region using the current region interface to find potential
-  // cycles.
-  SmallPtrSet<Region *, 4> visitedRegions;
-  std::function<bool(Region *)> recurse = [&](Region *current) {
-    if (!current)
-      return false;
-    // If we have found a back edge, the parent operation induces a loop.
-    if (!visitedRegions.insert(current).second)
-      return true;
-    // Recurses into all region successors.
-    SmallVector<RegionSuccessor, 2> successors;
-    regionInterface.getSuccessorRegions(current, successors);
-    for (RegionSuccessor &regionEntry : successors)
-      if (recurse(regionEntry.getSuccessor()))
-        return true;
-    return false;
-  };
-
-  // Start with all entry regions and test whether they induce a loop.
-  SmallVector<RegionSuccessor, 2> successorRegions;
-  regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
-                                      successorRegions);
-  for (RegionSuccessor &regionEntry : successorRegions) {
-    if (recurse(regionEntry.getSuccessor()))
-      return true;
-    visitedRegions.clear();
-  }
-
-  return false;
+  return regionInterface.hasLoop();
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index a563ec5cb8db58..a1ea22dbfc6937 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -219,11 +219,21 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
   return success();
 }
 
-/// Return `true` if region `r` is reachable from region `begin` according to
-/// the RegionBranchOpInterface (by taking a branch).
-static bool isRegionReachable(Region *begin, Region *r) {
-  assert(begin->getParentOp() == r->getParentOp() &&
-         "expected that both regions belong to the same op");
+namespace {
+/// Stop condition for `traverseRegionGraph`. The traversal is interrupted if
+/// this function returns "true" for a successor region. The first parameter is
+/// the successor region. The second parameter indicates all already visited
+/// regions.
+using StopConditionFn =
+    std::function<bool(Region *, const SmallVector<bool> &visited)>;
+} // namespace
+
+/// Traverse the region graph starting at `begin`. The traversal is interrupted
+/// if `stopCondition` evaluates to "true" for a successor region. In that case,
+/// this function returns "true". Otherwise, if the traversal was not
+/// interrupted, this function returns "false".
+static bool traverseRegionGraph(Region *begin,
+                                StopConditionFn stopConditionFn) {
   auto op = cast<RegionBranchOpInterface>(begin->getParentOp());
   SmallVector<bool> visited(op->getNumRegions(), false);
   visited[begin->getRegionNumber()] = true;
@@ -242,7 +252,7 @@ static bool isRegionReachable(Region *begin, Region *r) {
   // Process all regions in the worklist via DFS.
   while (!worklist.empty()) {
     Region *nextRegion = worklist.pop_back_val();
-    if (nextRegion == r)
+    if (stopConditionFn(nextRegion, visited))
       return true;
     if (visited[nextRegion->getRegionNumber()])
       continue;
@@ -253,6 +263,18 @@ static bool isRegionReachable(Region *begin, Region *r) {
   return false;
 }
 
+/// Return `true` if region `r` is reachable from region `begin` according to
+/// the RegionBranchOpInterface (by taking a branch).
+static bool isRegionReachable(Region *begin, Region *r) {
+  assert(begin->getParentOp() == r->getParentOp() &&
+         "expected that both regions belong to the same op");
+  return traverseRegionGraph(
+      begin, [&](Region *nextRegion, const SmallVector<bool> &visited) {
+        // Interrupt traversal if `r` was reached.
+        return nextRegion == r;
+      });
+}
+
 /// Return `true` if `a` and `b` are in mutually exclusive regions.
 ///
 /// 1. Find the first common of `a` and `b` (ancestor) that implements
@@ -306,6 +328,21 @@ bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
   return isRegionReachable(region, region);
 }
 
+bool RegionBranchOpInterface::hasLoop() {
+  SmallVector<RegionSuccessor> entryRegions;
+  getSuccessorRegions(RegionBranchPoint::parent(), entryRegions);
+  for (RegionSuccessor successor : entryRegions)
+    if (!successor.isParent() &&
+        traverseRegionGraph(
+            successor.getSuccessor(),
+            [](Region *nextRegion, const SmallVector<bool> &visited) {
+              // Interrupt traversal if the region was already visited.
+              return visited[nextRegion->getRegionNumber()];
+            }))
+      return true;
+  return false;
+}
+
 Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
   while (Region *region = op->getParentRegion()) {
     op = region->getParentOp();

@matthias-springer matthias-springer force-pushed the region_branch_op_interface_has_loop branch from 805267f to 33c35ad Compare January 5, 2024 14:11
Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Just two very minor comments

@matthias-springer matthias-springer force-pushed the region_branch_op_interface_has_loop branch from 33c35ad to 30f4424 Compare January 7, 2024 12:41
…Interface`

`BufferPlacementTransformationBase::isLoop` checks if there a loop in the region branching graph of an operation. This algorithm is similar to `isRegionReachable` in the `RegionBranchOpInterface`. To avoid duplicate code, `isRegionReachable` is generalized, so that it can be used to detect region loops. A helper function `RegionBranchOpInterface::hasLoop` is added.

This change also turns a recursive implementation into an iterative one, which is the preferred implementation strategy in LLVM.
@matthias-springer matthias-springer force-pushed the region_branch_op_interface_has_loop branch from 30f4424 to 20ea8f3 Compare January 7, 2024 12:44
@matthias-springer matthias-springer merged commit dd450f0 into llvm:main Jan 7, 2024
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
…Interface` (llvm#77090)

`BufferPlacementTransformationBase::isLoop` checks if there a loop in
the region branching graph of an operation. This algorithm is similar to
`isRegionReachable` in the `RegionBranchOpInterface`. To avoid duplicate
code, `isRegionReachable` is generalized, so that it can be used to
detect region loops. A helper function
`RegionBranchOpInterface::hasLoop` is added.

This change also turns a recursive implementation into an iterative one,
which is the preferred implementation strategy in LLVM.

Also move the `isLoop` to `BufferOptimizations.cpp`, so that we can
gradually retire `BufferPlacementTransformationBase`. (This is so that
proper error handling can be added to `BufferViewFlowAnalysis`.)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:bufferization Bufferization infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants