Skip to content

Commit 0db1ae3

Browse files
committed
[mlir][CFGToSCF] Visit subregions in CFGToSCF pass
This is useful when user already have partially-scf'ed IR or other ops with nested regions (e.g. linalg.generic). Also, improve error message and pass docs. Differential Revision: https://reviews.llvm.org/D158349
1 parent 4d434f7 commit 0db1ae3

File tree

3 files changed

+60
-10
lines changed

3 files changed

+60
-10
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,12 @@ def LiftControlFlowToSCFPass : Pass<"lift-cf-to-scf"> {
304304
ControlFlow operations will be replaced successfully.
305305
Otherwise a single ControlFlow switch branching to one block per return-like
306306
operation kind remains.
307+
308+
This pass may need to create unreachable terminators in case of infinite
309+
loops, which is only supported for 'func.func' for now. If you potentially
310+
have infinite loops inside CFG regions not belonging to 'func.func',
311+
consider using `transformCFGToSCF` function directly with corresponding
312+
`CFGToSCFInterface::createUnreachableTerminator` implementation.
307313
}];
308314

309315
let dependentDialects = ["scf::SCFDialect",

mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -140,10 +140,11 @@ ControlFlowToSCFTransformation::createUnreachableTerminator(Location loc,
140140
// TODO: This should create a `ub.unreachable` op. Once such an operation
141141
// exists to make the pass independent of the func dialect. For now just
142142
// return poison values.
143-
auto funcOp = dyn_cast<func::FuncOp>(region.getParentOp());
143+
Operation *parentOp = region.getParentOp();
144+
auto funcOp = dyn_cast<func::FuncOp>(parentOp);
144145
if (!funcOp)
145-
return emitError(loc, "Expected '")
146-
<< func::FuncOp::getOperationName() << "' as top level operation";
146+
return emitError(loc, "Cannot create unreachable terminator for '")
147+
<< parentOp->getName() << "'";
147148

148149
return builder
149150
.create<func::ReturnOp>(
@@ -165,18 +166,29 @@ struct LiftControlFlowToSCF
165166
ControlFlowToSCFTransformation transformation;
166167

167168
bool changed = false;
168-
WalkResult result = getOperation()->walk([&](func::FuncOp funcOp) {
169+
Operation *op = getOperation();
170+
WalkResult result = op->walk([&](func::FuncOp funcOp) {
169171
if (funcOp.getBody().empty())
170172
return WalkResult::advance();
171173

172-
FailureOr<bool> changedFunc = transformCFGToSCF(
173-
funcOp.getBody(), transformation,
174-
funcOp != getOperation() ? getChildAnalysis<DominanceInfo>(funcOp)
175-
: getAnalysis<DominanceInfo>());
176-
if (failed(changedFunc))
174+
auto &domInfo = funcOp != op ? getChildAnalysis<DominanceInfo>(funcOp)
175+
: getAnalysis<DominanceInfo>();
176+
177+
auto visitor = [&](Operation *innerOp) -> WalkResult {
178+
for (Region &reg : innerOp->getRegions()) {
179+
FailureOr<bool> changedFunc =
180+
transformCFGToSCF(reg, transformation, domInfo);
181+
if (failed(changedFunc))
182+
return WalkResult::interrupt();
183+
184+
changed |= *changedFunc;
185+
}
186+
return WalkResult::advance();
187+
};
188+
189+
if (funcOp->walk<WalkOrder::PostOrder>(visitor).wasInterrupted())
177190
return WalkResult::interrupt();
178191

179-
changed |= *changedFunc;
180192
return WalkResult::advance();
181193
});
182194
if (result.wasInterrupted())

mlir/test/Conversion/ControlFlowToSCF/test.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,3 +678,35 @@ func.func @multi_entry_loop(%cond: i1) {
678678
// CHECK: scf.yield
679679
// CHECK: call @foo(%[[WHILE]]#1)
680680
// CHECK-NEXT: return
681+
682+
// -----
683+
684+
func.func @nested_region() {
685+
scf.execute_region {
686+
%cond = "test.test1"() : () -> i1
687+
cf.cond_br %cond, ^bb1, ^bb2
688+
^bb1:
689+
"test.test2"() : () -> ()
690+
cf.br ^bb3
691+
^bb2:
692+
"test.test3"() : () -> ()
693+
cf.br ^bb3
694+
^bb3:
695+
"test.test4"() : () -> ()
696+
scf.yield
697+
}
698+
return
699+
}
700+
701+
// CHECK-LABEL: func @nested_region
702+
// CHECK: scf.execute_region {
703+
// CHECK: %[[COND:.*]] = "test.test1"()
704+
// CHECK-NEXT: scf.if %[[COND]]
705+
// CHECK-NEXT: "test.test2"()
706+
// CHECK-NEXT: else
707+
// CHECK-NEXT: "test.test3"()
708+
// CHECK-NEXT: }
709+
// CHECK-NEXT: "test.test4"()
710+
// CHECK-NEXT: scf.yield
711+
// CHECK-NEXT: }
712+
// CHECK-NEXT: return

0 commit comments

Comments
 (0)