@@ -210,15 +210,26 @@ template <typename T>
210
210
static LogicalResult verifyConvOp (T op) {
211
211
// All TOSA conv ops have an input() and weight().
212
212
auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput ().getType ());
213
- auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight ().getType ());
213
+
214
+ RankedTensorType weightType;
215
+ if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>)
216
+ weightType = llvm::dyn_cast<RankedTensorType>(op.getFilter ().getType ());
217
+ else
218
+ weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight ().getType ());
214
219
215
220
// Must be ranked tensor types
216
221
if (!inputType) {
217
222
op.emitOpError (" expect a ranked tensor for input, got " ) << op.getInput ();
218
223
return failure ();
219
224
}
220
225
if (!weightType) {
221
- op.emitOpError (" expect a ranked tensor for weight, got " ) << op.getWeight ();
226
+ if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>) {
227
+ op.emitOpError (" expect a ranked tensor for filter, got " )
228
+ << op.getFilter ();
229
+ } else {
230
+ op.emitOpError (" expect a ranked tensor for weight, got " )
231
+ << op.getWeight ();
232
+ }
222
233
return failure ();
223
234
}
224
235
@@ -271,6 +282,38 @@ LogicalResult tosa::ConstOp::verify() {
271
282
return success ();
272
283
}
273
284
285
+ template <typename T>
286
+ static LogicalResult verifyConvOpModes (T op) {
287
+ auto inputEType =
288
+ llvm::cast<ShapedType>(op.getInput ().getType ()).getElementType ();
289
+
290
+ if (auto quantType =
291
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
292
+ inputEType = quantType.getStorageType ();
293
+
294
+ auto accType = op.getAccType ();
295
+ if (inputEType.isInteger (8 ) && !accType.isInteger (32 ))
296
+ return op.emitOpError (" accumulator type for i8 tensor is not i32" );
297
+
298
+ if (inputEType.isInteger (16 ) && !accType.isInteger (48 ))
299
+ return op.emitOpError (" accumulator type for i16 tensor is not i48" );
300
+
301
+ if ((inputEType.isFloat8E5M2 () || inputEType.isFloat8E4M3 ()) &&
302
+ !accType.isF16 ())
303
+ return op.emitOpError (" accumulator type for f8 tensor is not f16" );
304
+
305
+ if (inputEType.isF16 () && !(accType.isF16 () || accType.isF32 ()))
306
+ return op.emitOpError (" accumulator type for f16 tensor is not f16/f32" );
307
+
308
+ if (inputEType.isBF16 () && !accType.isF32 ())
309
+ return op.emitOpError (" accumulator type for bf16 tensor is not f32" );
310
+
311
+ if (inputEType.isF32 () && !accType.isF32 ())
312
+ return op.emitOpError (" accumulator type for f32 tensor is not f32" );
313
+
314
+ return success ();
315
+ }
316
+
274
317
LogicalResult tosa::ArgMaxOp::verify () {
275
318
// Ensure output is of 32-bit integer
276
319
const auto resultETy = llvm::cast<ShapedType>(getType ()).getElementType ();
@@ -368,12 +411,14 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
368
411
Type outputType, Value input, Value weight,
369
412
Value bias, DenseI64ArrayAttr pad,
370
413
DenseI64ArrayAttr stride,
371
- DenseI64ArrayAttr dilation) {
414
+ DenseI64ArrayAttr dilation,
415
+ TypeAttr accType) {
372
416
373
417
result.addOperands ({input, weight, bias});
374
418
result.addAttribute (" pad" , pad);
375
419
result.addAttribute (" stride" , stride);
376
420
result.addAttribute (" dilation" , dilation);
421
+ result.addAttribute (" acc_type" , accType);
377
422
378
423
auto quantAttr = buildConvOpQuantizationAttr (builder, input, weight);
379
424
if (quantAttr) {
@@ -390,11 +435,12 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
390
435
static void buildTransConvOpWithQuantInfo (
391
436
OpBuilder &builder, OperationState &result, Type outputType, Value input,
392
437
Value weight, Value bias, DenseI64ArrayAttr outpad,
393
- DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape) {
438
+ DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType ) {
394
439
result.addOperands ({input, weight, bias});
395
440
result.addAttribute (" out_pad" , outpad);
396
441
result.addAttribute (" stride" , stride);
397
442
result.addAttribute (" out_shape" , outputShape);
443
+ result.addAttribute (" acc_type" , accType);
398
444
auto quantAttr = ::buildConvOpQuantizationAttr (builder, input, weight);
399
445
400
446
if (quantAttr) {
@@ -1599,7 +1645,11 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
1599
1645
return success ();
1600
1646
}
1601
1647
1602
- LogicalResult Conv2DOp::verify () { return verifyConvOp (*this ); }
1648
+ LogicalResult Conv2DOp::verify () {
1649
+ if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed ())
1650
+ return failure ();
1651
+ return success ();
1652
+ }
1603
1653
1604
1654
LogicalResult Conv3DOp::inferReturnTypeComponents (
1605
1655
MLIRContext *context, ::std::optional<Location> location,
@@ -1671,7 +1721,11 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
1671
1721
return success ();
1672
1722
}
1673
1723
1674
- LogicalResult Conv3DOp::verify () { return verifyConvOp (*this ); }
1724
+ LogicalResult Conv3DOp::verify () {
1725
+ if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed ())
1726
+ return failure ();
1727
+ return success ();
1728
+ }
1675
1729
1676
1730
LogicalResult AvgPool2dOp::inferReturnTypeComponents (
1677
1731
MLIRContext *context, ::std::optional<Location> location,
@@ -1766,7 +1820,11 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
1766
1820
return success ();
1767
1821
}
1768
1822
1769
- LogicalResult DepthwiseConv2DOp::verify () { return verifyConvOp (*this ); }
1823
+ LogicalResult DepthwiseConv2DOp::verify () {
1824
+ if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed ())
1825
+ return failure ();
1826
+ return success ();
1827
+ }
1770
1828
1771
1829
LogicalResult TransposeConv2DOp::inferReturnTypeComponents (
1772
1830
MLIRContext *context, ::std::optional<Location> location,
@@ -1832,6 +1890,12 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
1832
1890
return success ();
1833
1891
}
1834
1892
1893
+ LogicalResult TransposeConv2DOp::verify () {
1894
+ if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed ())
1895
+ return failure ();
1896
+ return success ();
1897
+ }
1898
+
1835
1899
LogicalResult IfOp::inferReturnTypeComponents (
1836
1900
MLIRContext *context, ::std::optional<Location> location,
1837
1901
IfOp::Adaptor adaptor,
0 commit comments