@@ -14,52 +14,46 @@ using namespace mlir;
14
14
using namespace mlir ::func;
15
15
16
16
// ===----------------------------------------------------------------------===//
17
- // ValueDecomposer
17
+ // Helper functions
18
18
// ===----------------------------------------------------------------------===//
19
19
20
- void ValueDecomposer::decomposeValue (OpBuilder &builder, Location loc,
21
- Type type, Value value,
22
- SmallVectorImpl<Value> &results) {
23
- for (auto &conversion : decomposeValueConversions)
24
- if (conversion (builder, loc, type, value, results))
25
- return ;
26
- results.push_back (value);
20
+ // / If the given value can be decomposed with the type converter, decompose it.
21
+ // / Otherwise, return the given value.
22
+ static SmallVector<Value> decomposeValue (OpBuilder &builder, Location loc,
23
+ Value value,
24
+ const TypeConverter *converter) {
25
+ // Try to convert the given value's type. If that fails, just return the
26
+ // given value.
27
+ SmallVector<Type> convertedTypes;
28
+ if (failed (converter->convertType (value.getType (), convertedTypes)))
29
+ return {value};
30
+ if (convertedTypes.empty ())
31
+ return {};
32
+
33
+ // If the given value's type is already legal, just return the given value.
34
+ TypeRange convertedTypeRange (convertedTypes);
35
+ if (convertedTypeRange == TypeRange (value.getType ()))
36
+ return {value};
37
+
38
+ // Try to materialize a target conversion. If the materialization did not
39
+ // produce values of the requested type, the materialization failed. Just
40
+ // return the given value in that case.
41
+ SmallVector<Value> result = converter->materializeTargetConversion (
42
+ builder, loc, convertedTypeRange, value);
43
+ if (result.empty ())
44
+ return {value};
45
+ return result;
27
46
}
28
47
29
- // ===----------------------------------------------------------------------===//
30
- // DecomposeCallGraphTypesOpConversionPattern
31
- // ===----------------------------------------------------------------------===//
32
-
33
- namespace {
34
- // / Base OpConversionPattern class to make a ValueDecomposer available to
35
- // / inherited patterns.
36
- template <typename SourceOp>
37
- class DecomposeCallGraphTypesOpConversionPattern
38
- : public OpConversionPattern<SourceOp> {
39
- public:
40
- DecomposeCallGraphTypesOpConversionPattern (const TypeConverter &typeConverter,
41
- MLIRContext *context,
42
- ValueDecomposer &decomposer,
43
- PatternBenefit benefit = 1 )
44
- : OpConversionPattern<SourceOp>(typeConverter, context, benefit),
45
- decomposer (decomposer) {}
46
-
47
- protected:
48
- ValueDecomposer &decomposer;
49
- };
50
- } // namespace
51
-
52
48
// ===----------------------------------------------------------------------===//
53
49
// DecomposeCallGraphTypesForFuncArgs
54
50
// ===----------------------------------------------------------------------===//
55
51
56
52
namespace {
57
- // / Expand function arguments according to the provided TypeConverter and
58
- // / ValueDecomposer.
53
+ // / Expand function arguments according to the provided TypeConverter.
59
54
struct DecomposeCallGraphTypesForFuncArgs
60
- : public DecomposeCallGraphTypesOpConversionPattern<func::FuncOp> {
61
- using DecomposeCallGraphTypesOpConversionPattern::
62
- DecomposeCallGraphTypesOpConversionPattern;
55
+ : public OpConversionPattern<func::FuncOp> {
56
+ using OpConversionPattern::OpConversionPattern;
63
57
64
58
LogicalResult
65
59
matchAndRewrite (func::FuncOp op, OpAdaptor adaptor,
@@ -100,19 +94,22 @@ struct DecomposeCallGraphTypesForFuncArgs
100
94
// ===----------------------------------------------------------------------===//
101
95
102
96
namespace {
103
- // / Expand return operands according to the provided TypeConverter and
104
- // / ValueDecomposer.
97
+ // / Expand return operands according to the provided TypeConverter.
105
98
struct DecomposeCallGraphTypesForReturnOp
106
- : public DecomposeCallGraphTypesOpConversionPattern <ReturnOp> {
107
- using DecomposeCallGraphTypesOpConversionPattern::
108
- DecomposeCallGraphTypesOpConversionPattern;
99
+ : public OpConversionPattern <ReturnOp> {
100
+ using OpConversionPattern::OpConversionPattern;
101
+
109
102
LogicalResult
110
103
matchAndRewrite (ReturnOp op, OpAdaptor adaptor,
111
104
ConversionPatternRewriter &rewriter) const final {
112
105
SmallVector<Value, 2 > newOperands;
113
- for (Value operand : adaptor.getOperands ())
114
- decomposer.decomposeValue (rewriter, op.getLoc (), operand.getType (),
115
- operand, newOperands);
106
+ for (Value operand : adaptor.getOperands ()) {
107
+ // TODO: We can directly take the values from the adaptor once this is a
108
+ // 1:N conversion pattern.
109
+ llvm::append_range (newOperands,
110
+ decomposeValue (rewriter, operand.getLoc (), operand,
111
+ getTypeConverter ()));
112
+ }
116
113
rewriter.replaceOpWithNewOp <ReturnOp>(op, newOperands);
117
114
return success ();
118
115
}
@@ -124,22 +121,23 @@ struct DecomposeCallGraphTypesForReturnOp
124
121
// ===----------------------------------------------------------------------===//
125
122
126
123
namespace {
127
- // / Expand call op operands and results according to the provided TypeConverter
128
- // / and ValueDecomposer.
129
- struct DecomposeCallGraphTypesForCallOp
130
- : public DecomposeCallGraphTypesOpConversionPattern<CallOp> {
131
- using DecomposeCallGraphTypesOpConversionPattern::
132
- DecomposeCallGraphTypesOpConversionPattern;
124
+ // / Expand call op operands and results according to the provided TypeConverter.
125
+ struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern <CallOp> {
126
+ using OpConversionPattern::OpConversionPattern;
133
127
134
128
LogicalResult
135
129
matchAndRewrite (CallOp op, OpAdaptor adaptor,
136
130
ConversionPatternRewriter &rewriter) const final {
137
131
138
132
// Create the operands list of the new `CallOp`.
139
133
SmallVector<Value, 2 > newOperands;
140
- for (Value operand : adaptor.getOperands ())
141
- decomposer.decomposeValue (rewriter, op.getLoc (), operand.getType (),
142
- operand, newOperands);
134
+ for (Value operand : adaptor.getOperands ()) {
135
+ // TODO: We can directly take the values from the adaptor once this is a
136
+ // 1:N conversion pattern.
137
+ llvm::append_range (newOperands,
138
+ decomposeValue (rewriter, operand.getLoc (), operand,
139
+ getTypeConverter ()));
140
+ }
143
141
144
142
// Create the new result types for the new `CallOp` and track the indices in
145
143
// the new call op's results that correspond to the old call op's results.
@@ -189,9 +187,8 @@ struct DecomposeCallGraphTypesForCallOp
189
187
190
188
void mlir::populateDecomposeCallGraphTypesPatterns (
191
189
MLIRContext *context, const TypeConverter &typeConverter,
192
- ValueDecomposer &decomposer, RewritePatternSet &patterns) {
190
+ RewritePatternSet &patterns) {
193
191
patterns
194
192
.add <DecomposeCallGraphTypesForCallOp, DecomposeCallGraphTypesForFuncArgs,
195
- DecomposeCallGraphTypesForReturnOp>(typeConverter, context,
196
- decomposer);
193
+ DecomposeCallGraphTypesForReturnOp>(typeConverter, context);
197
194
}
0 commit comments