Skip to content

Commit df728cf

Browse files
committed
Revert "[MLIR][SCFToEmitC] Convert types while converting from SCF to EmitC (llvm#118940)"
This reverts commit 450c6b0.
1 parent 450c6b0 commit df728cf

File tree

4 files changed

+80
-228
lines changed

4 files changed

+80
-228
lines changed

mlir/include/mlir/Conversion/SCFToEmitC/SCFToEmitC.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#ifndef MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H
1010
#define MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H
1111

12-
#include "mlir/Transforms/DialectConversion.h"
1312
#include <memory>
1413

1514
namespace mlir {
@@ -20,8 +19,7 @@ class RewritePatternSet;
2019
#include "mlir/Conversion/Passes.h.inc"
2120

2221
/// Collect a set of patterns to convert SCF operations to the EmitC dialect.
23-
void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns,
24-
TypeConverter &typeConverter);
22+
void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns);
2523
} // namespace mlir
2624

2725
#endif // MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H

mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp

Lines changed: 66 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
#include "mlir/Dialect/Arith/IR/Arith.h"
1616
#include "mlir/Dialect/EmitC/IR/EmitC.h"
17-
#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
1817
#include "mlir/Dialect/SCF/IR/SCF.h"
1918
#include "mlir/IR/Builders.h"
2019
#include "mlir/IR/BuiltinOps.h"
@@ -40,22 +39,21 @@ struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
4039

4140
// Lower scf::for to emitc::for, implementing result values using
4241
// emitc::variable's updated within the loop body.
43-
struct ForLowering : public OpConversionPattern<ForOp> {
44-
using OpConversionPattern<ForOp>::OpConversionPattern;
42+
struct ForLowering : public OpRewritePattern<ForOp> {
43+
using OpRewritePattern<ForOp>::OpRewritePattern;
4544

46-
LogicalResult
47-
matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
48-
ConversionPatternRewriter &rewriter) const override;
45+
LogicalResult matchAndRewrite(ForOp forOp,
46+
PatternRewriter &rewriter) const override;
4947
};
5048

5149
// Create an uninitialized emitc::variable op for each result of the given op.
5250
template <typename T>
53-
static LogicalResult
54-
createVariablesForResults(T op, const TypeConverter *typeConverter,
55-
ConversionPatternRewriter &rewriter,
56-
SmallVector<Value> &resultVariables) {
51+
static SmallVector<Value> createVariablesForResults(T op,
52+
PatternRewriter &rewriter) {
53+
SmallVector<Value> resultVariables;
54+
5755
if (!op.getNumResults())
58-
return success();
56+
return resultVariables;
5957

6058
Location loc = op->getLoc();
6159
MLIRContext *context = op.getContext();
@@ -64,23 +62,21 @@ createVariablesForResults(T op, const TypeConverter *typeConverter,
6462
rewriter.setInsertionPoint(op);
6563

6664
for (OpResult result : op.getResults()) {
67-
Type resultType = typeConverter->convertType(result.getType());
68-
if (!resultType)
69-
return rewriter.notifyMatchFailure(op, "result type conversion failed");
65+
Type resultType = result.getType();
7066
Type varType = emitc::LValueType::get(resultType);
7167
emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
7268
emitc::VariableOp var =
7369
rewriter.create<emitc::VariableOp>(loc, varType, noInit);
7470
resultVariables.push_back(var);
7571
}
7672

77-
return success();
73+
return resultVariables;
7874
}
7975

8076
// Create a series of assign ops assigning given values to given variables at
8177
// the current insertion point of given rewriter.
82-
static void assignValues(ValueRange values, ValueRange variables,
83-
ConversionPatternRewriter &rewriter, Location loc) {
78+
static void assignValues(ValueRange values, SmallVector<Value> &variables,
79+
PatternRewriter &rewriter, Location loc) {
8480
for (auto [value, var] : llvm::zip(values, variables))
8581
rewriter.create<emitc::AssignOp>(loc, var, value);
8682
}
@@ -93,58 +89,46 @@ SmallVector<Value> loadValues(const SmallVector<Value> &variables,
9389
});
9490
}
9591

96-
static LogicalResult lowerYield(Operation *op, ValueRange resultVariables,
97-
ConversionPatternRewriter &rewriter,
98-
scf::YieldOp yield) {
92+
static void lowerYield(SmallVector<Value> &resultVariables,
93+
PatternRewriter &rewriter, scf::YieldOp yield) {
9994
Location loc = yield.getLoc();
95+
ValueRange operands = yield.getOperands();
10096

10197
OpBuilder::InsertionGuard guard(rewriter);
10298
rewriter.setInsertionPoint(yield);
10399

104-
SmallVector<Value> yieldOperands;
105-
if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands))) {
106-
return rewriter.notifyMatchFailure(op, "failed to lower yield operands");
107-
}
108-
109-
assignValues(yieldOperands, resultVariables, rewriter, loc);
100+
assignValues(operands, resultVariables, rewriter, loc);
110101

