Skip to content

Commit dd450f0

Browse files
[mlir][Interfaces][NFC] Move region loop detection to RegionBranchOpInterface (#77090)
`BufferPlacementTransformationBase::isLoop` checks if there a loop in the region branching graph of an operation. This algorithm is similar to `isRegionReachable` in the `RegionBranchOpInterface`. To avoid duplicate code, `isRegionReachable` is generalized, so that it can be used to detect region loops. A helper function `RegionBranchOpInterface::hasLoop` is added. This change also turns a recursive implementation into an iterative one, which is the preferred implementation strategy in LLVM. Also move the `isLoop` to `BufferOptimizations.cpp`, so that we can gradually retire `BufferPlacementTransformationBase`. (This is so that proper error handling can be added to `BufferViewFlowAnalysis`.)
1 parent 2eb7a82 commit dd450f0

File tree

5 files changed

+66
-65
lines changed

5 files changed

+66
-65
lines changed

mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,6 @@ class BufferPlacementTransformationBase {
100100
return dom;
101101
}
102102

103-
/// Returns true if the given operation represents a loop by testing whether
104-
/// it implements the `LoopLikeOpInterface` or the `RegionBranchOpInterface`.
105-
/// In the case of a `RegionBranchOpInterface`, it checks all region-based
106-
/// control-flow edges for cycles.
107-
static bool isLoop(Operation *op);
108-
109103
/// Constructs a new operation base using the given root operation.
110104
BufferPlacementTransformationBase(Operation *op);
111105

mlir/include/mlir/Interfaces/ControlFlowInterfaces.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,10 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
272272
/// eventually branch back to the same region. (Maybe after passing through
273273
/// other regions.)
274274
bool isRepetitiveRegion(unsigned index);
275+
276+
/// Return `true` if there is a loop in the region branching graph. Only
277+
/// reachable regions (starting from the entry regions) are considered.
278+
bool hasLoop();
275279
}];
276280
}
277281

mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,25 @@ static bool isKnownControlFlowInterface(Operation *op) {
4040
return isa<LoopLikeOpInterface, RegionBranchOpInterface>(op);
4141
}
4242

43+
/// Returns true if the given operation represents a loop by testing whether it
44+
/// implements the `LoopLikeOpInterface` or the `RegionBranchOpInterface`. In
45+
/// the case of a `RegionBranchOpInterface`, it checks all region-based control-
46+
/// flow edges for cycles.
47+
static bool isLoop(Operation *op) {
48+
// If the operation implements the `LoopLikeOpInterface` it can be considered
49+
// a loop.
50+
if (isa<LoopLikeOpInterface>(op))
51+
return true;
52+
53+
// If the operation does not implement the `RegionBranchOpInterface`, it is
54+
// (currently) not possible to detect a loop.
55+
auto regionInterface = dyn_cast<RegionBranchOpInterface>(op);
56+
if (!regionInterface)
57+
return false;
58+
59+
return regionInterface.hasLoop();
60+
}
61+
4362
/// Returns true if the given operation implements the AllocationOpInterface
4463
/// and it supports the dominate block hoisting.
4564
static bool allowAllocDominateBlockHoisting(Operation *op) {
@@ -115,8 +134,7 @@ static bool hasAllocationScope(Value alloc,
115134
// Check if the operation is a known control flow interface and break the
116135
// loop to avoid transformation in loops. Furthermore skip transformation
117136
// if the operation does not implement a RegionBeanchOpInterface.
118-
if (BufferPlacementTransformationBase::isLoop(parentOp) ||
119-
!isKnownControlFlowInterface(parentOp))
137+
if (isLoop(parentOp) || !isKnownControlFlowInterface(parentOp))
120138
break;
121139
}
122140
} while ((region = region->getParentRegion()));
@@ -290,9 +308,7 @@ struct BufferAllocationHoistingState : BufferAllocationHoistingStateBase {
290308
}
291309

292310
/// Returns true if the given operation does not represent a loop.
293-
bool isLegalPlacement(Operation *op) {
294-
return !BufferPlacementTransformationBase::isLoop(op);
295-
}
311+
bool isLegalPlacement(Operation *op) { return !isLoop(op); }
296312

