@@ -37,7 +37,100 @@ struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
37
37
void runOnOperation () override ;
38
38
};
39
39
40
- // Lower scf::if to emitc::if, implementing return values as emitc::variable's
40
+ // Lower scf::for to emitc::for, implementing result values using
41
+ // emitc::variable's updated within the loop body.
42
+ struct ForLowering : public OpRewritePattern <ForOp> {
43
+ using OpRewritePattern<ForOp>::OpRewritePattern;
44
+
45
+ LogicalResult matchAndRewrite (ForOp forOp,
46
+ PatternRewriter &rewriter) const override ;
47
+ };
48
+
49
+ // Create an uninitialized emitc::variable op for each result of the given op.
50
+ template <typename T>
51
+ static SmallVector<Value> createVariablesForResults (T op,
52
+ PatternRewriter &rewriter) {
53
+ SmallVector<Value> resultVariables;
54
+
55
+ if (!op.getNumResults ())
56
+ return resultVariables;
57
+
58
+ Location loc = op->getLoc ();
59
+ MLIRContext *context = op.getContext ();
60
+
61
+ OpBuilder::InsertionGuard guard (rewriter);
62
+ rewriter.setInsertionPoint (op);
63
+
64
+ for (OpResult result : op.getResults ()) {
65
+ Type resultType = result.getType ();
66
+ emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get (context, " " );
67
+ emitc::VariableOp var =
68
+ rewriter.create <emitc::VariableOp>(loc, resultType, noInit);
69
+ resultVariables.push_back (var);
70
+ }
71
+
72
+ return resultVariables;
73
+ }
74
+
75
+ // Create a series of assign ops assigning given values to given variables at
76
+ // the current insertion point of given rewriter.
77
+ static void assignValues (ValueRange values, SmallVector<Value> &variables,
78
+ PatternRewriter &rewriter, Location loc) {
79
+ for (auto [value, var] : llvm::zip (values, variables))
80
+ rewriter.create <emitc::AssignOp>(loc, var, value);
81
+ }
82
+
83
+ static void lowerYield (SmallVector<Value> &resultVariables,
84
+ PatternRewriter &rewriter, scf::YieldOp yield) {
85
+ Location loc = yield.getLoc ();
86
+ ValueRange operands = yield.getOperands ();
87
+
88
+ OpBuilder::InsertionGuard guard (rewriter);
89
+ rewriter.setInsertionPoint (yield);
90
+
91
+ assignValues (operands, resultVariables, rewriter, loc);
92
+
93
+ rewriter.create <emitc::YieldOp>(loc);
94
+ rewriter.eraseOp (yield);
95
+ }
96
+
97
+ LogicalResult ForLowering::matchAndRewrite (ForOp forOp,
98
+ PatternRewriter &rewriter) const {
99
+ Location loc = forOp.getLoc ();
100
+
101
+ // Create an emitc::variable op for each result. These variables will be
102
+ // assigned to by emitc::assign ops within the loop body.
103
+ SmallVector<Value> resultVariables =
104
+ createVariablesForResults (forOp, rewriter);
105
+ SmallVector<Value> iterArgsVariables =
106
+ createVariablesForResults (forOp, rewriter);
107
+
108
+ assignValues (forOp.getInits (), iterArgsVariables, rewriter, loc);
109
+
110
+ emitc::ForOp loweredFor = rewriter.create <emitc::ForOp>(
111
+ loc, forOp.getLowerBound (), forOp.getUpperBound (), forOp.getStep ());
112
+
113
+ Block *loweredBody = loweredFor.getBody ();
114
+
115
+ // Erase the auto-generated terminator for the lowered for op.
116
+ rewriter.eraseOp (loweredBody->getTerminator ());
117
+
118
+ SmallVector<Value> replacingValues;
119
+ replacingValues.push_back (loweredFor.getInductionVar ());
120
+ replacingValues.append (iterArgsVariables.begin (), iterArgsVariables.end ());
121
+
122
+ rewriter.mergeBlocks (forOp.getBody (), loweredBody, replacingValues);
123
+ lowerYield (iterArgsVariables, rewriter,
124
+ cast<scf::YieldOp>(loweredBody->getTerminator ()));
125
+
126
+ // Copy iterArgs into results after the for loop.
127
+ assignValues (iterArgsVariables, resultVariables, rewriter, loc);
128
+
129
+ rewriter.replaceOp (forOp, resultVariables);
130
+ return success ();
131
+ }
132
+
133
+ // Lower scf::if to emitc::if, implementing result values as emitc::variable's
41
134
// updated within the then and else regions.
42
135
struct IfLowering : public OpRewritePattern <IfOp> {
43
136
using OpRewritePattern<IfOp>::OpRewritePattern;
@@ -52,20 +145,10 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
52
145
PatternRewriter &rewriter) const {
53
146
Location loc = ifOp.getLoc ();
54
147
55
- SmallVector<Value> resultVariables;
56
-
57
148
// Create an emitc::variable op for each result. These variables will be
58
149
// assigned to by emitc::assign ops within the then & else regions.
59
- if (ifOp.getNumResults ()) {
60
- MLIRContext *context = ifOp.getContext ();
61
- rewriter.setInsertionPoint (ifOp);
62
- for (OpResult result : ifOp.getResults ()) {
63
- Type resultType = result.getType ();
64
- auto noInit = emitc::OpaqueAttr::get (context, " " );
65
- auto var = rewriter.create <emitc::VariableOp>(loc, resultType, noInit);
66
- resultVariables.push_back (var);
67
- }
68
- }
150
+ SmallVector<Value> resultVariables =
151
+ createVariablesForResults (ifOp, rewriter);
69
152
70
153
// Utility function to lower the contents of an scf::if region to an emitc::if
71
154
// region. The contents of the scf::if regions is moved into the respective
@@ -76,16 +159,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
76
159
Region &loweredRegion) {
77
160
rewriter.inlineRegionBefore (region, loweredRegion, loweredRegion.end ());
78
161
Operation *terminator = loweredRegion.back ().getTerminator ();
79
- Location terminatorLoc = terminator->getLoc ();
80
- ValueRange terminatorOperands = terminator->getOperands ();
81
- rewriter.setInsertionPointToEnd (&loweredRegion.back ());
82
- for (auto value2Var : llvm::zip (terminatorOperands, resultVariables)) {
83
- Value resultValue = std::get<0 >(value2Var);
84
- Value resultVar = std::get<1 >(value2Var);
85
- rewriter.create <emitc::AssignOp>(terminatorLoc, resultVar, resultValue);
86
- }
87
- rewriter.create <emitc::YieldOp>(terminatorLoc);
88
- rewriter.eraseOp (terminator);
162
+ lowerYield (resultVariables, rewriter, cast<scf::YieldOp>(terminator));
89
163
};
90
164
91
165
Region &thenRegion = ifOp.getThenRegion ();
@@ -109,6 +183,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
109
183
}
110
184
111
185
void mlir::populateSCFToEmitCConversionPatterns (RewritePatternSet &patterns) {
186
+ patterns.add <ForLowering>(patterns.getContext ());
112
187
patterns.add <IfLowering>(patterns.getContext ());
113
188
}
114
189
@@ -118,7 +193,7 @@ void SCFToEmitCPass::runOnOperation() {
118
193
119
194
// Configure conversion to lower out SCF operations.
120
195
ConversionTarget target (getContext ());
121
- target.addIllegalOp <scf::IfOp>();
196
+ target.addIllegalOp <scf::ForOp, scf:: IfOp>();
122
197
target.markUnknownOpDynamicallyLegal ([](Operation *) { return true ; });
123
198
if (failed (
124
199
applyPartialConversion (getOperation (), target, std::move (patterns))))
0 commit comments