111102
rewriter.create<emitc::YieldOp>(loc);
112103
rewriter.eraseOp(yield);
113-
114-
return success();
115104
}
116105

117106
// Lower the contents of an scf::if/scf::index_switch regions to an
118107
// emitc::if/emitc::switch region. The contents of the lowering region is
119108
// moved into the respective lowered region, but the scf::yield is replaced not
120109
// only with an emitc::yield, but also with a sequence of emitc::assign ops that
121110
// set the yielded values into the result variables.
122-
static LogicalResult lowerRegion(Operation *op, ValueRange resultVariables,
123-
ConversionPatternRewriter &rewriter,
124-
Region &region, Region &loweredRegion) {
111+
static void lowerRegion(SmallVector<Value> &resultVariables,
112+
PatternRewriter &rewriter, Region &region,
113+
Region &loweredRegion) {
125114
rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
126115
Operation *terminator = loweredRegion.back().getTerminator();
127-
return lowerYield(op, resultVariables, rewriter,
128-
cast<scf::YieldOp>(terminator));
116+
lowerYield(resultVariables, rewriter, cast<scf::YieldOp>(terminator));
129117
}
130118

131-
LogicalResult
132-
ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
133-
ConversionPatternRewriter &rewriter) const {
119+
LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
120+
PatternRewriter &rewriter) const {
134121
Location loc = forOp.getLoc();
135122

136123
// Create an emitc::variable op for each result. These variables will be
137124
// assigned to by emitc::assign ops within the loop body.
138-
SmallVector<Value> resultVariables;
139-
if (failed(createVariablesForResults(forOp, getTypeConverter(), rewriter,
140-
resultVariables)))
141-
return rewriter.notifyMatchFailure(forOp,
142-
"create variables for results failed");
125+
SmallVector<Value> resultVariables =
126+
createVariablesForResults(forOp, rewriter);
143127

144-
assignValues(adaptor.getInitArgs(), resultVariables, rewriter, loc);
128+
assignValues(forOp.getInits(), resultVariables, rewriter, loc);
145129

146130
emitc::ForOp loweredFor = rewriter.create<emitc::ForOp>(
147-
loc, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep());
131+
loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep());
148132

149133
Block *loweredBody = loweredFor.getBody();
150134

@@ -159,27 +143,13 @@ ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
159143

160144
rewriter.restoreInsertionPoint(ip);
161145

162-
// Convert the original region types into the new types by adding unrealized
163-
// casts in the beginning of the loop. This performs the conversion in place.
164-
if (failed(rewriter.convertRegionTypes(&forOp.getRegion(),
165-
*getTypeConverter(), nullptr))) {
166-
return rewriter.notifyMatchFailure(forOp, "region types conversion failed");
167-
}
168-
169-
// Register the replacements for the block arguments and inline the body of
170-
// the scf.for loop into the body of the emitc::for loop.
171-
Block *scfBody = &(forOp.getRegion().front());
172146
SmallVector<Value> replacingValues;
173147
replacingValues.push_back(loweredFor.getInductionVar());
174148
replacingValues.append(iterArgsValues.begin(), iterArgsValues.end());
175-
rewriter.mergeBlocks(scfBody, loweredBody, replacingValues);
176149

177-
auto result = lowerYield(forOp, resultVariables, rewriter,
178-
cast<scf::YieldOp>(loweredBody->getTerminator()));
179-
180-
if (failed(result)) {
181-
return result;
182-
}
150+
rewriter.mergeBlocks(forOp.getBody(), loweredBody, replacingValues);
151+
lowerYield(resultVariables, rewriter,
152+
cast<scf::YieldOp>(loweredBody->getTerminator()));
183153

