Skip to content

Commit 69cc7e1

Browse files
authored
[FXML-4614] Use EmitC index types in all passes creating EmitC (#186)
1 parent 23bebca commit 69cc7e1

File tree

17 files changed

+272
-69
lines changed

17 files changed

+272
-69
lines changed

mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ namespace mlir {
1313
class RewritePatternSet;
1414
class TypeConverter;
1515

16-
void populateArithToEmitCPatterns(TypeConverter &typeConverter,
17-
RewritePatternSet &patterns);
16+
void populateArithToEmitCPatterns(RewritePatternSet &patterns,
17+
TypeConverter &typeConverter);
1818
} // namespace mlir
1919

2020
#endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H

mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitC.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
#ifndef MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITC_H
1010
#define MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITC_H
1111

12+
#include "mlir/Transforms/DialectConversion.h"
1213
namespace mlir {
1314
class RewritePatternSet;
1415

15-
void populateFuncToEmitCPatterns(RewritePatternSet &patterns);
16+
void populateFuncToEmitCPatterns(RewritePatternSet &patterns,
17+
TypeConverter &typeConverter);
1618
} // namespace mlir
1719

1820
#endif // MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITC_H

mlir/include/mlir/Conversion/SCFToEmitC/SCFToEmitC.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#ifndef MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H
1010
#define MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H
1111

12+
#include "mlir/Transforms/DialectConversion.h"
1213
#include <memory>
1314

1415
namespace mlir {
@@ -19,7 +20,8 @@ class RewritePatternSet;
1920
#include "mlir/Conversion/Passes.h.inc"
2021

2122
/// Collect a set of patterns to convert SCF operations to the EmitC dialect.
22-
void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns);
23+
void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns,
24+
TypeConverter &typeConverter);
2325
} // namespace mlir
2426

2527
#endif // MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H

mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def EmitC_ArrayType : EmitC_Type<"Array", "array", [ShapedTypeInterface]> {
7676

7777
static bool isValidElementType(Type type) {
7878
return type.isIntOrIndexOrFloat() ||
79+
emitc::isAnySizeTType(type) ||
7980
llvm::isa<PointerType, OpaqueType>(type);
8081
}
8182
}];

mlir/include/mlir/Dialect/EmitC/Transforms/TypeConversions.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,12 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#include "mlir/Transforms/DialectConversion.h"
9+
#ifndef MLIR_DIALECT_EMITC_TRANSFORMS_TYPECONVERSIONS_H
10+
#define MLIR_DIALECT_EMITC_TRANSFORMS_TYPECONVERSIONS_H
1011

1112
namespace mlir {
12-
void populateEmitCSizeTypeConversionPatterns(mlir::TypeConverter &converter);
13+
class TypeConverter;
14+
void populateEmitCSizeTypeConversions(TypeConverter &converter);
1315
} // namespace mlir
16+
17+
#endif // MLIR_DIALECT_EMITC_TRANSFORMS_TYPECONVERSIONS_H

mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -563,11 +563,11 @@ class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
563563
// Pattern population
564564
//===----------------------------------------------------------------------===//
565565

566-
void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
567-
RewritePatternSet &patterns) {
566+
void mlir::populateArithToEmitCPatterns(RewritePatternSet &patterns,
567+
TypeConverter &typeConverter) {
568568
MLIRContext *ctx = patterns.getContext();
569569

570-
mlir::populateEmitCSizeTypeConversionPatterns(typeConverter);
570+
mlir::populateEmitCSizeTypeConversions(typeConverter);
571571

572572
// clang-format off
573573
patterns.add<

mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
1717
#include "mlir/Dialect/Arith/IR/Arith.h"
1818
#include "mlir/Dialect/EmitC/IR/EmitC.h"
19+
#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
1920
#include "mlir/Pass/Pass.h"
2021
#include "mlir/Transforms/DialectConversion.h"
2122

@@ -43,9 +44,11 @@ void ConvertArithToEmitC::runOnOperation() {
4344
RewritePatternSet patterns(&getContext());
4445

4546
TypeConverter typeConverter;
46-
typeConverter.addConversion([](Type type) { return type; });
47-
48-
populateArithToEmitCPatterns(typeConverter, patterns);
47+
// Fallback converter
48+
// See note https://mlir.llvm.org/docs/DialectConversion/#type-converter
49+
// Type converters are called most to least recently inserted
50+
typeConverter.addConversion([](Type t) { return t; });
51+
populateArithToEmitCPatterns(patterns, typeConverter);
4952

5053
if (failed(
5154
applyPartialConversion(getOperation(), target, std::move(patterns))))

mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,17 @@ class CallOpConversion final : public OpConversionPattern<func::CallOp> {
3636
return rewriter.notifyMatchFailure(
3737
callOp, "only functions with zero or one result can be converted");
3838

39+
// Convert the original function results.
40+
Type resultTy = nullptr;
41+
if (callOp.getNumResults()) {
42+
resultTy = typeConverter->convertType(callOp.getResult(0).getType());
43+
if (!resultTy)
44+
return rewriter.notifyMatchFailure(
45+
callOp, "function return type conversion failed");
46+
}
47+
3948
rewriter.replaceOpWithNewOp<emitc::CallOp>(
40-
callOp,
41-
callOp.getNumResults() ? callOp.getResult(0).getType() : nullptr,
42-
adaptor.getOperands(), callOp->getAttrs());
49+
callOp, resultTy, adaptor.getOperands(), callOp->getAttrs());
4350

4451
return success();
4552
}
@@ -53,13 +60,34 @@ class FuncOpConversion final : public OpConversionPattern<func::FuncOp> {
5360
matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
5461
ConversionPatternRewriter &rewriter) const override {
5562

56-
if (funcOp.getFunctionType().getNumResults() > 1)
63+
FunctionType type = funcOp.getFunctionType();
64+
if (!type)
65+
return failure();
66+
67+
if (type.getNumResults() > 1)
5768
return rewriter.notifyMatchFailure(
5869
funcOp, "only functions with zero or one result can be converted");
5970

71+
const TypeConverter *converter = getTypeConverter();
72+
73+
// Convert function signature
74+
TypeConverter::SignatureConversion signatureConversion(type.getNumInputs());
75+
SmallVector<Type, 1> convertedResults;
76+
if (failed(converter->convertSignatureArgs(type.getInputs(),
77+
signatureConversion)) ||
78+
failed(converter->convertTypes(type.getResults(), convertedResults)) ||
79+
failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
80+
*converter, &signatureConversion)))
81+
return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
82+
83+
// Convert the function type
84+
auto convertedFunctionType = FunctionType::get(
85+
rewriter.getContext(), signatureConversion.getConvertedTypes(),
86+
convertedResults);
87+
6088
// Create the converted `emitc.func` op.
6189
emitc::FuncOp newFuncOp = rewriter.create<emitc::FuncOp>(
62-
funcOp.getLoc(), funcOp.getName(), funcOp.getFunctionType());
90+
funcOp.getLoc(), funcOp.getName(), convertedFunctionType);
6391

