Skip to content

Commit 2a01d7f

Browse files
author
Nicolas Vasilache
committed
[mlir][SCF] Add utility to outline the then and else branches of an scf.IfOp
Differential Revision: https://reviews.llvm.org/D85449
1 parent aedaa07 commit 2a01d7f

File tree

6 files changed

+166
-5
lines changed

6 files changed

+166
-5
lines changed

mlir/include/mlir/Dialect/SCF/Utils.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,15 @@
1313
#ifndef MLIR_DIALECT_SCF_UTILS_H_
1414
#define MLIR_DIALECT_SCF_UTILS_H_
1515

16+
#include "mlir/Support/LLVM.h"
17+
1618
namespace mlir {
19+
class FuncOp;
1720
class OpBuilder;
1821
class ValueRange;
1922

2023
namespace scf {
24+
class IfOp;
2125
class ForOp;
2226
class ParallelOp;
2327
} // end namespace scf
@@ -46,5 +50,12 @@ scf::ForOp cloneWithNewYields(OpBuilder &b, scf::ForOp loop,
4650
ValueRange newYieldedValues,
4751
bool replaceLoopResults = true);
4852

53+
/// Outline the then and/or else regions of `ifOp` as follows:
54+
/// - if `thenFn` is not null, `thenFnName` must be specified and the `then`
55+
/// region is inlined into a new FuncOp that is captured by the pointer.
56+
/// - if `elseFn` is not null, `elseFnName` must be specified and the `else`
57+
/// region is inlined into a new FuncOp that is captured by the pointer.
58+
void outlineIfOp(OpBuilder &b, scf::IfOp ifOp, FuncOp *thenFn,
59+
StringRef thenFnName, FuncOp *elseFn, StringRef elseFnName);
4960
} // end namespace mlir
5061
#endif // MLIR_DIALECT_SCF_UTILS_H_

mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@ add_mlir_dialect_library(MLIRSCFTransforms
1717
MLIRSCF
1818
MLIRStandardOps
1919
MLIRSupport
20-
)
20+
MLIRTransformUtils
21+
)

mlir/lib/Dialect/SCF/Transforms/Utils.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
#include "mlir/Dialect/SCF/Utils.h"
1414

1515
#include "mlir/Dialect/SCF/SCF.h"
16+
#include "mlir/Dialect/StandardOps/IR/Ops.h"
1617
#include "mlir/IR/BlockAndValueMapping.h"
18+
#include "mlir/IR/Function.h"
19+
#include "mlir/Transforms/RegionUtils.h"
20+
21+
#include "llvm/ADT/SetVector.h"
1722

1823
using namespace mlir;
1924

