@@ -26,53 +26,6 @@ using namespace mlir::tosa;
26
26
27
27
namespace {
28
28
29
- template <typename TosaOp, typename ... Args>
30
- TosaOp createOpAndInfer (PatternRewriter &rewriter, Location loc, Type resultTy,
31
- Args &&...args) {
32
- auto op = rewriter.create <TosaOp>(loc, resultTy, args...);
33
-
34
- InferShapedTypeOpInterface shapeInterface =
35
- dyn_cast<InferShapedTypeOpInterface>(op.getOperation ());
36
- if (!shapeInterface)
37
- return op;
38
-
39
- SmallVector<ShapedTypeComponents> returnedShapes;
40
- if (shapeInterface
41
- .inferReturnTypeComponents (
42
- op.getContext (), op.getLoc (), op->getOperands (),
43
- op->getDiscardableAttrDictionary (), op->getPropertiesStorage (),
44
- op->getRegions (), returnedShapes)
45
- .failed ())
46
- return op;
47
-
48
- // We need to use the element type of the existing result type to generate
49
- // the new result shaped type. This is because rescale can include a cast to
50
- // different bit-width types and does not have a TypeAttr to define the
51
- // target type.
52
- auto result = op->getResult (0 );
53
- auto predictedShape = returnedShapes[0 ];
54
- auto currentKnowledge =
55
- mlir::tosa::ValueKnowledge::getKnowledgeFromType (resultTy);
56
-
57
- // Compute the knowledge based on the inferred type.
58
- auto inferredKnowledge =
59
- mlir::tosa::ValueKnowledge::getPessimisticValueState ();
60
- inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType ();
61
- inferredKnowledge.hasRank = predictedShape.hasRank ();
62
- if (predictedShape.hasRank ()) {
63
- for (auto dim : predictedShape.getDims ()) {
64
- inferredKnowledge.sizes .push_back (dim);
65
- }
66
- }
67
-
68
- // Compute the new type based on the joined version.
69
- auto newKnowledge =
70
- mlir::tosa::ValueKnowledge::join (currentKnowledge, inferredKnowledge);
71
- auto newTy = newKnowledge.getType ();
72
- result.setType (newTy);
73
- return op;
74
- }
75
-
76
29
class TransposeConvNonStridedConverter
77
30
: public OpRewritePattern<tosa::TransposeConv2DOp> {
78
31
public:
@@ -187,20 +140,20 @@ class TransposeConvStridedConverter
187
140
(weightWidth % stride[1 ]) ? (stride[1 ] - weightWidth % stride[1 ]) : 0 ;
188
141
DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get (
189
142
RankedTensorType::get ({4 , 2 }, rewriter.getI32Type ()), weightPadding);
190
- Value weightPaddingVal = createOpAndInfer <tosa::ConstOp>(
143
+ Value weightPaddingVal = CreateOpAndInferShape <tosa::ConstOp>(
191
144
rewriter, loc, weightPaddingAttr.getType (), weightPaddingAttr);
192
145
193
146
if (op.getQuantizationInfo ().has_value ()) {
194
147
auto quantInfo = op.getQuantizationInfo ().value ();
195
- weight = createOpAndInfer <tosa::PadOp>(
148
+ weight = CreateOpAndInferShape <tosa::PadOp>(
196
149
rewriter, loc, UnrankedTensorType::get (weightETy), weight,
197
150
weightPaddingVal, nullptr ,
198
151
rewriter.getAttr <PadOpQuantizationAttr>(quantInfo.getWeightZp ()));
199
152
200
153
} else {
201
- weight = createOpAndInfer <tosa::PadOp>(rewriter, loc,
202
- UnrankedTensorType::get (weightETy),
203
- weight, weightPaddingVal);
154
+ weight = CreateOpAndInferShape <tosa::PadOp>(
155
+ rewriter, loc, UnrankedTensorType::get (weightETy), weight ,
156
+ weightPaddingVal);
204
157
}
205
158
206
159
weightTy = cast<ShapedType>(weight.getType ());
@@ -212,7 +165,7 @@ class TransposeConvStridedConverter
212
165
outputChannels, weightHeight / stride[0 ],
213
166
stride[0 ], weightWidth / stride[1 ],
214
167
stride[1 ], inputChannels};
215
- weight = createOpAndInfer <tosa::ReshapeOp>(
168
+ weight = CreateOpAndInferShape <tosa::ReshapeOp>(
216
169
rewriter, loc, UnrankedTensorType::get (weightETy), weight,
217
170
rewriter.getDenseI64ArrayAttr (weightReshapeDims0));
218
171
@@ -221,23 +174,23 @@ class TransposeConvStridedConverter
221
174
loc, RankedTensorType::get ({6 }, rewriter.getI32Type ()),
222
175
rewriter.getI32TensorAttr ({2 , 4 , 0 , 1 , 3 , 5 }));
223
176
224
- weight = createOpAndInfer <tosa::TransposeOp>(
177
+ weight = CreateOpAndInferShape <tosa::TransposeOp>(
225
178
rewriter, loc, UnrankedTensorType::get (weightETy), weight,
226
179
transposeWeightVal);
227
180
228
181
// Collapse the strides and output channels into a single dimension.
229
182
llvm::SmallVector<int64_t , 6 > weightReshapeDims1 = {
230
183
outputChannels * stride[0 ] * stride[1 ], weightHeight / stride[0 ],
231
184
weightWidth / stride[1 ], inputChannels};
232
- weight = createOpAndInfer <tosa::ReshapeOp>(
185
+ weight = CreateOpAndInferShape <tosa::ReshapeOp>(
233
186
rewriter, loc, UnrankedTensorType::get (weightETy), weight,
234
187
rewriter.getDenseI64ArrayAttr (weightReshapeDims1));
235
188
ShapedType restridedWeightTy = cast<ShapedType>(weight.getType ());
236
189
237
- weight = createOpAndInfer <tosa::ReverseOp>(
190
+ weight = CreateOpAndInferShape <tosa::ReverseOp>(
238
191
rewriter, loc, UnrankedTensorType::get (weightETy), weight,
239
192
/* axis = */ rewriter.getI32IntegerAttr (1 ));
240
- weight = createOpAndInfer <tosa::ReverseOp>(
193
+ weight = CreateOpAndInferShape <tosa::ReverseOp>(
241
194
rewriter, loc, UnrankedTensorType::get (weightETy), weight,
242
195
/* axis = */ rewriter.getI32IntegerAttr (2 ));
243
196
@@ -251,19 +204,19 @@ class TransposeConvStridedConverter
251
204
DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get (
252
205
RankedTensorType::get ({4 , 2 }, rewriter.getI32Type ()), inputPadding);
253
206
254
- Value inputPaddingVal = createOpAndInfer <tosa::ConstOp>(
207
+ Value inputPaddingVal = CreateOpAndInferShape <tosa::ConstOp>(
255
208
rewriter, loc, inputPaddingAttr.getType (), inputPaddingAttr);
256
209
257
210
if (op.getQuantizationInfo ().has_value ()) {
258
211
auto quantInfo = op.getQuantizationInfo ().value ();
259
- input = createOpAndInfer <tosa::PadOp>(
212
+ input = CreateOpAndInferShape <tosa::PadOp>(
260
213
rewriter, loc, UnrankedTensorType::get (inputETy), input,
261
214
inputPaddingVal, nullptr ,
262
215
rewriter.getAttr <PadOpQuantizationAttr>(quantInfo.getInputZp ()));
263
216
} else {
264
- input = createOpAndInfer <tosa::PadOp>(rewriter, loc,
265
- UnrankedTensorType::get (inputETy),
266
- input, inputPaddingVal);
217
+ input = CreateOpAndInferShape <tosa::PadOp>(
218
+ rewriter, loc, UnrankedTensorType::get (inputETy), input ,
219
+ inputPaddingVal);
267
220
}
268
221
269
222
// We use a zero bias as we need to broadcast the bias.
@@ -279,7 +232,7 @@ class TransposeConvStridedConverter
279
232
// Perform the convolution using the zero bias.
280
233
Value conv2d;
281
234
if (op.getQuantizationInfo ()) {
282
- conv2d = createOpAndInfer <tosa::Conv2DOp>(
235
+ conv2d = CreateOpAndInferShape <tosa::Conv2DOp>(
283
236
rewriter, loc, UnrankedTensorType::get (resultETy), input,
284
237
weight, zeroBias,
285
238
/* pad=*/ rewriter.getDenseI64ArrayAttr ({0 , 0 , 0 , 0 }),
@@ -288,7 +241,7 @@ class TransposeConvStridedConverter
288
241
*op.getQuantizationInfo ())
289
242
.getResult ();
290
243
} else {
291
- conv2d = createOpAndInfer <tosa::Conv2DOp>(
244
+ conv2d = CreateOpAndInferShape <tosa::Conv2DOp>(
292
245
rewriter, loc, UnrankedTensorType::get (resultETy), input,
293
246
weight, zeroBias,
294
247
/* pad=*/ rewriter.getDenseI64ArrayAttr ({0 , 0 , 0 , 0 }),
@@ -307,7 +260,7 @@ class TransposeConvStridedConverter
307
260
// Factor striding out of the convolution result.
308
261
llvm::SmallVector<int64_t , 6 > convReshapeDims0 = {
309
262
batch, convHeight, convWidth, stride[0 ], stride[1 ], outputChannels};
310
- conv2d = createOpAndInfer <tosa::ReshapeOp>(
263
+ conv2d = CreateOpAndInferShape <tosa::ReshapeOp>(
311
264
rewriter, loc, UnrankedTensorType::get (resultETy), conv2d,
312
265
rewriter.getDenseI64ArrayAttr (convReshapeDims0));
313
266
@@ -316,14 +269,14 @@ class TransposeConvStridedConverter
316
269
loc, RankedTensorType::get ({6 }, rewriter.getI32Type ()),
317
270
rewriter.getI32TensorAttr ({0 , 1 , 3 , 2 , 4 , 5 }));
318
271
319
- conv2d = createOpAndInfer <tosa::TransposeOp>(
272
+ conv2d = CreateOpAndInferShape <tosa::TransposeOp>(
320
273
rewriter, loc, UnrankedTensorType::get (convETy), conv2d,
321
274
transposeConvVal);
322
275
323
276
// Fuse striding behavior back into width / height.
324
277
llvm::SmallVector<int64_t , 6 > convReshapeDims1 = {
325
278
batch, convHeight * stride[0 ], convWidth * stride[1 ], outputChannels};
326
- conv2d = createOpAndInfer <tosa::ReshapeOp>(
279
+ conv2d = CreateOpAndInferShape <tosa::ReshapeOp>(
327
280
rewriter, loc, UnrankedTensorType::get (resultETy), conv2d,
328
281
rewriter.getDenseI64ArrayAttr (convReshapeDims1));
329
282
@@ -348,7 +301,7 @@ class TransposeConvStridedConverter
348
301
sliceSize[1 ] = resultSliceHeight;
349
302
sliceSize[2 ] = resultSliceWidth;
350
303
351
- auto slice = createOpAndInfer <tosa::SliceOp>(
304
+ auto slice = CreateOpAndInferShape <tosa::SliceOp>(
352
305
rewriter, loc, UnrankedTensorType::get (resultETy), conv2d,
353
306
rewriter.getDenseI64ArrayAttr (sliceBegin),
354
307
rewriter.getDenseI64ArrayAttr (sliceSize))
@@ -363,10 +316,10 @@ class TransposeConvStridedConverter
363
316
DenseElementsAttr resultPaddingAttr = DenseIntElementsAttr::get (
364
317
RankedTensorType::get ({4 , 2 }, rewriter.getI32Type ()), resultPadding);
365
318
366
- Value resultPaddingVal = createOpAndInfer <tosa::ConstOp>(
319
+ Value resultPaddingVal = CreateOpAndInferShape <tosa::ConstOp>(
367
320
rewriter, loc, resultPaddingAttr.getType (), resultPaddingAttr);
368
321
369
- Value resultPad = createOpAndInfer <tosa::PadOp>(
322
+ Value resultPad = CreateOpAndInferShape <tosa::PadOp>(
370
323
rewriter, loc, UnrankedTensorType::get (resultETy), slice,
371
324
resultPaddingVal);
372
325
0 commit comments