19
19
#include " mlir/IR/PatternMatch.h"
20
20
#include " mlir/Transforms/DialectConversion.h"
21
21
22
- #include < numeric>
23
-
24
22
using namespace mlir ;
25
23
using namespace tosa ;
26
24
27
- namespace {
28
-
29
- // Infer the type to which the input of a 'tosa.reshape' op must be cast when
30
- // lowered.
31
- TensorType inferReshapeInputType (TypedValue<TensorType> input,
32
- ArrayRef<int64_t > newShape) {
33
- // No need to cast input for non-empty target shape
34
- if (!newShape.empty ())
35
- return input.getType ();
36
-
37
- // The input type must be cast into a tensor with the same rank and all static
38
- // dimensions set to 1. This prevents the generation of a tensor.collapse_shape
39
- // op that converts a dynamically shaped tensor into a 0D tensor. While such
40
- // construct is not incorrect on its own, bufferization cannot properly handle
41
- // it at the moment, so we avoid it.
42
- SmallVector<int64_t > shape (input.getType ().getRank (), 1 );
43
- return input.getType ().clone (shape);
44
- }
45
-
46
- // Infer the result type of 'tensor.expand_shape' in the collapse-expand
47
- // pair emitted for a 'tosa.reshape' op.
48
- TensorType inferReshapeExpandedType (TensorType inputType,
49
- ArrayRef<int64_t > newShape) {
50
- // Special case for 0D output tensor. Note: Watch out when using Type::clone()
51
- // with just '{}', as it will invoke the incorrect overload.
52
- if (newShape.empty ())
53
- return inputType.clone (ArrayRef<int64_t >{});
54
-
55
- // Check if the input is static, and if so, get its total size
56
- bool inputIsStatic = inputType.hasStaticShape ();
57
- int64_t totalSize = inputIsStatic ? inputType.getNumElements () : -1 ;
58
-
59
- // Compute result shape
60
- bool resultIsStatic = true ;
61
- auto resultShape = llvm::map_to_vector (newShape, [&](int64_t size) -> int64_t {
62
- // If this is not a placeholder, do not change it
63
- if (size >= 0 )
64
- return size;
65
-
66
- // If we do not know the total size of the tensor, keep this dimension
67
- // dynamic in the result shape.
68
- if (!inputIsStatic) {
69
- resultIsStatic = false ;
70
- return ShapedType::kDynamic ;
71
- }
72
-
73
- // Calculate the product of all elements in 'newShape' except for the -1
74
- // placeholder, which we discard by negating the result.
75
- int64_t totalSizeNoPlaceholder = -std::accumulate (
76
- newShape.begin (), newShape.end (), 1 , std::multiplies ());
77
-
78
- // If there is a 0 component in 'newShape', resolve the placeholder as 0.
79
- if (totalSizeNoPlaceholder == 0 )
80
- return 0 ;
81
-
82
- // Resolve the placeholder as the quotient between the total tensor size and
83
- // the product of all other sizes.
84
- return totalSize / totalSizeNoPlaceholder;
85
- });
86
-
87
- // A syntactic restriction in 'tensor.expand_shape' forbids a dynamically
88
- // shaped input from being reshaped into a statically shaped result. We may
89
- // simply turn the first result dimension dynamic to address this.
90
- if (!inputIsStatic && resultIsStatic)
91
- resultShape[0 ] = ShapedType::kDynamic ;
92
-
93
- // The 'tensor.expand_shape' op also forbids a statically shaped input from
94
- // being reshaped into a dynamically shaped result, but the placeholder
95
- // inference algorithm above guarantees that this will never be the case.
96
- assert (!inputIsStatic || resultIsStatic);
97
-
98
- // Create result type
99
- return inputType.clone (resultShape);
100
- }
101
-
102
- // Infer the result type of 'tensor.collapse_shape' in the collapse-expand
103
- // pair emitted for a 'tosa.reshape' op.
104
- TensorType inferReshapeCollapsedType (TensorType lhsType, TensorType rhsType) {
105
- auto lhsShape = lhsType.getShape ();
106
- auto rhsShape = rhsType.getShape ();
107
-
108
- if (lhsShape.empty () || rhsShape.empty ())
109
- return lhsType.clone (ArrayRef<int64_t >{});
25
+ static bool findIntermediateShape (ArrayRef<int64_t > lhsShape,
26
+ ArrayRef<int64_t > rhsShape,
27
+ SmallVector<int64_t > &intermediateShape,
28
+ bool isDynamic) {
29
+ if (isDynamic) {
30
+ // TODO (natashaknk): Make dynamic intermediate shape not always be rank-1
31
+ intermediateShape = {ShapedType::kDynamic };
32
+ return true ;
33
+ }
110
34
111
- if (ShapedType::isDynamicShape (lhsShape) || ShapedType::isDynamicShape (rhsShape))
112
- return lhsType.clone ({ShapedType::kDynamic });
35
+ if (lhsShape.empty () || rhsShape.empty ()) {
36
+ intermediateShape = {};
37
+ return true ;
38
+ }
113
39
114
- SmallVector<int64_t > intermediateShape;
115
40
unsigned currLhsDim = 0 , currRhsDim = 0 ;
116
41
while (currLhsDim < lhsShape.size () && currRhsDim < rhsShape.size ()) {
117
42
int64_t rhsSize = rhsShape[currRhsDim];
@@ -137,113 +62,174 @@ TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) {
137
62
currLhsDim++;
138
63
}
139
64
140
- // Static shapes are guaranteed to be compatible by the op verifier, so all
141
- // leftover dimensions should be 1.
142
- for (; currLhsDim < lhsShape.size (); currLhsDim++) {
143
- assert (lhsShape[currLhsDim] == 1 );
65
+ // If the iterators didn't reach the end and their leftover dimensions are not
66
+ // equal to 1 an intermediate shape was not found.
67
+ while (currLhsDim < lhsShape.size ()) {
68
+ if (lhsShape[currLhsDim++] != 1 ) {
69
+ return false ;
70
+ }
144
71
}
145
- for (; currRhsDim < rhsShape.size (); currRhsDim++) {
146
- assert (rhsShape[currRhsDim] == 1 );
72
+
73
+ while (currRhsDim < rhsShape.size ()) {
74
+ if (rhsShape[currRhsDim++] != 1 ) {
75
+ return false ;
76
+ }
147
77
}
148
-
149
- return lhsType. clone (intermediateShape) ;
78
+
79
+ return true ;
150
80
}
151
81
152
- SmallVector<ReassociationExprs>
153
- createReassociationMapForCollapse (OpBuilder &builder, Type srcType, Type dstType) {
154
- auto srcShape = cast<TensorType>(srcType). getShape ();
155
- auto dstShape = cast<TensorType>(dstType). getShape ();
82
+ static bool createReassociationMapsForCollapse (
83
+ PatternRewriter &rewriter, ArrayRef< int64_t > srcShape,
84
+ ArrayRef< int64_t > dstShape,
85
+ SmallVector<ReassociationExprs, 4 > &reassociationMap, bool isDynamic) {
156
86
157
- if (srcShape.empty () || dstShape.empty ())
158
- return {};
159
-
160
- if (ShapedType::isDynamicShape (srcShape) || ShapedType::isDynamicShape (dstShape)) {
161
- assert (dstShape.size () == 1 );
87
+ // If the shape is dynamic, create a map for collapsing into one dimension.
88
+ if (isDynamic) {
162
89
SmallVector<AffineExpr, 2 > exprs;
163
- for (auto i : llvm::seq<int64_t >(srcShape.size ()))
164
- exprs.push_back (builder.getAffineDimExpr (i));
165
- return {exprs};
90
+ for (int i = 0 , s = srcShape.size (); i < s; ++i)
91
+ exprs.push_back (rewriter.getAffineDimExpr (i));
92
+ reassociationMap = {exprs};
93
+ return true ;
94
+ }
95
+
96
+ if (dstShape.empty ()) {
97
+ reassociationMap = {};
98
+ return true ;
166
99
}
167
100
168
- SmallVector<ReassociationExprs> reassociationMap (dstShape.size ());
101
+ reassociationMap. resize (dstShape.size ());
169
102
unsigned currSrcDim = 0 , currDstDim = 0 ;
170
103
while (currSrcDim < srcShape.size () && currDstDim < dstShape.size ()) {
171
104
int64_t dstSize = dstShape[currDstDim];
172
105
int64_t srcSize = srcShape[currSrcDim];
173
106
while (srcSize < dstSize && currSrcDim < srcShape.size ()) {
174
107
reassociationMap[currDstDim].push_back (
175
- builder .getAffineDimExpr (currSrcDim++));
108
+ rewriter .getAffineDimExpr (currSrcDim++));
176
109
srcSize *= srcShape[currSrcDim];
177
110
}
178
111
if (srcSize == dstSize) {
179
112
reassociationMap[currDstDim].push_back (
180
- builder .getAffineDimExpr (currSrcDim++));
113
+ rewriter .getAffineDimExpr (currSrcDim++));
181
114
// If the next dim in collapsedShape is not 1, treat subsequent dims in
182
115
// expandedShape which are 1 to be collapsed.
183
116
if (currDstDim == dstShape.size () - 1 || dstShape[currDstDim + 1 ] != 1 ) {
184
117
while (currSrcDim < srcShape.size () && srcShape[currSrcDim] == 1 ) {
185
118
reassociationMap[currDstDim].push_back (
186
- builder .getAffineDimExpr (currSrcDim++));
119
+ rewriter .getAffineDimExpr (currSrcDim++));
187
120
}
188
121
}
189
122
}
190
123
currDstDim++;
191
124
}
192
125
193
- // If the source and target shapes are compatible, both iterators must have
194
- // reached the end. This condition is guaranteed by the op verifier for
195
- // static shapes.
196
- assert (currSrcDim == srcShape.size () && currDstDim == dstShape.size ());
197
- return reassociationMap;
126
+ // If both iterators didn't reach the end, we have leftover dimentions which
127
+ // implies that we have a mismatch in shape.
128
+ return currSrcDim == srcShape.size () && currDstDim == dstShape.size ();
198
129
}
199
130
200
- // Create a tensor.collapse_shape op that reshapes the input into the given
201
- // result type.
202
- Value createCollapse (OpBuilder &builder, Location loc, TensorType resultType,
203
- Value input) {
204
- auto reassociationMap =
205
- createReassociationMapForCollapse (builder, input.getType (), resultType);
206
- return builder.createOrFold <tensor::CollapseShapeOp>(loc, resultType, input,
207
- reassociationMap);
131
+ namespace {
132
+ Value createCollapse (ConversionPatternRewriter &rewriter, Location loc,
133
+ ShapedType resultTy, Value operand) {
134
+ ShapedType operandTy = cast<ShapedType>(operand.getType ());
135
+ if (resultTy == operandTy)
136
+ return operand;
137
+
138
+ bool isDynamic = !operandTy.hasStaticShape ();
139
+
140
+ if (isDynamic && resultTy.getRank () != 1 ) {
141
+ (void )rewriter.notifyMatchFailure (
142
+ loc, " Cannot collapse dynamic dims to more than one dimension" );
143
+ return {};
144
+ }
145
+
146
+ SmallVector<ReassociationExprs, 4 > reassociationMap;
147
+ if (!createReassociationMapsForCollapse (rewriter, operandTy.getShape (),
148
+ resultTy.getShape (),
149
+ reassociationMap, isDynamic)) {
150
+ (void )rewriter.notifyMatchFailure (
151
+ loc, " tosa.reshape Attempting to collapse into an incompatible shape" );
152
+ return {};
153
+ }
154
+
155
+ SmallVector<int64_t > intermediateShape;
156
+ if (!findIntermediateShape (operandTy.getShape (), resultTy.getShape (),
157
+ intermediateShape, isDynamic)) {
158
+ (void )rewriter.notifyMatchFailure (
159
+ loc, " tosa.reshape Cannot collapse into given shape" );
160
+ return {};
161
+ }
162
+ return rewriter.create <tensor::CollapseShapeOp>(loc, resultTy, operand,
163
+ reassociationMap);
208
164
}
209
165
210
- // Create a tensor.expand_shape op that reshapes the input into the given result
211
- // type.
212
- Value createExpand (OpBuilder &builder, Location loc, TensorType resultType,
213
- Value input) {
214
- auto reassociationMap =
215
- createReassociationMapForCollapse (builder, resultType, input.getType ());
216
- return builder.createOrFold <tensor::ExpandShapeOp>(loc, resultType, input,
217
- reassociationMap);
166
+ Value createExpand (ConversionPatternRewriter &rewriter, Location loc,
167
+ ShapedType resultTy, Value operand) {
168
+ ShapedType operandTy = cast<ShapedType>(operand.getType ());
169
+ if (resultTy == operandTy)
170
+ return operand;
171
+
172
+ bool isDynamic = !operandTy.hasStaticShape ();
173
+
174
+ if (isDynamic && operandTy.getRank () != 1 ) {
175
+ (void )rewriter.notifyMatchFailure (
176
+ loc, " Cannot expand dynamic dims from more than one dimension" );
177
+ return {};
178
+ }
179
+
180
+ SmallVector<ReassociationExprs, 4 > reassociationMap;
181
+ if (!createReassociationMapsForCollapse (rewriter, resultTy.getShape (),
182
+ operandTy.getShape (),
183
+ reassociationMap, isDynamic)) {
184
+ (void )rewriter.notifyMatchFailure (
185
+ loc, " tosa.reshape Attempting to expand into an incompatible shape" );
186
+ return {};
187
+ }
188
+
189
+ SmallVector<int64_t > intermediateShape;
190
+ if (!findIntermediateShape (operandTy.getShape (), resultTy.getShape (),
191
+ intermediateShape, isDynamic) ||
192
+ intermediateShape != operandTy.getShape ()) {
193
+ (void )rewriter.notifyMatchFailure (
194
+ loc, " tosa.reshape Cannot expand into given shape" );
195
+ return {};
196
+ }
197
+ return rewriter.create <tensor::ExpandShapeOp>(loc, resultTy, operand,
198
+ reassociationMap);
218
199
}
219
200
220
- class ReshapeConverter : public OpConversionPattern <tosa::ReshapeOp> {
201
+ class ReshapeConverterCollapseExpand
202
+ : public OpConversionPattern<tosa::ReshapeOp> {
221
203
public:
222
204
using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
223
205
224
206
LogicalResult
225
207
matchAndRewrite (tosa::ReshapeOp reshape, OpAdaptor adaptor,
226
208
ConversionPatternRewriter &rewriter) const final {
227
- auto loc = reshape.getLoc ();
228
- auto resultType = reshape.getResult ().getType ();
229
- auto input = reshape.getInput1 ();
230
- auto newShape = reshape.getNewShape ();
231
-
232
- // Infer all intermediate types
233
- auto inputType = inferReshapeInputType (input, newShape);
234
- auto expandedType = inferReshapeExpandedType (inputType, newShape);
235
- auto collapsedType = inferReshapeCollapsedType (inputType, expandedType);
236
-
237
- // Cast input if needed
238
- auto castInput = rewriter.createOrFold <tensor::CastOp>(loc, inputType, input);
239
-
240
- // Emit collaspe-expand pair
241
- auto collapsed = createCollapse (rewriter, loc, collapsedType, castInput);
242
- auto expanded = createExpand (rewriter, loc, expandedType, collapsed);
243
-
244
- // Cast to final result type if needed
245
- auto result = rewriter.createOrFold <tensor::CastOp>(loc, resultType, expanded);
246
- rewriter.replaceOp (reshape, result);
209
+ ShapedType operandTy = cast<ShapedType>(adaptor.getInput1 ().getType ());
210
+ ShapedType resultTy = cast<ShapedType>(reshape.getType ());
211
+ bool isDynamic = !operandTy.hasStaticShape ();
212
+
213
+ SmallVector<int64_t > intermediateShape;
214
+ if (!findIntermediateShape (resultTy.getShape (), operandTy.getShape (),
215
+ intermediateShape, isDynamic)) {
216
+ return rewriter.notifyMatchFailure (
217
+ reshape, " tosa.reshape Cannot identify an intermediate shape between "
218
+ " the given two shapes" );
219
+ }
220
+ auto intermediateTy = RankedTensorType::get (
221
+ intermediateShape, reshape.getType ().getElementType ());
222
+
223
+ Value collapse = createCollapse (rewriter, reshape.getLoc (), intermediateTy,
224
+ adaptor.getInput1 ());
225
+ if (!collapse)
226
+ return failure ();
227
+
228
+ Value expand = createExpand (rewriter, reshape.getLoc (), resultTy, collapse);
229
+ if (!expand)
230
+ return failure ();
231
+
232
+ rewriter.replaceOp (reshape, expand);
247
233
return success ();
248
234
}
249
235
};
@@ -430,10 +416,8 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
430
416
431
417
void mlir::tosa::populateTosaToTensorConversionPatterns (
432
418
RewritePatternSet *patterns) {
433
- patterns->add <
434
- ConcatConverter,
435
- PadConverter,
436
- ReshapeConverter,
437
- SliceConverter
438
- >(patterns->getContext ());
419
+ patterns->add <SliceConverter, PadConverter, ConcatConverter>(
420
+ patterns->getContext ());
421
+
422
+ patterns->add <ReshapeConverterCollapseExpand>(patterns->getContext ());
439
423
}
0 commit comments