Skip to content

Commit ccd8573

Browse files
StarryCSFZhiQiang Fan
authored andcommitted
[flang][fir] Add fir.if -> scf.if and add filecheck test file (llvm#142965)
This commmit is a supplement for llvm#140374. RFC:https://discourse.llvm.org/t/rfc-add-fir-affine-optimization-fir-pass-pipeline/86190/6 --------- Co-authored-by: ZhiQiang Fan <[email protected]>
1 parent fd8e9fe commit ccd8573

File tree

2 files changed

+93
-2
lines changed

2 files changed

+93
-2
lines changed

flang/lib/Optimizer/Transforms/FIRToSCF.cpp

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,48 @@ struct DoLoopConversion : public OpRewritePattern<fir::DoLoopOp> {
8787
return success();
8888
}
8989
};
90+
91+
struct IfConversion : public OpRewritePattern<fir::IfOp> {
92+
using OpRewritePattern<fir::IfOp>::OpRewritePattern;
93+
LogicalResult matchAndRewrite(fir::IfOp ifOp,
94+
PatternRewriter &rewriter) const override {
95+
mlir::Location loc = ifOp.getLoc();
96+
mlir::detail::TypedValue<mlir::IntegerType> condition = ifOp.getCondition();
97+
ValueTypeRange<ResultRange> resultTypes = ifOp.getResultTypes();
98+
mlir::scf::IfOp scfIfOp = rewriter.create<scf::IfOp>(
99+
loc, resultTypes, condition, !ifOp.getElseRegion().empty());
100+
// then region
101+
scfIfOp.getThenRegion().takeBody(ifOp.getThenRegion());
102+
Block &scfThenBlock = scfIfOp.getThenRegion().front();
103+
Operation *scfThenTerminator = scfThenBlock.getTerminator();
104+
// fir.result->scf.yield
105+
rewriter.setInsertionPointToEnd(&scfThenBlock);
106+
rewriter.replaceOpWithNewOp<scf::YieldOp>(scfThenTerminator,
107+
scfThenTerminator->getOperands());
108+
109+
// else region
110+
if (!ifOp.getElseRegion().empty()) {
111+
scfIfOp.getElseRegion().takeBody(ifOp.getElseRegion());
112+
mlir::Block &elseBlock = scfIfOp.getElseRegion().front();
113+
mlir::Operation *elseTerminator = elseBlock.getTerminator();
114+
115+
rewriter.setInsertionPointToEnd(&elseBlock);
116+
rewriter.replaceOpWithNewOp<scf::YieldOp>(elseTerminator,
117+
elseTerminator->getOperands());
118+
}
119+
120+
scfIfOp->setAttrs(ifOp->getAttrs());
121+
rewriter.replaceOp(ifOp, scfIfOp);
122+
return success();
123+
}
124+
};
90125
} // namespace
91126

92127
void FIRToSCFPass::runOnOperation() {
93128
RewritePatternSet patterns(&getContext());
94-
patterns.add<DoLoopConversion>(patterns.getContext());
129+
patterns.add<DoLoopConversion, IfConversion>(patterns.getContext());
95130
ConversionTarget target(getContext());
96-
target.addIllegalOp<fir::DoLoopOp>();
131+
target.addIllegalOp<fir::DoLoopOp, fir::IfOp>();
97132
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
98133
if (failed(
99134
applyPartialConversion(getOperation(), target, std::move(patterns))))

flang/test/Fir/FirToSCF/if.fir

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// RUN: fir-opt %s --fir-to-scf | FileCheck %s
2+
3+
// CHECK: func.func @test_only(%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32) {
4+
// CHECK: scf.if %[[ARG0:.*]] {
5+
// CHECK: %[[VAL_1:.*]] = arith.addi %[[ARG1:.*]], %[[ARG1:.*]] : i32
6+
// CHECK: }
7+
// CHECK: return
8+
// CHECK: }
9+
func.func @test_only(%arg0 : i1, %arg1 : i32) {
10+
fir.if %arg0 {
11+
%0 = arith.addi %arg1, %arg1 : i32
12+
}
13+
return
14+
}
15+
16+
// CHECK: func.func @test_else() {
17+
// CHECK: %[[VAL_1:.*]] = arith.constant false
18+
// CHECK: %[[VAL_2:.*]] = arith.constant 2 : i32
19+
// CHECK: scf.if %[[VAL_1:.*]] {
20+
// CHECK: %[[VAL_3:.*]] = arith.constant 3 : i32
21+
// CHECK: } else {
22+
// CHECK: %[[VAL_3:.*]] = arith.constant 3 : i32
23+
// CHECK: }
24+
// CHECK: return
25+
// CHECK: }
26+
func.func @test_else() {
27+
%false = arith.constant false
28+
%1 = arith.constant 2 : i32
29+
fir.if %false {
30+
%2 = arith.constant 3 : i32
31+
} else {
32+
%3 = arith.constant 3 : i32
33+
}
34+
return
35+
}
36+
37+
// CHECK-LABEL: func.func @test_two_result() {
38+
// CHECK: %[[VAL_1:.*]] = arith.constant 2.000000e+00 : f32
39+
// CHECK: %[[VAL_2:.*]] = arith.constant false
40+
// CHECK: %[[RES:[0-9]+]]:2 = scf.if %[[VAL_2:.*]] -> (f32, f32) {
41+
// CHECK: scf.yield %[[VAL_1:.*]], %[[VAL_1:.*]] : f32, f32
42+
// CHECK: } else {
43+
// CHECK: scf.yield %[[VAL_1:.*]], %[[VAL_1:.*]] : f32, f32
44+
// CHECK: }
45+
// CHECK: return
46+
// CHECK: }
47+
func.func @test_two_result() {
48+
%1 = arith.constant 2.0 : f32
49+
%cmp = arith.constant false
50+
%x, %y = fir.if %cmp -> (f32, f32) {
51+
fir.result %1, %1 : f32, f32
52+
} else {
53+
fir.result %1, %1 : f32, f32
54+
}
55+
return
56+
}

0 commit comments

Comments
 (0)