297313
/// Returns true if the given operation should be considered for hoisting.
298314
static bool shouldHoistOpType(Operation *op) {
@@ -327,7 +343,7 @@ struct BufferAllocationLoopHoistingState : BufferAllocationHoistingStateBase {
327343
/// given loop operation. If this is the case, it indicates that the
328344
/// allocation is passed via a back edge.
329345
bool isLegalPlacement(Operation *op) {
330-
return BufferPlacementTransformationBase::isLoop(op) &&
346+
return isLoop(op) &&
331347
!dominators->dominates(aliasDominatorBlock, op->getBlock());
332348
}
333349

mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -96,53 +96,6 @@ BufferPlacementTransformationBase::BufferPlacementTransformationBase(
9696
Operation *op)
9797
: aliases(op), allocs(op), liveness(op) {}
9898

99-
/// Returns true if the given operation represents a loop by testing whether it
100-
/// implements the `LoopLikeOpInterface` or the `RegionBranchOpInterface`. In
101-
/// the case of a `RegionBranchOpInterface`, it checks all region-based control-
102-
/// flow edges for cycles.
103-
bool BufferPlacementTransformationBase::isLoop(Operation *op) {
104-
// If the operation implements the `LoopLikeOpInterface` it can be considered
105-
// a loop.
106-
if (isa<LoopLikeOpInterface>(op))
107-
return true;
108-
109-
// If the operation does not implement the `RegionBranchOpInterface`, it is
110-
// (currently) not possible to detect a loop.
111-
RegionBranchOpInterface regionInterface;
112-
if (!(regionInterface = dyn_cast<RegionBranchOpInterface>(op)))
113-
return false;
114-
115-
// Recurses into a region using the current region interface to find potential
116-
// cycles.
117-
SmallPtrSet<Region *, 4> visitedRegions;
118-
std::function<bool(Region *)> recurse = [&](Region *current) {
119-
if (!current)
120-
return false;
121-
// If we have found a back edge, the parent operation induces a loop.
122-
if (!visitedRegions.insert(current).second)
123-
return true;
124-
// Recurses into all region successors.
125-
SmallVector<RegionSuccessor, 2> successors;
126-
regionInterface.getSuccessorRegions(current, successors);
127-
for (RegionSuccessor &regionEntry : successors)
128-
if (recurse(regionEntry.getSuccessor()))
129-
return true;
130-
return false;
131-
};
132-
133-
// Start with all entry regions and test whether they induce a loop.
134-
SmallVector<RegionSuccessor, 2> successorRegions;
135-
regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
136-
successorRegions);
137-
for (RegionSuccessor &regionEntry : successorRegions) {
138-
if (recurse(regionEntry.getSuccessor()))
139-
return true;
140-
visitedRegions.clear();
141-
}
142-
143-
return false;
144-
}
145-
14699
//===----------------------------------------------------------------------===//
147100
// BufferPlacementTransformationBase
148101
//===----------------------------------------------------------------------===//

mlir/lib/Interfaces/ControlFlowInterfaces.cpp

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -219,11 +219,18 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
219219
return success();
220220
}
221221

222-
/// Return `true` if region `r` is reachable from region `begin` according to
223-
/// the RegionBranchOpInterface (by taking a branch).
224-
static bool isRegionReachable(Region *begin, Region *r) {
225-
assert(begin->getParentOp() == r->getParentOp() &&
226-
"expected that both regions belong to the same op");
222+
/// Stop condition for `traverseRegionGraph`. The traversal is interrupted if
223+
/// this function returns "true" for a successor region. The first parameter is
224+
/// the successor region. The second parameter indicates all already visited
225+
/// regions.
226+
using StopConditionFn = function_ref<bool(Region *, ArrayRef<bool> visited)>;
227+
228+
/// Traverse the region graph starting at `begin`. The traversal is interrupted
229+
/// if `stopCondition` evaluates to "true" for a successor region. In that case,
230+
/// this function returns "true". Otherwise, if the traversal was not
231+
/// interrupted, this function returns "false".
232+
static bool traverseRegionGraph(Region *begin,
233+
StopConditionFn stopConditionFn) {
227234
auto op = cast<RegionBranchOpInterface>(begin->getParentOp());
228235
SmallVector<bool> visited(op->getNumRegions(), false);
229236
visited[begin->getRegionNumber()] = true;
@@ -242,7 +249,7 @@ static bool isRegionReachable(Region *begin, Region *r) {
242249
// Process all regions in the worklist via DFS.
243250
while (!worklist.empty()) {
244251
Region *nextRegion = worklist.pop_back_val();
245-
if (nextRegion == r)
252+
if (stopConditionFn(nextRegion, visited))
246253
return true;
247254
if (visited[nextRegion->getRegionNumber()])
248255
continue;
@@ -253,6 +260,18 @@ static bool isRegionReachable(Region *begin, Region *r) {
253260
return false;
254261
}
255262

263+
/// Return `true` if region `r` is reachable from region `begin` according to
264+
/// the RegionBranchOpInterface (by taking a branch).
265+
static bool isRegionReachable(Region *begin, Region *r) {
266+
assert(begin->getParentOp() == r->getParentOp() &&
267+
"expected that both regions belong to the same op");
268+
return traverseRegionGraph(begin,
269+
[&](Region *nextRegion, ArrayRef<bool> visited) {
270+
// Interrupt traversal if `r` was reached.
271+
return nextRegion == r;
272+
});
273+
}
274+
256275
/// Return `true` if `a` and `b` are in mutually exclusive regions.
257276
///
258277
/// 1. Find the first common of `a` and `b` (ancestor) that implements
@@ -306,6 +325,21 @@ bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
306325
return isRegionReachable(region, region);
307326
}
308327

328+
bool RegionBranchOpInterface::hasLoop() {
329+
SmallVector<RegionSuccessor> entryRegions;
330+
getSuccessorRegions(RegionBranchPoint::parent(), entryRegions);
331+
for (RegionSuccessor successor : entryRegions)
332+
if (!successor.isParent() &&
333+
traverseRegionGraph(successor.getSuccessor(),
334+
[](Region *nextRegion, ArrayRef<bool> visited) {
335+
// Interrupt traversal if the region was already
336+
// visited.
337+
return visited[nextRegion->getRegionNumber()];
338+
}))
339+
return true;
340+
return false;
341+
}
342+
309343
Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
310344
while (Region *region = op->getParentRegion()) {
311345
op = region->getParentOp();

0 commit comments

Comments
 (0)