Skip to content

Commit d6bc837

Browse files
author
ZhiQiang Fan
committed
[flang][fir] Add fir.if -> scf.if and add filecheck test file
1 parent 6da8f3b commit d6bc837

File tree

2 files changed

+153
-2
lines changed

2 files changed

+153
-2
lines changed

flang/lib/Optimizer/Transforms/FIRToSCF.cpp

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,67 @@ struct DoLoopConversion : public OpRewritePattern<fir::DoLoopOp> {
8787
return success();
8888
}
8989
};
90+
91+
struct IfConversion : public OpRewritePattern<fir::IfOp> {
92+
using OpRewritePattern<fir::IfOp>::OpRewritePattern;
93+
LogicalResult matchAndRewrite(fir::IfOp ifOp,
94+
PatternRewriter &rewriter) const override {
95+
mlir::Location loc = ifOp.getLoc();
96+
mlir::detail::TypedValue<mlir::IntegerType> condition = ifOp.getCondition();
97+
ValueTypeRange<ResultRange> resultTypes = ifOp.getResultTypes();
98+
bool hasResult = !resultTypes.empty();
99+
auto scfIfOp = rewriter.create<scf::IfOp>(loc, resultTypes, condition,
100+
!ifOp.getElseRegion().empty());
101+
// then region
102+
assert(!ifOp.getThenRegion().empty() && "must have then region");
103+
auto &firThenBlock = ifOp.getThenRegion().front();
104+
auto &scfThenBlock = scfIfOp.getThenRegion().front();
105+
auto &firThenOps = firThenBlock.getOperations();
106+
mlir::Operation *firThenTerminator = firThenBlock.getTerminator();
107+
108+
rewriter.setInsertionPointToStart(&scfThenBlock);
109+
// not splice terminator
110+
scfThenBlock.getOperations().splice(scfThenBlock.begin(), firThenOps,
111+
firThenOps.begin(),
112+
std::prev(firThenOps.end()));
113+
// create terminator scf.yield
114+
if (hasResult) {
115+
rewriter.setInsertionPointToEnd(&scfThenBlock);
116+
mlir::OperandRange thenResults = firThenTerminator->getOperands();
117+
rewriter.create<scf::YieldOp>(firThenTerminator->getLoc(), thenResults);
118+
}
119+
120+
// else region
121+
if (!ifOp.getElseRegion().empty()) {
122+
auto &firElseBlock = ifOp.getElseRegion().front();
123+
auto &scfElseBlock = scfIfOp.getElseRegion().front();
124+
auto &firElseOps = firElseBlock.getOperations();
125+
mlir::Operation *firElseTerminator = firElseBlock.getTerminator();
126+
127+
rewriter.setInsertionPointToStart(&scfElseBlock);
128+
scfElseBlock.getOperations().splice(scfElseBlock.begin(), firElseOps,
129+
firElseOps.begin(),
130+
std::prev(firElseOps.end()));
131+
132+
if (hasResult) {
133+
rewriter.setInsertionPointToEnd(&scfElseBlock);
134+
mlir::OperandRange elseResults = firElseTerminator->getOperands();
135+
rewriter.create<scf::YieldOp>(firElseTerminator->getLoc(), elseResults);
136+
}
137+
}
138+
139+
scfIfOp->setAttrs(ifOp->getAttrs());
140+
rewriter.replaceOp(ifOp, scfIfOp);
141+
return success();
142+
}
143+
};
90144
} // namespace
91145

