14
14
15
15
#include " mlir/Dialect/Arith/IR/Arith.h"
16
16
#include " mlir/Dialect/EmitC/IR/EmitC.h"
17
- #include " mlir/Dialect/EmitC/Transforms/TypeConversions.h"
18
17
#include " mlir/Dialect/SCF/IR/SCF.h"
19
18
#include " mlir/IR/Builders.h"
20
19
#include " mlir/IR/BuiltinOps.h"
@@ -40,22 +39,21 @@ struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
40
39
41
40
// Lower scf::for to emitc::for, implementing result values using
42
41
// emitc::variable's updated within the loop body.
43
- struct ForLowering : public OpConversionPattern <ForOp> {
44
- using OpConversionPattern <ForOp>::OpConversionPattern ;
42
+ struct ForLowering : public OpRewritePattern <ForOp> {
43
+ using OpRewritePattern <ForOp>::OpRewritePattern ;
45
44
46
- LogicalResult
47
- matchAndRewrite (ForOp forOp, OpAdaptor adaptor,
48
- ConversionPatternRewriter &rewriter) const override ;
45
+ LogicalResult matchAndRewrite (ForOp forOp,
46
+ PatternRewriter &rewriter) const override ;
49
47
};
50
48
51
49
// Create an uninitialized emitc::variable op for each result of the given op.
52
50
template <typename T>
53
- static LogicalResult
54
- createVariablesForResults (T op, const TypeConverter *typeConverter,
55
- ConversionPatternRewriter &rewriter,
56
- SmallVector<Value> &resultVariables) {
51
+ static SmallVector<Value> createVariablesForResults (T op,
52
+ PatternRewriter &rewriter) {
53
+ SmallVector<Value> resultVariables;
54
+
57
55
if (!op.getNumResults ())
58
- return success () ;
56
+ return resultVariables ;
59
57
60
58
Location loc = op->getLoc ();
61
59
MLIRContext *context = op.getContext ();
@@ -64,23 +62,21 @@ createVariablesForResults(T op, const TypeConverter *typeConverter,
64
62
rewriter.setInsertionPoint (op);
65
63
66
64
for (OpResult result : op.getResults ()) {
67
- Type resultType = typeConverter->convertType (result.getType ());
68
- if (!resultType)
69
- return rewriter.notifyMatchFailure (op, " result type conversion failed" );
65
+ Type resultType = result.getType ();
70
66
Type varType = emitc::LValueType::get (resultType);
71
67
emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get (context, " " );
72
68
emitc::VariableOp var =
73
69
rewriter.create <emitc::VariableOp>(loc, varType, noInit);
74
70
resultVariables.push_back (var);
75
71
}
76
72
77
- return success () ;
73
+ return resultVariables ;
78
74
}
79
75
80
76
// Create a series of assign ops assigning given values to given variables at
81
77
// the current insertion point of given rewriter.
82
- static void assignValues (ValueRange values, ValueRange variables,
83
- ConversionPatternRewriter &rewriter, Location loc) {
78
+ static void assignValues (ValueRange values, SmallVector<Value> & variables,
79
+ PatternRewriter &rewriter, Location loc) {
84
80
for (auto [value, var] : llvm::zip (values, variables))
85
81
rewriter.create <emitc::AssignOp>(loc, var, value);
86
82
}
@@ -93,58 +89,46 @@ SmallVector<Value> loadValues(const SmallVector<Value> &variables,
93
89
});
94
90
}
95
91
96
- static LogicalResult lowerYield (Operation *op, ValueRange resultVariables,
97
- ConversionPatternRewriter &rewriter,
98
- scf::YieldOp yield) {
92
+ static void lowerYield (SmallVector<Value> &resultVariables,
93
+ PatternRewriter &rewriter, scf::YieldOp yield) {
99
94
Location loc = yield.getLoc ();
95
+ ValueRange operands = yield.getOperands ();
100
96
101
97
OpBuilder::InsertionGuard guard (rewriter);
102
98
rewriter.setInsertionPoint (yield);
103
99
104
- SmallVector<Value> yieldOperands;
105
- if (failed (rewriter.getRemappedValues (yield.getOperands (), yieldOperands))) {
106
- return rewriter.notifyMatchFailure (op, " failed to lower yield operands" );
107
- }
108
-
109
- assignValues (yieldOperands, resultVariables, rewriter, loc);
100
+ assignValues (operands, resultVariables, rewriter, loc);
110
101
111
102
rewriter.create <emitc::YieldOp>(loc);
112
103
rewriter.eraseOp (yield);
113
-
114
- return success ();
115
104
}
116
105
117
106
// Lower the contents of an scf::if/scf::index_switch regions to an
118
107
// emitc::if/emitc::switch region. The contents of the lowering region is
119
108
// moved into the respective lowered region, but the scf::yield is replaced not
120
109
// only with an emitc::yield, but also with a sequence of emitc::assign ops that
121
110
// set the yielded values into the result variables.
122
- static LogicalResult lowerRegion (Operation *op, ValueRange resultVariables,
123
- ConversionPatternRewriter &rewriter ,
124
- Region ®ion, Region &loweredRegion) {
111
+ static void lowerRegion (SmallVector<Value> & resultVariables,
112
+ PatternRewriter &rewriter, Region ®ion ,
113
+ Region &loweredRegion) {
125
114
rewriter.inlineRegionBefore (region, loweredRegion, loweredRegion.end ());
126
115
Operation *terminator = loweredRegion.back ().getTerminator ();
127
- return lowerYield (op, resultVariables, rewriter,
128
- cast<scf::YieldOp>(terminator));
116
+ lowerYield (resultVariables, rewriter, cast<scf::YieldOp>(terminator));
129
117
}
130
118
131
- LogicalResult
132
- ForLowering::matchAndRewrite (ForOp forOp, OpAdaptor adaptor,
133
- ConversionPatternRewriter &rewriter) const {
119
+ LogicalResult ForLowering::matchAndRewrite (ForOp forOp,
120
+ PatternRewriter &rewriter) const {
134
121
Location loc = forOp.getLoc ();
135
122
136
123
// Create an emitc::variable op for each result. These variables will be
137
124
// assigned to by emitc::assign ops within the loop body.
138
- SmallVector<Value> resultVariables;
139
- if (failed (createVariablesForResults (forOp, getTypeConverter (), rewriter,
140
- resultVariables)))
141
- return rewriter.notifyMatchFailure (forOp,
142
- " create variables for results failed" );
125
+ SmallVector<Value> resultVariables =
126
+ createVariablesForResults (forOp, rewriter);
143
127
144
- assignValues (adaptor. getInitArgs (), resultVariables, rewriter, loc);
128
+ assignValues (forOp. getInits (), resultVariables, rewriter, loc);
145
129
146
130
emitc::ForOp loweredFor = rewriter.create <emitc::ForOp>(
147
- loc, adaptor .getLowerBound (), adaptor .getUpperBound (), adaptor .getStep ());
131
+ loc, forOp .getLowerBound (), forOp .getUpperBound (), forOp .getStep ());
148
132
149
133
Block *loweredBody = loweredFor.getBody ();
150
134
@@ -159,27 +143,13 @@ ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
159
143
160
144
rewriter.restoreInsertionPoint (ip);
161
145
162
- // Convert the original region types into the new types by adding unrealized
163
- // casts in the beginning of the loop. This performs the conversion in place.
164
- if (failed (rewriter.convertRegionTypes (&forOp.getRegion (),
165
- *getTypeConverter (), nullptr ))) {
166
- return rewriter.notifyMatchFailure (forOp, " region types conversion failed" );
167
- }
168
-
169
- // Register the replacements for the block arguments and inline the body of
170
- // the scf.for loop into the body of the emitc::for loop.
171
- Block *scfBody = &(forOp.getRegion ().front ());
172
146
SmallVector<Value> replacingValues;
173
147
replacingValues.push_back (loweredFor.getInductionVar ());
174
148
replacingValues.append (iterArgsValues.begin (), iterArgsValues.end ());
175
- rewriter.mergeBlocks (scfBody, loweredBody, replacingValues);
176
149
177
- auto result = lowerYield (forOp, resultVariables, rewriter,
178
- cast<scf::YieldOp>(loweredBody->getTerminator ()));
179
-
180
- if (failed (result)) {
181
- return result;
182
- }
150
+ rewriter.mergeBlocks (forOp.getBody (), loweredBody, replacingValues);
151
+ lowerYield (resultVariables, rewriter,
152
+ cast<scf::YieldOp>(loweredBody->getTerminator ()));
183
153
184
154
// Load variables into SSA values after the for loop.
185
155
SmallVector<Value> resultValues = loadValues (resultVariables, rewriter, loc);
@@ -190,66 +160,38 @@ ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
190
160
191
161
// Lower scf::if to emitc::if, implementing result values as emitc::variable's
192
162
// updated within the then and else regions.
193
- struct IfLowering : public OpConversionPattern <IfOp> {
194
- using OpConversionPattern <IfOp>::OpConversionPattern ;
163
+ struct IfLowering : public OpRewritePattern <IfOp> {
164
+ using OpRewritePattern <IfOp>::OpRewritePattern ;
195
165
196
- LogicalResult
197
- matchAndRewrite (IfOp ifOp, OpAdaptor adaptor,
198
- ConversionPatternRewriter &rewriter) const override ;
166
+ LogicalResult matchAndRewrite (IfOp ifOp,
167
+ PatternRewriter &rewriter) const override ;
199
168
};
200
169
201
170
} // namespace
202
171
203
- LogicalResult
204
- IfLowering::matchAndRewrite (IfOp ifOp, OpAdaptor adaptor,
205
- ConversionPatternRewriter &rewriter) const {
172
+ LogicalResult IfLowering::matchAndRewrite (IfOp ifOp,
173
+ PatternRewriter &rewriter) const {
206
174
Location loc = ifOp.getLoc ();
207
175
208
176
// Create an emitc::variable op for each result. These variables will be
209
177
// assigned to by emitc::assign ops within the then & else regions.
210
- SmallVector<Value> resultVariables;
211
- if (failed (createVariablesForResults (ifOp, getTypeConverter (), rewriter,
212
- resultVariables)))
213
- return rewriter.notifyMatchFailure (ifOp,
214
- " create variables for results failed" );
215
-
216
- // Utility function to lower the contents of an scf::if region to an emitc::if
217
- // region. The contents of the scf::if regions is moved into the respective
218
- // emitc::if regions, but the scf::yield is replaced not only with an
219
- // emitc::yield, but also with a sequence of emitc::assign ops that set the
220
- // yielded values into the result variables.
221
- auto lowerRegion = [&resultVariables, &rewriter,
222
- &ifOp](Region ®ion, Region &loweredRegion) {
223
- rewriter.inlineRegionBefore (region, loweredRegion, loweredRegion.end ());
224
- Operation *terminator = loweredRegion.back ().getTerminator ();
225
- auto result = lowerYield (ifOp, resultVariables, rewriter,
226
- cast<scf::YieldOp>(terminator));
227
- if (failed (result)) {
228
- return result;
229
- }
230
- return success ();
231
- };
232
-
233
- Region &thenRegion = adaptor.getThenRegion ();
234
- Region &elseRegion = adaptor.getElseRegion ();
178
+ SmallVector<Value> resultVariables =
179
+ createVariablesForResults (ifOp, rewriter);
180
+
181
+ Region &thenRegion = ifOp.getThenRegion ();
182
+ Region &elseRegion = ifOp.getElseRegion ();
235
183
236
184
bool hasElseBlock = !elseRegion.empty ();
237
185
238
186
auto loweredIf =
239
- rewriter.create <emitc::IfOp>(loc, adaptor .getCondition (), false , false );
187
+ rewriter.create <emitc::IfOp>(loc, ifOp .getCondition (), false , false );
240
188
241
189
Region &loweredThenRegion = loweredIf.getThenRegion ();
242
- auto result = lowerRegion (thenRegion, loweredThenRegion);
243
- if (failed (result)) {
244
- return result;
245
- }
190
+ lowerRegion (resultVariables, rewriter, thenRegion, loweredThenRegion);
246
191
247
192
if (hasElseBlock) {
248
193
Region &loweredElseRegion = loweredIf.getElseRegion ();
249
- auto result = lowerRegion (elseRegion, loweredElseRegion);
250
- if (failed (result)) {
251
- return result;
252
- }
194
+ lowerRegion (resultVariables, rewriter, elseRegion, loweredElseRegion);
253
195
}
254
196
255
197
rewriter.setInsertionPointAfter (ifOp);
@@ -261,46 +203,37 @@ IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
261
203
262
204
// Lower scf::index_switch to emitc::switch, implementing result values as
263
205
// emitc::variable's updated within the case and default regions.
264
- struct IndexSwitchOpLowering : public OpConversionPattern <IndexSwitchOp> {
265
- using OpConversionPattern::OpConversionPattern ;
206
+ struct IndexSwitchOpLowering : public OpRewritePattern <IndexSwitchOp> {
207
+ using OpRewritePattern<IndexSwitchOp>::OpRewritePattern ;
266
208
267
- LogicalResult
268
- matchAndRewrite (IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
269
- ConversionPatternRewriter &rewriter) const override ;
209
+ LogicalResult matchAndRewrite (IndexSwitchOp indexSwitchOp,
210
+ PatternRewriter &rewriter) const override ;
270
211
};
271
212
272
- LogicalResult IndexSwitchOpLowering::matchAndRewrite (
273
- IndexSwitchOp indexSwitchOp, OpAdaptor adaptor ,
274
- ConversionPatternRewriter &rewriter) const {
213
+ LogicalResult
214
+ IndexSwitchOpLowering::matchAndRewrite ( IndexSwitchOp indexSwitchOp,
215
+ PatternRewriter &rewriter) const {
275
216
Location loc = indexSwitchOp.getLoc ();
276
217
277
218
// Create an emitc::variable op for each result. These variables will be
278
219
// assigned to by emitc::assign ops within the case and default regions.
279
- SmallVector<Value> resultVariables;
280
- if (failed (createVariablesForResults (indexSwitchOp, getTypeConverter (),
281
- rewriter, resultVariables))) {
282
- return rewriter.notifyMatchFailure (indexSwitchOp,
283
- " create variables for results failed" );
284
- }
220
+ SmallVector<Value> resultVariables =
221
+ createVariablesForResults (indexSwitchOp, rewriter);
285
222
286
223
auto loweredSwitch = rewriter.create <emitc::SwitchOp>(
287
- loc, adaptor.getArg (), adaptor.getCases (), indexSwitchOp.getNumCases ());
224
+ loc, indexSwitchOp.getArg (), indexSwitchOp.getCases (),
225
+ indexSwitchOp.getNumCases ());
288
226
289
227
// Lowering all case regions.
290
- for (auto pair :
291
- llvm::zip (adaptor.getCaseRegions (), loweredSwitch.getCaseRegions ())) {
292
- if (failed (lowerRegion (indexSwitchOp, resultVariables, rewriter,
293
- *std::get<0 >(pair), std::get<1 >(pair)))) {
294
- return failure ();
295
- }
228
+ for (auto pair : llvm::zip (indexSwitchOp.getCaseRegions (),
229
+ loweredSwitch.getCaseRegions ())) {
230
+ lowerRegion (resultVariables, rewriter, std::get<0 >(pair),
231
+ std::get<1 >(pair));
296
232
}
297
233
298
234
// Lowering default region.
299
- if (failed (lowerRegion (indexSwitchOp, resultVariables, rewriter,
300
- adaptor.getDefaultRegion (),
301
- loweredSwitch.getDefaultRegion ()))) {
302
- return failure ();
303
- }
235
+ lowerRegion (resultVariables, rewriter, indexSwitchOp.getDefaultRegion (),
236
+ loweredSwitch.getDefaultRegion ());
304
237
305
238
rewriter.setInsertionPointAfter (indexSwitchOp);
306
239
SmallVector<Value> results = loadValues (resultVariables, rewriter, loc);
@@ -309,22 +242,15 @@ LogicalResult IndexSwitchOpLowering::matchAndRewrite(
309
242
return success ();
310
243
}
311
244
312
- void mlir::populateSCFToEmitCConversionPatterns (RewritePatternSet &patterns,
313
- TypeConverter &typeConverter) {
314
- patterns.add <ForLowering>(typeConverter, patterns.getContext ());
315
- patterns.add <IfLowering>(typeConverter, patterns.getContext ());
316
- patterns.add <IndexSwitchOpLowering>(typeConverter, patterns.getContext ());
245
+ void mlir::populateSCFToEmitCConversionPatterns (RewritePatternSet &patterns) {
246
+ patterns.add <ForLowering>(patterns.getContext ());
247
+ patterns.add <IfLowering>(patterns.getContext ());
248
+ patterns.add <IndexSwitchOpLowering>(patterns.getContext ());
317
249
}
318
250
319
251
void SCFToEmitCPass::runOnOperation () {
320
252
RewritePatternSet patterns (&getContext ());
321
- TypeConverter typeConverter;
322
- // Fallback converter
323
- // See note https://mlir.llvm.org/docs/DialectConversion/#type-converter
324
- // Type converters are called most to least recently inserted
325
- typeConverter.addConversion ([](Type t) { return t; });
326
- populateEmitCSizeTTypeConversions (typeConverter);
327
- populateSCFToEmitCConversionPatterns (patterns, typeConverter);
253
+ populateSCFToEmitCConversionPatterns (patterns);
328
254
329
255
// Configure conversion to lower out SCF operations.
330
256
ConversionTarget target (getContext ());
0 commit comments