6
6
//
7
7
// ===----------------------------------------------------------------------===//
8
8
//
9
- // This file implements a pass to convert memref ops into emitc ops.
9
+ // This file implements patterns to convert memref ops into emitc ops.
10
10
//
11
11
// ===----------------------------------------------------------------------===//
12
12
13
13
#include " mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
14
14
15
15
#include " mlir/Dialect/EmitC/IR/EmitC.h"
16
- #include " mlir/Dialect/Func/IR/FuncOps.h"
17
- #include " mlir/Dialect/Func/Transforms/FuncConversions.h"
18
16
#include " mlir/Dialect/MemRef/IR/MemRef.h"
19
17
#include " mlir/IR/Builders.h"
20
- #include " mlir/IR/BuiltinOps.h"
21
- #include " mlir/IR/IRMapping.h"
22
- #include " mlir/IR/MLIRContext.h"
23
18
#include " mlir/IR/PatternMatch.h"
24
- #include " mlir/Interfaces/FunctionInterfaces.h"
25
19
#include " mlir/Transforms/DialectConversion.h"
26
- #include " mlir/Transforms/Passes.h"
27
-
28
- namespace mlir {
29
- #define GEN_PASS_DEF_CONVERTMEMREFTOEMITC
30
- #include " mlir/Conversion/Passes.h.inc"
31
- } // namespace mlir
32
20
33
21
using namespace mlir ;
34
22
35
23
namespace {
24
+ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
25
+ using OpConversionPattern::OpConversionPattern;
36
26
37
- // / Disallow all memrefs even though we only have conversions
38
- // / for memrefs with static shape right now to have good diagnostics.
39
- bool isLegal (Type t) { return !isa<BaseMemRefType>(t); }
27
+ LogicalResult
28
+ matchAndRewrite (memref::AllocaOp op, OpAdaptor operands,
29
+ ConversionPatternRewriter &rewriter) const override {
40
30
41
- template <typename RangeT>
42
- std::enable_if_t <!std::is_convertible<RangeT, Type>::value &&
43
- !std::is_convertible<RangeT, Operation *>::value,
44
- bool >
45
- isLegal (RangeT &&range) {
46
- return llvm::all_of (range, [](Type type) { return isLegal (type); });
47
- }
31
+ if (!op.getType ().hasStaticShape ()) {
32
+ return rewriter.notifyMatchFailure (
33
+ op.getLoc (), " cannot transform alloca with dynamic shape" );
34
+ }
48
35
49
- bool isLegal (Operation *op) {
50
- return isLegal (op->getOperandTypes ()) && isLegal (op->getResultTypes ());
51
- }
36
+ if (op.getAlignment ().value_or (1 ) > 1 ) {
37
+ // TODO: Allow alignment if it is not more than the natural alignment
38
+ // of the C array.
39
+ return rewriter.notifyMatchFailure (
40
+ op.getLoc (), " cannot transform alloca with alignment requirement" );
41
+ }
52
42
53
- bool isSignatureLegal (FunctionType ty) {
54
- return isLegal (llvm::concat<const Type>(ty.getInputs (), ty.getResults ()));
55
- }
43
+ auto resultTy = getTypeConverter ()->convertType (op.getType ());
44
+ if (!resultTy) {
45
+ return rewriter.notifyMatchFailure (op.getLoc (), " cannot convert type" );
46
+ }
47
+ auto noInit = emitc::OpaqueAttr::get (getContext (), " " );
48
+ rewriter.replaceOpWithNewOp <emitc::VariableOp>(op, resultTy, noInit);
49
+ return success ();
50
+ }
51
+ };
56
52
57
53
struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
58
54
using OpConversionPattern::OpConversionPattern;
@@ -61,8 +57,20 @@ struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
61
57
matchAndRewrite (memref::LoadOp op, OpAdaptor operands,
62
58
ConversionPatternRewriter &rewriter) const override {
63
59
64
- rewriter.replaceOpWithNewOp <emitc::SubscriptOp>(op, operands.getMemref (),
65
- operands.getIndices ());
60
+ auto resultTy = getTypeConverter ()->convertType (op.getType ());
61
+ if (!resultTy) {
62
+ return rewriter.notifyMatchFailure (op.getLoc (), " cannot convert type" );
63
+ }
64
+
65
+ auto subscript = rewriter.create <emitc::SubscriptOp>(
66
+ op.getLoc (), operands.getMemref (), operands.getIndices ());
67
+
68
+ auto noInit = emitc::OpaqueAttr::get (getContext (), " " );
69
+ auto var =
70
+ rewriter.create <emitc::VariableOp>(op.getLoc (), resultTy, noInit);
71
+
72
+ rewriter.create <emitc::AssignOp>(op.getLoc (), var, subscript);
73
+ rewriter.replaceOp (op, var);
66
74
return success ();
67
75
}
68
76
};
@@ -81,90 +89,26 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
81
89
return success ();
82
90
}
83
91
};
92
+ } // namespace
84
93
85
- struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
86
- using OpConversionPattern::OpConversionPattern;
87
-
88
- LogicalResult
89
- matchAndRewrite (memref::AllocaOp op, OpAdaptor operands,
90
- ConversionPatternRewriter &rewriter) const override {
91
-
92
- if (!op.getType ().hasStaticShape ()) {
93
- return rewriter.notifyMatchFailure (
94
- op.getLoc (), " cannot transform alloca with dynamic shape" );
95
- }
96
-
97
- if (op.getAlignment ().value_or (1 ) > 1 ) {
98
- // TODO: Allow alignment if it is not more than the natural alignment
99
- // of the C array.
100
- return rewriter.notifyMatchFailure (
101
- op.getLoc (), " cannot transform alloca with alignment requirement" );
102
- }
103
-
104
- auto resultTy = getTypeConverter ()->convertType (op.getType ());
105
- auto noInit = emitc::OpaqueAttr::get (getContext (), " " );
106
- rewriter.replaceOpWithNewOp <emitc::VariableOp>(op, resultTy, noInit);
107
- return success ();
108
- }
109
- };
110
-
111
- struct ConvertMemRefToEmitCPass
112
- : public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
113
- void runOnOperation () override {
114
- TypeConverter converter;
115
- // Pass through for all other types.
116
- converter.addConversion ([](Type type) { return type; });
117
-
118
- converter.addConversion ([](MemRefType memRefType) -> std::optional<Type> {
119
- if (memRefType.hasStaticShape ()) {
94
+ void mlir::populateMemRefToEmitCTypeConversion (TypeConverter &typeConverter) {
95
+ typeConverter.addConversion (
96
+ [&](MemRefType memRefType) -> std::optional<Type> {
97
+ if (!memRefType.hasStaticShape () ||
98
+ !memRefType.getLayout ().isIdentity () || memRefType.getRank () == 0 ) {
99
+ return {};
100
+ }
101
+ Type convertedElementType =
102
+ typeConverter.convertType (memRefType.getElementType ());
103
+ if (!convertedElementType)
104
+ return {};
120
105
return emitc::ArrayType::get (memRefType.getShape (),
121
- memRefType.getElementType ());
122
- }
123
- return {};
124
- });
125
-
126
- converter.addConversion (
127
- [&converter](FunctionType ty) -> std::optional<Type> {
128
- SmallVector<Type> inputs;
129
- if (failed (converter.convertTypes (ty.getInputs (), inputs)))
130
- return std::nullopt;
131
-
132
- SmallVector<Type> results;
133
- if (failed (converter.convertTypes (ty.getResults (), results)))
134
- return std::nullopt;
135
-
136
- return FunctionType::get (ty.getContext (), inputs, results);
137
- });
138
-
139
- RewritePatternSet patterns (&getContext ());
140
- populateMemRefToEmitCConversionPatterns (patterns, converter);
141
-
142
- ConversionTarget target (getContext ());
143
- target.addDynamicallyLegalOp <func::FuncOp>(
144
- [](func::FuncOp op) { return isSignatureLegal (op.getFunctionType ()); });
145
- target.addDynamicallyLegalDialect <func::FuncDialect>(
146
- [](Operation *op) { return isLegal (op); });
147
- target.addIllegalDialect <memref::MemRefDialect>();
148
- target.addLegalDialect <emitc::EmitCDialect>();
149
-
150
- if (failed (applyPartialConversion (getOperation (), target,
151
- std::move (patterns))))
152
- return signalPassFailure ();
153
- }
154
- };
155
- } // namespace
106
+ convertedElementType);
107
+ });
108
+ }
156
109
157
110
void mlir::populateMemRefToEmitCConversionPatterns (RewritePatternSet &patterns,
158
111
TypeConverter &converter) {
159
-
160
- populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
161
- converter);
162
- populateCallOpTypeConversionPattern (patterns, converter);
163
- populateReturnOpTypeConversionPattern (patterns, converter);
164
- patterns.add <ConvertLoad, ConvertStore, ConvertAlloca>(converter,
112
+ patterns.add <ConvertAlloca, ConvertLoad, ConvertStore>(converter,
165
113
patterns.getContext ());
166
114
}
167
-
168
- std::unique_ptr<OperationPass<>> mlir::createConvertMemRefToEmitCPass () {
169
- return std::make_unique<ConvertMemRefToEmitCPass>();
170
- }
0 commit comments