Skip to content

Commit c8275bc

Browse files
Simon CamphausenSimon Camphausen
authored andcommitted
Use TypeConverter in FuncToEmitC conversion
1 parent f6adbc2 commit c8275bc

File tree

7 files changed

+50
-14
lines changed

7 files changed

+50
-14
lines changed

mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitC.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111

1212
namespace mlir {
1313
class RewritePatternSet;
14+
class TypeConverter;
1415

15-
void populateFuncToEmitCPatterns(RewritePatternSet &patterns);
16+
void populateFuncToEmitCPatterns(const TypeConverter &typeConverter,
17+
RewritePatternSet &patterns);
1618
} // namespace mlir
1719

1820
#endif // MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITC_H

mlir/lib/Conversion/ConvertToEmitC/ConvertToEmitC.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,7 @@ void mlir::populateConvertToEmitCTypeConverter(TypeConverter &typeConverter) {
3535
void mlir::populateConvertToEmitCPatterns(TypeConverter &typeConverter,
3636
RewritePatternSet &patterns) {
3737
populateArithToEmitCPatterns(typeConverter, patterns);
38-
populateFuncToEmitCPatterns(patterns);
38+
populateFuncToEmitCPatterns(typeConverter, patterns);
3939
populateMemRefToEmitCConversionPatterns(patterns, typeConverter);
4040
populateSCFToEmitCConversionPatterns(patterns);
41-
populateFunctionOpInterfaceTypeConversionPattern<emitc::FuncOp>(
42-
patterns, typeConverter);
4341
}

mlir/lib/Conversion/ConvertToEmitC/ConvertToEmitCPass.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,6 @@ struct ConvertToEmitC final : impl::ConvertToEmitCBase<ConvertToEmitC> {
4545
target.addIllegalDialect<arith::ArithDialect, func::FuncDialect,
4646
memref::MemRefDialect, scf::SCFDialect>();
4747
target.addLegalDialect<emitc::EmitCDialect>();
48-
target.addDynamicallyLegalOp<emitc::FuncOp>(
49-
[&typeConverter](emitc::FuncOp op) {
50-
return typeConverter.isSignatureLegal(op.getFunctionType());
51-
});
5248

5349
populateConvertToEmitCTypeConverter(typeConverter);
5450
populateConvertToEmitCPatterns(typeConverter, patterns);

mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,34 @@ class FuncOpConversion final : public OpConversionPattern<func::FuncOp> {
5151
LogicalResult
5252
matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
5353
ConversionPatternRewriter &rewriter) const override {
54+
FunctionType fnType = funcOp.getFunctionType();
5455

55-
if (funcOp.getFunctionType().getNumResults() > 1)
56+
if (fnType.getNumResults() > 1)
5657
return rewriter.notifyMatchFailure(
5758
funcOp, "only functions with zero or one result can be converted");
5859

60+
TypeConverter::SignatureConversion signatureConverter(
61+
fnType.getNumInputs());
62+
for (const auto &argType : enumerate(fnType.getInputs())) {
63+
auto convertedType = getTypeConverter()->convertType(argType.value());
64+
if (!convertedType)
65+
return failure();
66+
signatureConverter.addInputs(argType.index(), convertedType);
67+
}
68+
69+
Type resultType;
70+
if (fnType.getNumResults() == 1) {
71+
resultType = getTypeConverter()->convertType(fnType.getResult(0));
72+
if (!resultType)
73+
return failure();
74+
}
75+
5976
// Create the converted `emitc.func` op.
6077
emitc::FuncOp newFuncOp = rewriter.create<emitc::FuncOp>(
61-
funcOp.getLoc(), funcOp.getName(), funcOp.getFunctionType());
78+
funcOp.getLoc(), funcOp.getName(),
79+
FunctionType::get(rewriter.getContext(),
80+
signatureConverter.getConvertedTypes(),
81+
resultType ? TypeRange(resultType) : TypeRange()));
6282

6383
// Copy over all attributes other than the function name and type.
6484
for (const auto &namedAttr : funcOp->getAttrs()) {
@@ -80,9 +100,13 @@ class FuncOpConversion final : public OpConversionPattern<func::FuncOp> {
80100
newFuncOp.setSpecifiersAttr(specifiers);
81101
}
82102

83-
if (!funcOp.isDeclaration())
103+
if (!funcOp.isDeclaration()) {
84104
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
85105
newFuncOp.end());
106+
if (failed(rewriter.convertRegionTypes(
107+
&newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
108+
return failure();
109+
}
86110
rewriter.eraseOp(funcOp);
87111

88112
return success();
@@ -112,8 +136,10 @@ class ReturnOpConversion final : public OpConversionPattern<func::ReturnOp> {
112136
// Pattern population
113137
//===----------------------------------------------------------------------===//
114138

115-
void mlir::populateFuncToEmitCPatterns(RewritePatternSet &patterns) {
139+
void mlir::populateFuncToEmitCPatterns(const TypeConverter &typeConverter,
140+
RewritePatternSet &patterns) {
116141
MLIRContext *ctx = patterns.getContext();
117142

118-
patterns.add<CallOpConversion, FuncOpConversion, ReturnOpConversion>(ctx);
143+
patterns.add<CallOpConversion, FuncOpConversion, ReturnOpConversion>(
144+
typeConverter, ctx);
119145
}

mlir/lib/Conversion/FuncToEmitC/FuncToEmitCPass.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,11 @@ void ConvertFuncToEmitC::runOnOperation() {
3939
target.addIllegalOp<func::CallOp, func::FuncOp, func::ReturnOp>();
4040

4141
RewritePatternSet patterns(&getContext());
42-
populateFuncToEmitCPatterns(patterns);
42+
43+
TypeConverter typeConverter;
44+
typeConverter.addConversion([](Type type) { return type; });
45+
46+
populateFuncToEmitCPatterns(typeConverter, patterns);
4347

4448
if (failed(
4549
applyPartialConversion(getOperation(), target, std::move(patterns))))
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// RUN: mlir-opt -convert-to-emitc %s -split-input-file -verify-diagnostics
2+
3+
func.func @block_args(%arg0: i1, %arg1: index, %arg2: index) -> index {
4+
// expected-error @+1 {{type mismatch for bb argument #0 of successor #0}}
5+
cf.cond_br %arg0, ^bb1(%arg1: index), ^bb2(%arg2: index)
6+
^bb1(%0: index):
7+
return %0 : index
8+
^bb2(%1: index):
9+
return %1 : index
10+
}

0 commit comments

Comments
 (0)