@@ -71,3 +76,50 @@ scf::ForOp mlir::cloneWithNewYields(OpBuilder &b, scf::ForOp loop,
7176

7277
return newLoop;
7378
}
79+
80+
void mlir::outlineIfOp(OpBuilder &b, scf::IfOp ifOp, FuncOp *thenFn,
81+
StringRef thenFnName, FuncOp *elseFn,
82+
StringRef elseFnName) {
83+
Location loc = ifOp.getLoc();
84+
MLIRContext *ctx = ifOp.getContext();
85+
auto outline = [&](Region &ifOrElseRegion, StringRef funcName) {
86+
assert(!funcName.empty() && "Expected function name for outlining");
87+
assert(ifOrElseRegion.getBlocks().size() <= 1 &&
88+
"Expected at most one block");
89+
90+
// Outline before current function.
91+
OpBuilder::InsertionGuard g(b);
92+
b.setInsertionPoint(ifOp.getParentOfType<FuncOp>());
93+
94+
llvm::SetVector<Value> captures;
95+
getUsedValuesDefinedAbove(ifOrElseRegion, captures);
96+
97+
ValueRange values(captures.getArrayRef());
98+
FunctionType type =
99+
FunctionType::get(values.getTypes(), ifOp.getResultTypes(), ctx);
100+
auto outlinedFunc = b.create<FuncOp>(loc, funcName, type);
101+
b.setInsertionPointToStart(outlinedFunc.addEntryBlock());
102+
BlockAndValueMapping bvm;
103+
for (auto it : llvm::zip(values, outlinedFunc.getArguments()))
104+
bvm.map(std::get<0>(it), std::get<1>(it));
105+
for (Operation &op : ifOrElseRegion.front().without_terminator())
106+
b.clone(op, bvm);
107+
108+
Operation *term = ifOrElseRegion.front().getTerminator();
109+
SmallVector<Value, 4> terminatorOperands;
110+
for (auto op : term->getOperands())
111+
terminatorOperands.push_back(bvm.lookup(op));
112+
b.create<ReturnOp>(loc, term->getResultTypes(), terminatorOperands);
113+
114+
ifOrElseRegion.front().clear();
115+
b.setInsertionPointToEnd(&ifOrElseRegion.front());
116+
Operation *call = b.create<CallOp>(loc, outlinedFunc, values);
117+
b.create<scf::YieldOp>(loc, call->getResults());
118+
return outlinedFunc;
119+
};
120+
121+
if (thenFn && !ifOp.thenRegion().empty())
122+
*thenFn = outline(ifOp.thenRegion(), thenFnName);
123+
if (elseFn && !ifOp.elseRegion().empty())
124+
*elseFn = outline(ifOp.elseRegion(), elseFnName);
125+
}
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
// RUN: mlir-opt -allow-unregistered-dialect -test-scf-if-utils -split-input-file %s | FileCheck %s
2+
3+
// -----
4+
5+
// CHECK: func @outlined_then0(%{{.*}}: i1, %{{.*}}: memref<?xf32>) -> i8 {
6+
// CHECK-NEXT: %{{.*}} = "some_op"(%{{.*}}, %{{.*}}) : (i1, memref<?xf32>) -> i8
7+
// CHECK-NEXT: return %{{.*}} : i8
8+
// CHECK-NEXT: }
9+
// CHECK: func @outlined_else0(%{{.*}}: i8) -> i8 {
10+
// CHECK-NEXT: return %{{.*}}0 : i8
11+
// CHECK-NEXT: }
12+
// CHECK: func @outline_if_else(
13+
// CHECK-NEXT: %{{.*}} = scf.if %{{.*}} -> (i8) {
14+
// CHECK-NEXT: %{{.*}} = call @outlined_then0(%{{.*}}, %{{.*}}) : (i1, memref<?xf32>) -> i8
15+
// CHECK-NEXT: scf.yield %{{.*}} : i8
16+
// CHECK-NEXT: } else {
17+
// CHECK-NEXT: %{{.*}} = call @outlined_else0(%{{.*}}) : (i8) -> i8
18+
// CHECK-NEXT: scf.yield %{{.*}} : i8
19+
// CHECK-NEXT: }
20+
// CHECK-NEXT: return
21+
// CHECK-NEXT: }
22+
func @outline_if_else(%cond: i1, %a: index, %b: memref<?xf32>, %c: i8) {
23+
%r = scf.if %cond -> (i8) {
24+
%r = "some_op"(%cond, %b) : (i1, memref<?xf32>) -> (i8)
25+
scf.yield %r : i8
26+
} else {
27+
scf.yield %c : i8
28+
}
29+
return
30+
}
31+
32+
// -----
33+
34+
// CHECK: func @outlined_then0(%{{.*}}: i1, %{{.*}}: memref<?xf32>) {
35+
// CHECK-NEXT: "some_op"(%{{.*}}, %{{.*}}) : (i1, memref<?xf32>) -> ()
36+
// CHECK-NEXT: return
37+
// CHECK-NEXT: }
38+
// CHECK: func @outline_if(
39+
// CHECK-NEXT: scf.if %{{.*}} {
40+
// CHECK-NEXT: call @outlined_then0(%{{.*}}, %{{.*}}) : (i1, memref<?xf32>) -> ()
41+
// CHECK-NEXT: }
42+
// CHECK-NEXT: return
43+
// CHECK-NEXT: }
44+
func @outline_if(%cond: i1, %a: index, %b: memref<?xf32>, %c: i8) {
45+
scf.if %cond {
46+
"some_op"(%cond, %b) : (i1, memref<?xf32>) -> ()
47+
scf.yield
48+
}
49+
return
50+
}
51+
52+
// -----
53+
54+
// CHECK: func @outlined_then0() {
55+
// CHECK-NEXT: return
56+
// CHECK-NEXT: }
57+
// CHECK: func @outlined_else0(%{{.*}}: i1, %{{.*}}: memref<?xf32>) {
58+
// CHECK-NEXT: "some_op"(%{{.*}}, %{{.*}}) : (i1, memref<?xf32>) -> ()
59+
// CHECK-NEXT: return
60+
// CHECK-NEXT: }
61+
// CHECK: func @outline_empty_if_else(
62+
// CHECK-NEXT: scf.if %{{.*}} {
63+
// CHECK-NEXT: call @outlined_then0() : () -> ()
64+
// CHECK-NEXT: } else {
65+
// CHECK-NEXT: call @outlined_else0(%{{.*}}, %{{.*}}) : (i1, memref<?xf32>) -> ()
66+
// CHECK-NEXT: }
67+
// CHECK-NEXT: return
68+
// CHECK-NEXT: }
69+
func @outline_empty_if_else(%cond: i1, %a: index, %b: memref<?xf32>, %c: i8) {
70+
scf.if %cond {
71+
} else {
72+
"some_op"(%cond, %b) : (i1, memref<?xf32>) -> ()
73+
}
74+
return
75+
}