6492
// Copy over all attributes other than the function name and type.
6593
for (const auto &namedAttr : funcOp->getAttrs()) {
@@ -113,8 +141,10 @@ class ReturnOpConversion final : public OpConversionPattern<func::ReturnOp> {
113141
// Pattern population
114142
//===----------------------------------------------------------------------===//
115143

116-
void mlir::populateFuncToEmitCPatterns(RewritePatternSet &patterns) {
144+
void mlir::populateFuncToEmitCPatterns(RewritePatternSet &patterns,
145+
TypeConverter &typeConverter) {
117146
MLIRContext *ctx = patterns.getContext();
118147

119-
patterns.add<CallOpConversion, FuncOpConversion, ReturnOpConversion>(ctx);
148+
patterns.add<CallOpConversion, FuncOpConversion, ReturnOpConversion>(
149+
typeConverter, ctx);
120150
}

mlir/lib/Conversion/FuncToEmitC/FuncToEmitCPass.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "mlir/Conversion/FuncToEmitC/FuncToEmitC.h"
1616
#include "mlir/Dialect/EmitC/IR/EmitC.h"
17+
#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
1718
#include "mlir/Dialect/Func/IR/FuncOps.h"
1819
#include "mlir/Pass/Pass.h"
1920
#include "mlir/Transforms/DialectConversion.h"
@@ -33,13 +34,20 @@ struct ConvertFuncToEmitC
3334
} // namespace
3435

3536
void ConvertFuncToEmitC::runOnOperation() {
37+
TypeConverter typeConverter;
38+
// Fallback converter
39+
// See note https://mlir.llvm.org/docs/DialectConversion/#type-converter
40+
// Type converters are called most to least recently inserted
41+
typeConverter.addConversion([](Type t) { return t; });
42+
populateEmitCSizeTypeConversions(typeConverter);
43+
3644
ConversionTarget target(getContext());
3745

3846
target.addLegalDialect<emitc::EmitCDialect>();
3947
target.addIllegalOp<func::CallOp, func::FuncOp, func::ReturnOp>();
4048

4149
RewritePatternSet patterns(&getContext());
42-
populateFuncToEmitCPatterns(patterns);
50+
populateFuncToEmitCPatterns(patterns, typeConverter);
4351

4452
if (failed(
4553
applyPartialConversion(getOperation(), target, std::move(patterns))))

mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
1616
#include "mlir/Dialect/EmitC/IR/EmitC.h"
17+
#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
1718
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1819
#include "mlir/Pass/Pass.h"
1920
#include "mlir/Transforms/DialectConversion.h"
@@ -33,12 +34,13 @@ struct ConvertMemRefToEmitCPass
3334

3435
// Fallback for other types.
3536
converter.addConversion([](Type type) -> std::optional<Type> {
36-
if (isa<MemRefType>(type))
37-
return {};
38-
return type;
37+
if (emitc::isSupportedEmitCType(type))
38+
return type;
39+
return {};
3940
});
4041

4142
populateMemRefToEmitCTypeConversion(converter);
43+
populateEmitCSizeTypeConversions(converter);
4244

4345
RewritePatternSet patterns(&getContext());
4446
populateMemRefToEmitCConversionPatterns(patterns, converter);

0 commit comments

Comments
 (0)