19
19
#include " mlir/IR/PatternMatch.h"
20
20
#include " mlir/Transforms/DialectConversion.h"
21
21
22
+ #include < numeric>
23
+
22
24
using namespace mlir ;
23
25
using namespace tosa ;
24
26
25
- static bool findIntermediateShape (ArrayRef<int64_t > lhsShape,
26
- ArrayRef<int64_t > rhsShape,
27
- SmallVector<int64_t > &intermediateShape,
28
- bool isDynamic) {
29
- if (isDynamic) {
30
- // TODO (natashaknk): Make dynamic intermediate shape not always be rank-1
31
- intermediateShape = {ShapedType::kDynamic };
32
- return true ;
33
- }
27
+ namespace {
34
28
35
- if (lhsShape.empty () || rhsShape.empty ()) {
36
- intermediateShape = {};
37
- return true ;
38
- }
29
+ // Infer the type to which the input of a 'tosa.reshape' op must be cast when
30
+ // lowered.
31
+ TensorType inferReshapeInputType (TypedValue<TensorType> input,
32
+ ArrayRef<int64_t > newShape) {
33
+ // No need to cast input for non-empty target shape
34
+ if (!newShape.empty ())
35
+ return input.getType ();
36
+
37
+ // The input type must be cast into a tensor with the same rank and all static
38
+ // dimensions set to 1. This prevents the generation of a tensor.collapse_shape
39
+ // op that converts a dynamically shaped tensor into a 0D tensor. While such
40
+ // construct is not incorrect on its own, bufferization cannot properly handle
41
+ // it at the moment, so we avoid it.
42
+ SmallVector<int64_t > shape (input.getType ().getRank (), 1 );
43
+ return input.getType ().clone (shape);
44
+ }
45
+
46
+ // Infer the result type of 'tensor.expand_shape' in the collapse-expand
47
+ // pair emitted for a 'tosa.reshape' op.
48
+ TensorType inferReshapeExpandedType (TensorType inputType,
49
+ ArrayRef<int64_t > newShape) {
50
+ // Special case for 0D output tensor. Note: Watch out when using Type::clone()
51
+ // with just '{}', as it will invoke the incorrect overload.
52
+ if (newShape.empty ())
53
+ return inputType.clone (ArrayRef<int64_t >{});
54
+
55
+ // Check if the input is static, and if so, get its total size
56
+ bool inputIsStatic = inputType.hasStaticShape ();
57
+ int64_t totalSize = inputIsStatic ? inputType.getNumElements () : -1 ;
58
+
59
+ // Compute result shape
60
+ bool resultIsStatic = true ;
61
+ auto resultShape = llvm::map_to_vector (newShape, [&](int64_t size) -> int64_t {
62
+ // If this is not a placeholder, do not change it
63
+ if (size >= 0 )
64
+ return size;
65
+
66
+ // If we do not know the total size of the tensor, keep this dimension
67
+ // dynamic in the result shape.
68
+ if (!inputIsStatic) {
69
+ resultIsStatic = false ;
70
+ return ShapedType::kDynamic ;
71
+ }
39
72
73
+ // Calculate the product of all elements in 'newShape' except for the -1
74
+ // placeholder, which we discard by negating the result.
75
+ int64_t totalSizeNoPlaceholder = -std::accumulate (
76
+ newShape.begin (), newShape.end (), 1 , std::multiplies ());
77
+
78
+ // If there is a 0 component in 'newShape', resolve the placeholder as 0.
79
+ if (totalSizeNoPlaceholder == 0 )
80
+ return 0 ;
81
+
82
+ // Resolve the placeholder as the quotient between the total tensor size and
83
+ // the product of all other sizes.
84
+ return totalSize / totalSizeNoPlaceholder;
85
+ });
86
+
87
+ // A syntactic restriction in 'tensor.expand_shape' forbids a dynamically
88
+ // shaped input from being reshaped into a statically shaped result. We may
89
+ // simply turn the first result dimension dynamic to address this.
90
+ if (!inputIsStatic && resultIsStatic)
91
+ resultShape[0 ] = ShapedType::kDynamic ;
92
+
93
+ // The 'tensor.expand_shape' op also forbids a statically shaped input from
94
+ // being reshaped into a dynamically shaped result, but the placeholder
95
+ // inference algorithm above guarantees that this will never be the case.
96
+ assert (!inputIsStatic || resultIsStatic);
97
+
98
+ // Create result type
99
+ return inputType.clone (resultShape);
100
+ }
101
+
102
+ // Infer the result type of 'tensor.collapse_shape' in the collapse-expand
103
+ // pair emitted for a 'tosa.reshape' op.
104
+ TensorType inferReshapeCollapsedType (TensorType lhsType, TensorType rhsType) {
105
+ auto lhsShape = lhsType.getShape ();
106
+ auto rhsShape = rhsType.getShape ();
107
+
108
+ if (lhsShape.empty () || rhsShape.empty ())
109
+ return lhsType.clone (ArrayRef<int64_t >{});
110
+
111
+ if (ShapedType::isDynamicShape (lhsShape) || ShapedType::isDynamicShape (rhsShape))
112
+ return lhsType.clone ({ShapedType::kDynamic });
113
+
114
+ SmallVector<int64_t > intermediateShape;
40
115
unsigned currLhsDim = 0 , currRhsDim = 0 ;
41
116
while (currLhsDim < lhsShape.size () && currRhsDim < rhsShape.size ()) {
42
117
int64_t rhsSize = rhsShape[currRhsDim];
@@ -62,174 +137,113 @@ static bool findIntermediateShape(ArrayRef<int64_t> lhsShape,
62
137
currLhsDim++;
63
138
}
64
139
65
- // If the iterators didn't reach the end and their leftover dimensions are not
66
- // equal to 1 an intermediate shape was not found.
67
- while (currLhsDim < lhsShape.size ()) {
68
- if (lhsShape[currLhsDim++] != 1 ) {
69
- return false ;
70
- }
140
+ // Static shapes are guaranteed to be compatible by the op verifier, so all
141
+ // leftover dimensions should be 1.
142
+ for (; currLhsDim < lhsShape.size (); currLhsDim++) {
143
+ assert (lhsShape[currLhsDim] == 1 );
71
144
}
72
-
73
- while (currRhsDim < rhsShape.size ()) {
74
- if (rhsShape[currRhsDim++] != 1 ) {
75
- return false ;
76
- }
145
+ for (; currRhsDim < rhsShape.size (); currRhsDim++) {
146
+ assert (rhsShape[currRhsDim] == 1 );
77
147
}
78
-
79
- return true ;
148
+
149
+ return lhsType. clone (intermediateShape) ;
80
150
}
81
151
82
- static bool createReassociationMapsForCollapse (
83
- PatternRewriter &rewriter, ArrayRef< int64_t > srcShape,
84
- ArrayRef< int64_t > dstShape,
85
- SmallVector<ReassociationExprs, 4 > &reassociationMap, bool isDynamic) {
152
+ SmallVector<ReassociationExprs>
153
+ createReassociationMapForCollapse (OpBuilder &builder, Type srcType, Type dstType) {
154
+ auto srcShape = cast<TensorType>(srcType). getShape ();
155
+ auto dstShape = cast<TensorType>(dstType). getShape ();
86
156
87
- // If the shape is dynamic, create a map for collapsing into one dimension.
88
- if (isDynamic) {
89
- SmallVector<AffineExpr, 2 > exprs;
90
- for (int i = 0 , s = srcShape.size (); i < s; ++i)
91
- exprs.push_back (rewriter.getAffineDimExpr (i));
92
- reassociationMap = {exprs};
93
- return true ;
94
- }
157
+ if (srcShape.empty () || dstShape.empty ())
158
+ return {};
95
159
96
- if (dstShape.empty ()) {
97
- reassociationMap = {};
98
- return true ;
160
+ if (ShapedType::isDynamicShape (srcShape) || ShapedType::isDynamicShape (dstShape)) {
161
+ assert (dstShape.size () == 1 );
162
+ SmallVector<AffineExpr, 2 > exprs;
163
+ for (auto i : llvm::seq<int64_t >(srcShape.size ()))
164
+ exprs.push_back (builder.getAffineDimExpr (i));
165
+ return {exprs};
99
166
}
100
167
101
- reassociationMap. resize (dstShape.size ());
168
+ SmallVector<ReassociationExprs> reassociationMap (dstShape.size ());
102
169
unsigned currSrcDim = 0 , currDstDim = 0 ;
103
170
while (currSrcDim < srcShape.size () && currDstDim < dstShape.size ()) {
104
171
int64_t dstSize = dstShape[currDstDim];
105
172
int64_t srcSize = srcShape[currSrcDim];
106
173
while (srcSize < dstSize && currSrcDim < srcShape.size ()) {
107
174
reassociationMap[currDstDim].push_back (
108
- rewriter .getAffineDimExpr (currSrcDim++));
175
+ builder .getAffineDimExpr (currSrcDim++));
109
176
srcSize *= srcShape[currSrcDim];
110
177
}
111
178
if (srcSize == dstSize) {
112
179
reassociationMap[currDstDim].push_back (
113
- rewriter .getAffineDimExpr (currSrcDim++));
180
+ builder .getAffineDimExpr (currSrcDim++));
114
181
// If the next dim in collapsedShape is not 1, treat subsequent dims in
115
182
// expandedShape which are 1 to be collapsed.
116
183
if (currDstDim == dstShape.size () - 1 || dstShape[currDstDim + 1 ] != 1 ) {
117
184
while (currSrcDim < srcShape.size () && srcShape[currSrcDim] == 1 ) {
118
185
reassociationMap[currDstDim].push_back (
119
- rewriter .getAffineDimExpr (currSrcDim++));
186
+ builder .getAffineDimExpr (currSrcDim++));
120
187
}
121
188
}
122
189
}
123
190
currDstDim++;
124
191
}
125
192
126
- // If both iterators didn't reach the end, we have leftover dimentions which
127
- // implies that we have a mismatch in shape.
128
- return currSrcDim == srcShape.size () && currDstDim == dstShape.size ();
193
+ // If the source and target shapes are compatible, both iterators must have
194
+ // reached the end. This condition is guaranteed by the op verifier for
195
+ // static shapes.
196
+ assert (currSrcDim == srcShape.size () && currDstDim == dstShape.size ());
197
+ return reassociationMap;
129
198
}
130
199
131
- namespace {
132
- Value createCollapse (ConversionPatternRewriter &rewriter, Location loc,
133
- ShapedType resultTy, Value operand) {
134
- ShapedType operandTy = cast<ShapedType>(operand.getType ());
135
- if (resultTy == operandTy)
136
- return operand;
137
-
138
- bool isDynamic = !operandTy.hasStaticShape ();
139
-
140
- if (isDynamic && resultTy.getRank () != 1 ) {
141
- (void )rewriter.notifyMatchFailure (
142
- loc, " Cannot collapse dynamic dims to more than one dimension" );
143
- return {};
144
- }
145
-
146
- SmallVector<ReassociationExprs, 4 > reassociationMap;
147
- if (!createReassociationMapsForCollapse (rewriter, operandTy.getShape (),
148
- resultTy.getShape (),
149
- reassociationMap, isDynamic)) {
150
- (void )rewriter.notifyMatchFailure (
151
- loc, " tosa.reshape Attempting to collapse into an incompatible shape" );
152
- return {};
153
- }
154
-
155
- SmallVector<int64_t > intermediateShape;
156
- if (!findIntermediateShape (operandTy.getShape (), resultTy.getShape (),
157
- intermediateShape, isDynamic)) {
158
- (void )rewriter.notifyMatchFailure (
159
- loc, " tosa.reshape Cannot collapse into given shape" );
160
- return {};
161
- }
162
- return rewriter.create <tensor::CollapseShapeOp>(loc, resultTy, operand,
163
- reassociationMap);
200
+ // Create a tensor.collapse_shape op that reshapes the input into the given
201
+ // result type.
202
+ Value createCollapse (OpBuilder &builder, Location loc, TensorType resultType,
203
+ Value input) {
204
+ auto reassociationMap =
205
+ createReassociationMapForCollapse (builder, input.getType (), resultType);
206
+ return builder.createOrFold <tensor::CollapseShapeOp>(loc, resultType, input,
207
+ reassociationMap);
164
208
}
165
209
166
- Value createExpand (ConversionPatternRewriter &rewriter, Location loc,
167
- ShapedType resultTy, Value operand) {
168
- ShapedType operandTy = cast<ShapedType>(operand.getType ());
169
- if (resultTy == operandTy)
170
- return operand;
171
-
172
- bool isDynamic = !operandTy.hasStaticShape ();
173
-
174
- if (isDynamic && operandTy.getRank () != 1 ) {
175
- (void )rewriter.notifyMatchFailure (
176
- loc, " Cannot expand dynamic dims from more than one dimension" );
177
- return {};
178
- }
179
-
180
- SmallVector<ReassociationExprs, 4 > reassociationMap;
181
- if (!createReassociationMapsForCollapse (rewriter, resultTy.getShape (),
182
- operandTy.getShape (),
183
- reassociationMap, isDynamic)) {
184
- (void )rewriter.notifyMatchFailure (
185
- loc, " tosa.reshape Attempting to expand into an incompatible shape" );
186
- return {};
187
- }
188
-
189
- SmallVector<int64_t > intermediateShape;
190
- if (!findIntermediateShape (operandTy.getShape (), resultTy.getShape (),
191
- intermediateShape, isDynamic) ||
192
- intermediateShape != operandTy.getShape ()) {
193
- (void )rewriter.notifyMatchFailure (
194
- loc, " tosa.reshape Cannot expand into given shape" );
195
- return {};
196
- }
197
- return rewriter.create <tensor::ExpandShapeOp>(loc, resultTy, operand,
198
- reassociationMap);
210
+ // Create a tensor.expand_shape op that reshapes the input into the given result
211
+ // type.
212
+ Value createExpand (OpBuilder &builder, Location loc, TensorType resultType,
213
+ Value input) {
214
+ auto reassociationMap =
215
+ createReassociationMapForCollapse (builder, resultType, input.getType ());
216
+ return builder.createOrFold <tensor::ExpandShapeOp>(loc, resultType, input,
217
+ reassociationMap);
199
218
}
200
219
201
- class ReshapeConverterCollapseExpand
202
- : public OpConversionPattern<tosa::ReshapeOp> {
220
+ class ReshapeConverter : public OpConversionPattern <tosa::ReshapeOp> {
203
221
public:
204
222
using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
205
223
206
224
LogicalResult
207
225
matchAndRewrite (tosa::ReshapeOp reshape, OpAdaptor adaptor,
208
226
ConversionPatternRewriter &rewriter) const final {
209
- ShapedType operandTy = cast<ShapedType>(adaptor.getInput1 ().getType ());
210
- ShapedType resultTy = cast<ShapedType>(reshape.getType ());
211
- bool isDynamic = !operandTy.hasStaticShape ();
212
-
213
- SmallVector<int64_t > intermediateShape;
214
- if (!findIntermediateShape (resultTy.getShape (), operandTy.getShape (),
215
- intermediateShape, isDynamic)) {
216
- return rewriter.notifyMatchFailure (
217
- reshape, " tosa.reshape Cannot identify an intermediate shape between "
218
- " the given two shapes" );
219
- }
220
- auto intermediateTy = RankedTensorType::get (
221
- intermediateShape, reshape.getType ().getElementType ());
222
-
223
- Value collapse = createCollapse (rewriter, reshape.getLoc (), intermediateTy,
224
- adaptor.getInput1 ());
225
- if (!collapse)
226
- return failure ();
227
-
228
- Value expand = createExpand (rewriter, reshape.getLoc (), resultTy, collapse);
229
- if (!expand)
230
- return failure ();
231
-
232
- rewriter.replaceOp (reshape, expand);
227
+ auto loc = reshape.getLoc ();
228
+ auto resultType = reshape.getResult ().getType ();
229
+ auto input = reshape.getInput1 ();
230
+ auto newShape = reshape.getNewShape ();
231
+
232
+ // Infer all intermediate types
233
+ auto inputType = inferReshapeInputType (input, newShape);
234
+ auto expandedType = inferReshapeExpandedType (inputType, newShape);
235
+ auto collapsedType = inferReshapeCollapsedType (inputType, expandedType);
236
+
237
+ // Cast input if needed
238
+ auto castInput = rewriter.createOrFold <tensor::CastOp>(loc, inputType, input);
239
+
240
+ // Emit collaspe-expand pair
241
+ auto collapsed = createCollapse (rewriter, loc, collapsedType, castInput);
242
+ auto expanded = createExpand (rewriter, loc, expandedType, collapsed);
243
+
244
+ // Cast to final result type if needed
245
+ auto result = rewriter.createOrFold <tensor::CastOp>(loc, resultType, expanded);
246
+ rewriter.replaceOp (reshape, result);
233
247
return success ();
234
248
}
235
249
};
@@ -416,8 +430,10 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
416
430
417
431
void mlir::tosa::populateTosaToTensorConversionPatterns (
418
432
RewritePatternSet *patterns) {
419
- patterns->add <SliceConverter, PadConverter, ConcatConverter>(
420
- patterns->getContext ());
421
-
422
- patterns->add <ReshapeConverterCollapseExpand>(patterns->getContext ());
433
+ patterns->add <
434
+ ConcatConverter,
435
+ PadConverter,
436
+ ReshapeConverter,
437
+ SliceConverter
438
+ >(patterns->getContext ());
423
439
}
0 commit comments