mlir/test/Transforms/loop-utils.mlir renamed to mlir/test/Transforms/scf-loop-utils.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt -allow-unregistered-dialect -test-scf-utils -mlir-disable-threading %s | FileCheck %s
1+
// RUN: mlir-opt -allow-unregistered-dialect -test-scf-for-utils -mlir-disable-threading %s | FileCheck %s
22

33
// CHECK-LABEL: @hoist
44
// CHECK-SAME: %[[lb:[a-zA-Z0-9]*]]: index,

mlir/test/lib/Transforms/TestSCFUtils.cpp

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@
2121
using namespace mlir;
2222

2323
namespace {
24-
class TestSCFUtilsPass : public PassWrapper<TestSCFUtilsPass, FunctionPass> {
24+
class TestSCFForUtilsPass
25+
: public PassWrapper<TestSCFForUtilsPass, FunctionPass> {
2526
public:
26-
explicit TestSCFUtilsPass() {}
27+
explicit TestSCFForUtilsPass() {}
2728

2829
void runOnFunction() override {
2930
FuncOp func = getFunction();
@@ -49,10 +50,31 @@ class TestSCFUtilsPass : public PassWrapper<TestSCFUtilsPass, FunctionPass> {
4950
loop.erase();
5051
}
5152
};
53+
54+
class TestSCFIfUtilsPass
55+
: public PassWrapper<TestSCFIfUtilsPass, FunctionPass> {
56+
public:
57+
explicit TestSCFIfUtilsPass() {}
58+
59+
void runOnFunction() override {
60+
int count = 0;
61+
FuncOp func = getFunction();
62+
func.walk([&](scf::IfOp ifOp) {
63+
auto strCount = std::to_string(count++);
64+
FuncOp thenFn, elseFn;
65+
OpBuilder b(ifOp);
66+
outlineIfOp(b, ifOp, &thenFn, std::string("outlined_then") + strCount,
67+
&elseFn, std::string("outlined_else") + strCount);
68+
});
69+
}
70+
};
5271
} // end namespace
5372

5473
namespace mlir {
5574
void registerTestSCFUtilsPass() {
56-
PassRegistration<TestSCFUtilsPass>("test-scf-utils", "test scf utils");
75+
PassRegistration<TestSCFForUtilsPass>("test-scf-for-utils",
76+
"test scf.for utils");
77+
PassRegistration<TestSCFIfUtilsPass>("test-scf-if-utils",
78+
"test scf.if utils");
5779
}
5880
} // namespace mlir

0 commit comments

Comments
 (0)