-
Notifications
You must be signed in to change notification settings - Fork 14.2k
[mlir][Interfaces][NFC] Move region loop detection to RegionBranchOpInterface
#77090
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Interfaces][NFC] Move region loop detection to RegionBranchOpInterface
#77090
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-bufferization Author: Matthias Springer (matthias-springer) Changes
This change also turns a recursive implementation into an iterative one, which is the preferred implementation strategy in LLVM. Full diff: https://github.com/llvm/llvm-project/pull/77090.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 120ddf01ebce5c..95ac5dea243aa4 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -272,6 +272,10 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
/// eventually branch back to the same region. (Maybe after passing through
/// other regions.)
bool isRepetitiveRegion(unsigned index);
+
+ /// Return `true` if there is a loop in the region branching graph. Only
+ /// reachable regions (starting from the entry regions) are considered.
+ bool hasLoop();
}];
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
index 119801f9cc92f3..227a3df8fb9974 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
@@ -108,39 +108,11 @@ bool BufferPlacementTransformationBase::isLoop(Operation *op) {
// If the operation does not implement the `RegionBranchOpInterface`, it is
// (currently) not possible to detect a loop.
- RegionBranchOpInterface regionInterface;
- if (!(regionInterface = dyn_cast<RegionBranchOpInterface>(op)))
+ auto regionInterface = dyn_cast<RegionBranchOpInterface>(op);
+ if (!regionInterface)
return false;
- // Recurses into a region using the current region interface to find potential
- // cycles.
- SmallPtrSet<Region *, 4> visitedRegions;
- std::function<bool(Region *)> recurse = [&](Region *current) {
- if (!current)
- return false;
- // If we have found a back edge, the parent operation induces a loop.
- if (!visitedRegions.insert(current).second)
- return true;
- // Recurses into all region successors.
- SmallVector<RegionSuccessor, 2> successors;
- regionInterface.getSuccessorRegions(current, successors);
- for (RegionSuccessor ®ionEntry : successors)
- if (recurse(regionEntry.getSuccessor()))
- return true;
- return false;
- };
-
- // Start with all entry regions and test whether they induce a loop.
- SmallVector<RegionSuccessor, 2> successorRegions;
- regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
- successorRegions);
- for (RegionSuccessor ®ionEntry : successorRegions) {
- if (recurse(regionEntry.getSuccessor()))
- return true;
- visitedRegions.clear();
- }
-
- return false;
+ return regionInterface.hasLoop();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index a563ec5cb8db58..a1ea22dbfc6937 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -219,11 +219,21 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
return success();
}
-/// Return `true` if region `r` is reachable from region `begin` according to
-/// the RegionBranchOpInterface (by taking a branch).
-static bool isRegionReachable(Region *begin, Region *r) {
- assert(begin->getParentOp() == r->getParentOp() &&
- "expected that both regions belong to the same op");
+namespace {
+/// Stop condition for `traverseRegionGraph`. The traversal is interrupted if
+/// this function returns "true" for a successor region. The first parameter is
+/// the successor region. The second parameter indicates all already visited
+/// regions.
+using StopConditionFn =
+ std::function<bool(Region *, const SmallVector<bool> &visited)>;
+} // namespace
+
+/// Traverse the region graph starting at `begin`. The traversal is interrupted
+/// if `stopCondition` evaluates to "true" for a successor region. In that case,
+/// this function returns "true". Otherwise, if the traversal was not
+/// interrupted, this function returns "false".
+static bool traverseRegionGraph(Region *begin,
+ StopConditionFn stopConditionFn) {
auto op = cast<RegionBranchOpInterface>(begin->getParentOp());
SmallVector<bool> visited(op->getNumRegions(), false);
visited[begin->getRegionNumber()] = true;
@@ -242,7 +252,7 @@ static bool isRegionReachable(Region *begin, Region *r) {
// Process all regions in the worklist via DFS.
while (!worklist.empty()) {
Region *nextRegion = worklist.pop_back_val();
- if (nextRegion == r)
+ if (stopConditionFn(nextRegion, visited))
return true;
if (visited[nextRegion->getRegionNumber()])
continue;
@@ -253,6 +263,18 @@ static bool isRegionReachable(Region *begin, Region *r) {
return false;
}
+/// Return `true` if region `r` is reachable from region `begin` according to
+/// the RegionBranchOpInterface (by taking a branch).
+static bool isRegionReachable(Region *begin, Region *r) {
+ assert(begin->getParentOp() == r->getParentOp() &&
+ "expected that both regions belong to the same op");
+ return traverseRegionGraph(
+ begin, [&](Region *nextRegion, const SmallVector<bool> &visited) {
+ // Interrupt traversal if `r` was reached.
+ return nextRegion == r;
+ });
+}
+
/// Return `true` if `a` and `b` are in mutually exclusive regions.
///
/// 1. Find the first common of `a` and `b` (ancestor) that implements
@@ -306,6 +328,21 @@ bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
return isRegionReachable(region, region);
}
+bool RegionBranchOpInterface::hasLoop() {
+ SmallVector<RegionSuccessor> entryRegions;
+ getSuccessorRegions(RegionBranchPoint::parent(), entryRegions);
+ for (RegionSuccessor successor : entryRegions)
+ if (!successor.isParent() &&
+ traverseRegionGraph(
+ successor.getSuccessor(),
+ [](Region *nextRegion, const SmallVector<bool> &visited) {
+ // Interrupt traversal if the region was already visited.
+ return visited[nextRegion->getRegionNumber()];
+ }))
+ return true;
+ return false;
+}
+
Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
while (Region *region = op->getParentRegion()) {
op = region->getParentOp();
|
805267f
to
33c35ad
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Just two very minor comments
33c35ad
to
30f4424
Compare
…Interface` `BufferPlacementTransformationBase::isLoop` checks if there a loop in the region branching graph of an operation. This algorithm is similar to `isRegionReachable` in the `RegionBranchOpInterface`. To avoid duplicate code, `isRegionReachable` is generalized, so that it can be used to detect region loops. A helper function `RegionBranchOpInterface::hasLoop` is added. This change also turns a recursive implementation into an iterative one, which is the preferred implementation strategy in LLVM.
30f4424
to
20ea8f3
Compare
…Interface` (llvm#77090) `BufferPlacementTransformationBase::isLoop` checks if there a loop in the region branching graph of an operation. This algorithm is similar to `isRegionReachable` in the `RegionBranchOpInterface`. To avoid duplicate code, `isRegionReachable` is generalized, so that it can be used to detect region loops. A helper function `RegionBranchOpInterface::hasLoop` is added. This change also turns a recursive implementation into an iterative one, which is the preferred implementation strategy in LLVM. Also move the `isLoop` to `BufferOptimizations.cpp`, so that we can gradually retire `BufferPlacementTransformationBase`. (This is so that proper error handling can be added to `BufferViewFlowAnalysis`.)
BufferPlacementTransformationBase::isLoop
checks if there a loop in the region branching graph of an operation. This algorithm is similar toisRegionReachable
in theRegionBranchOpInterface
. To avoid duplicate code,isRegionReachable
is generalized, so that it can be used to detect region loops. A helper functionRegionBranchOpInterface::hasLoop
is added.This change also turns a recursive implementation into an iterative one, which is the preferred implementation strategy in LLVM.
Also move the
isLoop
toBufferOptimizations.cpp
, so that we can gradually retireBufferPlacementTransformationBase
. (This is so that proper error handling can be added toBufferViewFlowAnalysis
.)