@@ -87,13 +87,48 @@ struct DoLoopConversion : public OpRewritePattern<fir::DoLoopOp> {
87
87
return success ();
88
88
}
89
89
};
90
+
91
+ struct IfConversion : public OpRewritePattern <fir::IfOp> {
92
+ using OpRewritePattern<fir::IfOp>::OpRewritePattern;
93
+ LogicalResult matchAndRewrite (fir::IfOp ifOp,
94
+ PatternRewriter &rewriter) const override {
95
+ mlir::Location loc = ifOp.getLoc ();
96
+ mlir::detail::TypedValue<mlir::IntegerType> condition = ifOp.getCondition ();
97
+ ValueTypeRange<ResultRange> resultTypes = ifOp.getResultTypes ();
98
+ mlir::scf::IfOp scfIfOp = rewriter.create <scf::IfOp>(
99
+ loc, resultTypes, condition, !ifOp.getElseRegion ().empty ());
100
+ // then region
101
+ scfIfOp.getThenRegion ().takeBody (ifOp.getThenRegion ());
102
+ Block &scfThenBlock = scfIfOp.getThenRegion ().front ();
103
+ Operation *scfThenTerminator = scfThenBlock.getTerminator ();
104
+ // fir.result->scf.yield
105
+ rewriter.setInsertionPointToEnd (&scfThenBlock);
106
+ rewriter.replaceOpWithNewOp <scf::YieldOp>(scfThenTerminator,
107
+ scfThenTerminator->getOperands ());
108
+
109
+ // else region
110
+ if (!ifOp.getElseRegion ().empty ()) {
111
+ scfIfOp.getElseRegion ().takeBody (ifOp.getElseRegion ());
112
+ mlir::Block &elseBlock = scfIfOp.getElseRegion ().front ();
113
+ mlir::Operation *elseTerminator = elseBlock.getTerminator ();
114
+
115
+ rewriter.setInsertionPointToEnd (&elseBlock);
116
+ rewriter.replaceOpWithNewOp <scf::YieldOp>(elseTerminator,
117
+ elseTerminator->getOperands ());
118
+ }
119
+
120
+ scfIfOp->setAttrs (ifOp->getAttrs ());
121
+ rewriter.replaceOp (ifOp, scfIfOp);
122
+ return success ();
123
+ }
124
+ };
90
125
} // namespace
91
126
92
127
void FIRToSCFPass::runOnOperation () {
93
128
RewritePatternSet patterns (&getContext ());
94
- patterns.add <DoLoopConversion>(patterns.getContext ());
129
+ patterns.add <DoLoopConversion, IfConversion >(patterns.getContext ());
95
130
ConversionTarget target (getContext ());
96
- target.addIllegalOp <fir::DoLoopOp>();
131
+ target.addIllegalOp <fir::DoLoopOp, fir::IfOp >();
97
132
target.markUnknownOpDynamicallyLegal ([](Operation *) { return true ; });
98
133
if (failed (
99
134
applyPartialConversion (getOperation (), target, std::move (patterns))))
0 commit comments