@@ -271,6 +271,55 @@ LogicalResult tosa::ConstOp::verify() {
271
271
return success ();
272
272
}
273
273
274
+ template <typename T>
275
+ static LogicalResult verifyConvOpModes (T op) {
276
+ auto inputEType =
277
+ llvm::cast<ShapedType>(op.getInput ().getType ()).getElementType ();
278
+
279
+ if (auto quantType =
280
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
281
+ inputEType = quantType.getStorageType ();
282
+
283
+ auto accType = op.getAccType ();
284
+ if (inputEType.isInteger (8 ) && !accType.isInteger (32 ))
285
+ return op.emitOpError (" accumulator type for i8 tensor is not i32" );
286
+
287
+ if (inputEType.isInteger (16 ) && !accType.isInteger (48 ))
288
+ return op.emitOpError (" accumulator type for i16 tensor is not i48" );
289
+
290
+ if ((inputEType.isFloat8E5M2 () || inputEType.isFloat8E4M3FN ()) &&
291
+ !accType.isF16 ())
292
+ return op.emitOpError (" accumulator type for f8 tensor is not f16" );
293
+
294
+ if (inputEType.isF16 () && !(accType.isF16 () || accType.isF32 ()))
295
+ return op.emitOpError (" accumulator type for f16 tensor is not f16/f32" );
296
+
297
+ if (inputEType.isBF16 () && !accType.isF32 ())
298
+ return op.emitOpError (" accumulator type for bf16 tensor is not f32" );
299
+
300
+ if (inputEType.isF32 () && !accType.isF32 ())
301
+ return op.emitOpError (" accumulator type for f32 tensor is not f32" );
302
+
303
+ auto resultEType =
304
+ llvm::cast<ShapedType>(op.getResult ().getType ()).getElementType ();
305
+
306
+ if (auto quantType =
307
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultEType))
308
+ resultEType = quantType.getStorageType ();
309
+
310
+ // check allowed input/result element types combinations
311
+ if ((inputEType.isInteger (8 ) && resultEType.isInteger (32 )) ||
312
+ (inputEType.isInteger (16 ) && resultEType.isInteger (48 )) ||
313
+ (inputEType.isFloat8E5M2 () && resultEType.isF16 ()) ||
314
+ (inputEType.isFloat8E4M3FN () && resultEType.isF16 ()) ||
315
+ (inputEType.isF16 () && resultEType.isF16 ()) ||
316
+ (inputEType.isBF16 () && resultEType.isBF16 ()) ||
317
+ (inputEType.isF32 () && resultEType.isF32 ()))
318
+ return success ();
319
+
320
+ return op.emitOpError (" input/output element types are incompatible." );
321
+ }
322
+
274
323
LogicalResult tosa::ArgMaxOp::verify () {
275
324
// Ensure output is of 32-bit integer
276
325
const auto resultETy = llvm::cast<ShapedType>(getType ()).getElementType ();
@@ -368,12 +417,14 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
368
417
Type outputType, Value input, Value weight,
369
418
Value bias, DenseI64ArrayAttr pad,
370
419
DenseI64ArrayAttr stride,
371
- DenseI64ArrayAttr dilation) {
420
+ DenseI64ArrayAttr dilation,
421
+ TypeAttr accType) {
372
422
373
423
result.addOperands ({input, weight, bias});
374
424
result.addAttribute (" pad" , pad);
375
425
result.addAttribute (" stride" , stride);
376
426
result.addAttribute (" dilation" , dilation);
427
+ result.addAttribute (" acc_type" , accType);
377
428
378
429
auto quantAttr = buildConvOpQuantizationAttr (builder, input, weight);
379
430
if (quantAttr) {
@@ -390,11 +441,12 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
390
441
static void buildTransConvOpWithQuantInfo (
391
442
OpBuilder &builder, OperationState &result, Type outputType, Value input,
392
443
Value weight, Value bias, DenseI64ArrayAttr outpad,
393
- DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape) {
444
+ DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType ) {
394
445
result.addOperands ({input, weight, bias});
395
446
result.addAttribute (" out_pad" , outpad);
396
447
result.addAttribute (" stride" , stride);
397
448
result.addAttribute (" out_shape" , outputShape);
449
+ result.addAttribute (" acc_type" , accType);
398
450
auto quantAttr = ::buildConvOpQuantizationAttr (builder, input, weight);
399
451
400
452
if (quantAttr) {
@@ -1595,7 +1647,11 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
1595
1647
return success ();
1596
1648
}
1597
1649
1598
- LogicalResult Conv2DOp::verify () { return verifyConvOp (*this ); }
1650
+ LogicalResult Conv2DOp::verify () {
1651
+ if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed ())
1652
+ return failure ();
1653
+ return success ();
1654
+ }
1599
1655
1600
1656
LogicalResult Conv3DOp::inferReturnTypeComponents (
1601
1657
MLIRContext *context, ::std::optional<Location> location,
@@ -1667,7 +1723,11 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
1667
1723
return success ();
1668
1724
}
1669
1725
1670
- LogicalResult Conv3DOp::verify () { return verifyConvOp (*this ); }
1726
+ LogicalResult Conv3DOp::verify () {
1727
+ if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed ())
1728
+ return failure ();
1729
+ return success ();
1730
+ }
1671
1731
1672
1732
LogicalResult AvgPool2dOp::inferReturnTypeComponents (
1673
1733
MLIRContext *context, ::std::optional<Location> location,
@@ -1762,7 +1822,11 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
1762
1822
return success ();
1763
1823
}
1764
1824
1765
- LogicalResult DepthwiseConv2DOp::verify () { return verifyConvOp (*this ); }
1825
+ LogicalResult DepthwiseConv2DOp::verify () {
1826
+ if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed ())
1827
+ return failure ();
1828
+ return success ();
1829
+ }
1766
1830
1767
1831
LogicalResult TransposeConv2DOp::inferReturnTypeComponents (
1768
1832
MLIRContext *context, ::std::optional<Location> location,
0 commit comments