184154
// Load variables into SSA values after the for loop.
185155
SmallVector<Value> resultValues = loadValues(resultVariables, rewriter, loc);
@@ -190,66 +160,38 @@ ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
190160

191161
// Lower scf::if to emitc::if, implementing result values as emitc::variable's
192162
// updated within the then and else regions.
193-
struct IfLowering : public OpConversionPattern<IfOp> {
194-
using OpConversionPattern<IfOp>::OpConversionPattern;
163+
struct IfLowering : public OpRewritePattern<IfOp> {
164+
using OpRewritePattern<IfOp>::OpRewritePattern;
195165

196-
LogicalResult
197-
matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
198-
ConversionPatternRewriter &rewriter) const override;
166+
LogicalResult matchAndRewrite(IfOp ifOp,
167+
PatternRewriter &rewriter) const override;
199168
};
200169

201170
} // namespace
202171

203-
LogicalResult
204-
IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
205-
ConversionPatternRewriter &rewriter) const {
172+
LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
173+
PatternRewriter &rewriter) const {
206174
Location loc = ifOp.getLoc();
207175

208176
// Create an emitc::variable op for each result. These variables will be
209177
// assigned to by emitc::assign ops within the then & else regions.
210-
SmallVector<Value> resultVariables;
211-
if (failed(createVariablesForResults(ifOp, getTypeConverter(), rewriter,
212-
resultVariables)))
213-
return rewriter.notifyMatchFailure(ifOp,
214-
"create variables for results failed");
215-
216-
// Utility function to lower the contents of an scf::if region to an emitc::if
217-
// region. The contents of the scf::if regions is moved into the respective
218-
// emitc::if regions, but the scf::yield is replaced not only with an
219-
// emitc::yield, but also with a sequence of emitc::assign ops that set the
220-
// yielded values into the result variables.
221-
auto lowerRegion = [&resultVariables, &rewriter,
222-
&ifOp](Region &region, Region &loweredRegion) {
223-
rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
224-
Operation *terminator = loweredRegion.back().getTerminator();
225-
auto result = lowerYield(ifOp, resultVariables, rewriter,
226-
cast<scf::YieldOp>(terminator));
227-
if (failed(result)) {
228-
return result;
229-
}
230-
return success();
231-
};
232-
233-
Region &thenRegion = adaptor.getThenRegion();
234-
Region &elseRegion = adaptor.getElseRegion();
178+
SmallVector<Value> resultVariables =
179+
createVariablesForResults(ifOp, rewriter);
180+
181+
Region &thenRegion = ifOp.getThenRegion();
182+
Region &elseRegion = ifOp.getElseRegion();
235183

236184
bool hasElseBlock = !elseRegion.empty();
237185

238186
auto loweredIf =
239-
rewriter.create<emitc::IfOp>(loc, adaptor.getCondition(), false, false);
187+
rewriter.create<emitc::IfOp>(loc, ifOp.getCondition(), false, false);
240188

241189
Region &loweredThenRegion = loweredIf.getThenRegion();
242-
auto result = lowerRegion(thenRegion, loweredThenRegion);
243-
if (failed(result)) {
244-
return result;
245-
}
190+
lowerRegion(resultVariables, rewriter, thenRegion, loweredThenRegion);
246191

247192
if (hasElseBlock) {
248193
Region &loweredElseRegion = loweredIf.getElseRegion();
249-
auto result = lowerRegion(elseRegion, loweredElseRegion);
250-
if (failed(result)) {
251-
return result;
252-
}
194+
lowerRegion(resultVariables, rewriter, elseRegion, loweredElseRegion);
253195
}
254196

255197
rewriter.setInsertionPointAfter(ifOp);
@@ -261,46 +203,37 @@ IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
261203

262204
// Lower scf::index_switch to emitc::switch, implementing result values as
263205
// emitc::variable's updated within the case and default regions.
264-
struct IndexSwitchOpLowering : public OpConversionPattern<IndexSwitchOp> {
265-
using OpConversionPattern::OpConversionPattern;
206+
struct IndexSwitchOpLowering : public OpRewritePattern<IndexSwitchOp> {
207+
using OpRewritePattern<IndexSwitchOp>::OpRewritePattern;
266208

267-
LogicalResult
268-
matchAndRewrite(IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
269-
ConversionPatternRewriter &rewriter) const override;
209+
LogicalResult matchAndRewrite(IndexSwitchOp indexSwitchOp,
210+
PatternRewriter &rewriter) const override;
270211
};
271212

272-
LogicalResult IndexSwitchOpLowering::matchAndRewrite(
273-
IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
274-
ConversionPatternRewriter &rewriter) const {
213+
LogicalResult
214+
IndexSwitchOpLowering::matchAndRewrite(IndexSwitchOp indexSwitchOp,
215+
PatternRewriter &rewriter) const {
275216
Location loc = indexSwitchOp.getLoc();
276217

277218
// Create an emitc::variable op for each result. These variables will be
278219
// assigned to by emitc::assign ops within the case and default regions.
279-
SmallVector<Value> resultVariables;
280-
if (failed(createVariablesForResults(indexSwitchOp, getTypeConverter(),
281-
rewriter, resultVariables))) {
282-
return rewriter.notifyMatchFailure(indexSwitchOp,
283-
"create variables for results failed");
284-
}
220+
SmallVector<Value> resultVariables =
221+
createVariablesForResults(indexSwitchOp, rewriter);
285222

286223
auto loweredSwitch = rewriter.create<emitc::SwitchOp>(
287-
loc, adaptor.getArg(), adaptor.getCases(), indexSwitchOp.getNumCases());
224+
loc, indexSwitchOp.getArg(), indexSwitchOp.getCases(),
225+
indexSwitchOp.getNumCases());
288226

289227
// Lowering all case regions.
290-
for (auto pair :
291-
llvm::zip(adaptor.getCaseRegions(), loweredSwitch.getCaseRegions())) {
292-
if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
293-
*std::get<0>(pair), std::get<1>(pair)))) {
294-
return failure();
295-
}
228+
for (auto pair : llvm::zip(indexSwitchOp.getCaseRegions(),
229+
loweredSwitch.getCaseRegions())) {
230+
lowerRegion(resultVariables, rewriter, std::get<0>(pair),
231+
std::get<1>(pair));
296232
}
297233

298234
// Lowering default region.
299-
if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
300-
adaptor.getDefaultRegion(),
301-
loweredSwitch.getDefaultRegion()))) {
302-
return failure();
303-
}
235+
lowerRegion(resultVariables, rewriter, indexSwitchOp.getDefaultRegion(),
236+
loweredSwitch.getDefaultRegion());
304237