92146
void FIRToSCFPass::runOnOperation() {
93147
RewritePatternSet patterns(&getContext());
94-
patterns.add<DoLoopConversion>(patterns.getContext());
148+
patterns.add<DoLoopConversion, IfConversion>(patterns.getContext());
95149
ConversionTarget target(getContext());
96-
target.addIllegalOp<fir::DoLoopOp>();
150+
target.addIllegalOp<fir::DoLoopOp, fir::IfOp>();
97151
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
98152
if (failed(
99153
applyPartialConversion(getOperation(), target, std::move(patterns))))

flang/test/Fir/FirToSCF/if.fir

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
2+
// RUN: fir-opt %s --fir-to-scf | FileCheck %s
3+
4+
// CHECK: func.func @_QFPtest_only(
5+
// CHECK: %[[ARG0:.*]]: !fir.ref<tuple<!fir.ref<f32>>>) {
6+
// CHECK: %[[VAL_0:.*]] = arith.constant 1 : i32
7+
// CHECK: %[[VAL_1:.*]] = arith.constant false
8+
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : i32
9+
// CHECK: scf.if %[[VAL_1:.*]] {
10+
// CHECK: %[[VAL_3:.*]] = arith.addi %[[VAL_2:.*]], %[[VAL_0:.*]] : i32
11+
// CHECK: }
12+
// CHECK: return
13+
// CHECK: }
14+
func.func @_QFPtest_only(%arg0: !fir.ref<tuple<!fir.ref<f32>>>) {
15+
%c1_i32 = arith.constant 1 : i32
16+
%false = arith.constant false
17+
%c0_i32 = arith.constant 0 : i32
18+
fir.if %false {
19+
%0 = arith.addi %c0_i32, %c1_i32 : i32
20+
}
21+
return
22+
}
23+
24+
// CHECK: func.func @_QFPtest_else(
25+
// CHECK: %[[ARG0:.*]]: !fir.ref<tuple<!fir.ref<f32>>>) {
26+
// CHECK: %[[VAL_0:.*]] = arith.constant 1 : i32
27+
// CHECK: %[[VAL_1:.*]] = arith.constant false
28+
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : i32
29+
// CHECK: %[[VAL_3:.*]] = fir.dummy_scope : !fir.dscope
30+
// CHECK: %[[VAL_4:.*]] = fir.coordinate_of %[[ARG0:.*]], %[[VAL_2:.*]] : (!fir.ref<tuple<!fir.ref<f32>>>, i32) -> !fir.llvm_ptr<!fir.ref<f32>>
31+
// CHECK: %[[VAL_5:.*]] = fir.load %[[VAL_4:.*]] : !fir.llvm_ptr<!fir.ref<f32>>
32+
// CHECK: %[[VAL_6:.*]] = fir.declare %[[VAL_5:.*]] {fortran_attrs = #fir.var_attrs<host_assoc>, uniq_name = "_QFEx"} : (!fir.ref<f32>) -> !fir.ref<f32>
33+
// CHECK: %[[VAL_7:.*]] = fir.address_of(@_QFFtest_elseEsum) : !fir.ref<i32>
34+
// CHECK: %[[VAL_10:.*]] = fir.declare %[[VAL_7:.*]] {uniq_name = "_QFFtest_elseEsum"} : (!fir.ref<i32>) -> !fir.ref<i32>
35+
// CHECK: scf.if %[[VAL_1:.*]] {
36+
// CHECK: %[[VAL_8:.*]] = fir.load %[[VAL_10:.*]] : !fir.ref<i32>
37+
// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8:.*]], %[[VAL_0:.*]] : i32
38+
// CHECK: fir.store %[[VAL_9:.*]] to %[[VAL_10:.*]] : !fir.ref<i32>
39+
// CHECK: } else {
40+
// CHECK: %[[VAL_8:.*]] = fir.load %[[VAL_10:.*]] : !fir.ref<i32>
41+
// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8:.*]], %c1_i32 : i32
42+
// CHECK: fir.store %[[VAL_9:.*]] to %[[VAL_10:.*]] : !fir.ref<i32>
43+
// CHECK: }
44+
// CHECK: return
45+
// CHECK: }
46+
func.func @_QFPtest_else(%arg0: !fir.ref<tuple<!fir.ref<f32>>> {}) attributes {} {
47+
%c1_i32 = arith.constant 1 : i32
48+
%false = arith.constant false
49+
%c0_i32 = arith.constant 0 : i32
50+
%0 = fir.dummy_scope : !fir.dscope
51+
%1 = fir.coordinate_of %arg0, %c0_i32 : (!fir.ref<tuple<!fir.ref<f32>>>, i32) -> !fir.llvm_ptr<!fir.ref<f32>>
52+
%2 = fir.load %1 : !fir.llvm_ptr<!fir.ref<f32>>
53+
%3 = fir.declare %2 {fortran_attrs = #fir.var_attrs<host_assoc>, uniq_name = "_QFEx"} : (!fir.ref<f32>) -> !fir.ref<f32>
54+
%4 = fir.address_of(@_QFFtest_elseEsum) : !fir.ref<i32>
55+
%5 = fir.declare %4 {uniq_name = "_QFFtest_elseEsum"} : (!fir.ref<i32>) -> !fir.ref<i32>
56+
fir.if %false {
57+
%6 = fir.load %5 : !fir.ref<i32>
58+
%7 = arith.addi %6, %c1_i32 : i32
59+
fir.store %7 to %5 : !fir.ref<i32>
60+
} else {
61+
%6 = fir.load %5 : !fir.ref<i32>
62+
%7 = arith.addi %6, %c1_i32 : i32
63+
fir.store %7 to %5 : !fir.ref<i32>
64+
}
65+
return
66+
}
67+
68+
// CHECK-LABEL: func.func @test_two_result() {
69+
// CHECK: %[[VAL_0:.*]] = arith.constant 10 : i32
70+
// CHECK: %[[VAL_1:.*]] = arith.constant 5 : i32
71+
// CHECK: %[[VAL_2:.*]] = arith.cmpi sgt, %[[VAL_0:.*]], %[[VAL_1:.*]] : i32
72+
// CHECK: %[[VAL_3:.*]] = arith.constant 3.140000e+00 : f32
73+
// CHECK: %[[VAL_4:.*]] = arith.constant 2.710000e+00 : f32
74+
// CHECK: %[[VAL_5:.*]] = arith.constant 1.000000e+00 : f32
75+
// CHECK: %[[VAL_6:.*]] = arith.constant 2.000000e+00 : f32
76+
// CHECK: %[[RES:[0-9]+]]:2 = scf.if %[[VAL_2:.*]] -> (f32, f32) {
77+
// CHECK: scf.yield %[[VAL_3:.*]], %[[VAL_4:.*]] : f32, f32
78+
// CHECK: } else {
79+
// CHECK: scf.yield %[[VAL_5:.*]], %[[VAL_6:.*]] : f32, f32
80+
// CHECK: }
81+
// CHECK: return
82+
// CHECK: }
83+
func.func @test_two_result() {
84+
%c10_i32 = arith.constant 10 : i32
85+
%c5_i32 = arith.constant 5 : i32
86+
%cmp = arith.cmpi sgt, %c10_i32, %c5_i32 : i32
87+
%c3_14_f32 = arith.constant 3.14 : f32
88+
%c2_71_f32 = arith.constant 2.71 : f32
89+
%c1_0_f32 = arith.constant 1.0 : f32
90+
%c2_0_f32 = arith.constant 2.0 : f32
91+
%x, %y = fir.if %cmp -> (f32, f32) {
92+
fir.result %c3_14_f32, %c2_71_f32 : f32, f32
93+
} else {
94+
fir.result %c1_0_f32, %c2_0_f32 : f32, f32
95+
}
96+
return
97+
}

0 commit comments

Comments
 (0)