305238
rewriter.setInsertionPointAfter(indexSwitchOp);
306239
SmallVector<Value> results = loadValues(resultVariables, rewriter, loc);
@@ -309,22 +242,15 @@ LogicalResult IndexSwitchOpLowering::matchAndRewrite(
309242
return success();
310243
}
311244

312-
void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns,
313-
TypeConverter &typeConverter) {
314-
patterns.add<ForLowering>(typeConverter, patterns.getContext());
315-
patterns.add<IfLowering>(typeConverter, patterns.getContext());
316-
patterns.add<IndexSwitchOpLowering>(typeConverter, patterns.getContext());
245+
void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns) {
246+
patterns.add<ForLowering>(patterns.getContext());
247+
patterns.add<IfLowering>(patterns.getContext());
248+
patterns.add<IndexSwitchOpLowering>(patterns.getContext());
317249
}
318250

319251
void SCFToEmitCPass::runOnOperation() {
320252
RewritePatternSet patterns(&getContext());
321-
TypeConverter typeConverter;
322-
// Fallback converter
323-
// See note https://mlir.llvm.org/docs/DialectConversion/#type-converter
324-
// Type converters are called most to least recently inserted
325-
typeConverter.addConversion([](Type t) { return t; });
326-
populateEmitCSizeTTypeConversions(typeConverter);
327-
populateSCFToEmitCConversionPatterns(patterns, typeConverter);
253+
populateSCFToEmitCConversionPatterns(patterns);
328254

329255
// Configure conversion to lower out SCF operations.
330256
ConversionTarget target(getContext());

0 commit comments

Comments
 (0)