@@ -210,15 +210,26 @@ template <typename T>
210
210
static LogicalResult verifyConvOp (T op) {
211
211
// All TOSA conv ops have an input() and weight().
212
212
auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput ().getType ());
213
- auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight ().getType ());
213
+
214
+ RankedTensorType weightType;
215
+ if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>)
216
+ weightType = llvm::dyn_cast<RankedTensorType>(op.getFilter ().getType ());
217
+ else
218
+ weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight ().getType ());
214
219
215
220
// Must be ranked tensor types
216
221
if (!inputType) {
217
222
op.emitOpError (" expect a ranked tensor for input, got " ) << op.getInput ();
218
223
return failure ();
219
224
}
220
225
if (!weightType) {
221
- op.emitOpError (" expect a ranked tensor for weight, got " ) << op.getWeight ();
226
+ if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>) {
227
+ op.emitOpError (" expect a ranked tensor for filter, got " )
228
+ << op.getFilter ();
229
+ } else {
230
+ op.emitOpError (" expect a ranked tensor for weight, got " )
231
+ << op.getWeight ();
232
+ }
222
233
return failure ();
223
234
}
224
235
@@ -271,6 +282,38 @@ LogicalResult tosa::ConstOp::verify() {
271
282
return success ();
272
283
}
273
284
285
+ template <typename T>
286
+ static LogicalResult verifyConvOpModes (T op) {
287
+ auto inputEType =
288
+ llvm::cast<ShapedType>(op.getInput ().getType ()).getElementType ();
289
+
290
+ if (auto quantType =
291
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
292
+ inputEType = quantType.getStorageType ();
293
+
294
+ auto accType = op.getAccType ();
295
+ if (inputEType.isInteger (8 ) && !accType.isInteger (32 ))
296
+ return op.emitOpError (" accumulator type for i8 tensor is not i32" );
297
+
298
+ if (inputEType.isInteger (16 ) && !accType.isInteger (48 ))
299
+ return op.emitOpError (" accumulator type for i16 tensor is not i48" );
300
+
301
+ if ((inputEType.isFloat8E5M2 () || inputEType.isFloat8E4M3 ()) &&
302
+ !accType.isF16 ())
303
+ return op.emitOpError (" accumulator type for f8 tensor is not f16" );
304
+
305
+ if (inputEType.isF16 () && !(accType.isF16 () || accType.isF32 ()))
306
+ return op.emitOpError (" accumulator type for f16 tensor is not f16/f32" );
307
+
308
+ if (inputEType.isBF16 () && !accType.isF32 ())
309
+ return op.emitOpError (" accumulator type for bf16 tensor is not f32" );
310
+
311
+ if (inputEType.isF32 () && !accType.isF32 ())
312
+ return op.emitOpError (" accumulator type for f32 tensor is not f32" );
313
+
314
+ return success ();
315
+ }
316
+
274
317
LogicalResult tosa::ArgMaxOp::verify () {
275
318
// Ensure output is of 32-bit integer
276
319
const auto resultETy = llvm::cast<ShapedType>(getType ()).getElementType ();
@@ -368,12 +411,14 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
368
411
Type outputType, Value input, Value weight,
369
412
Value bias, DenseI64ArrayAttr pad,
370
413
DenseI64ArrayAttr stride,
371
- DenseI64ArrayAttr dilation) {
414
+ DenseI64ArrayAttr dilation,
415
+ TypeAttr accType) {
372
416
373
417
result.addOperands ({input, weight, bias});
374
418
result.addAttribute (" pad" , pad);
375
419
result.addAttribute (" stride" , stride);
376
420
result.addAttribute (" dilation" , dilation);
421
+ result.addAttribute (" acc_type" , accType);
377
422
378
423
auto quantAttr = buildConvOpQuantizationAttr (builder, input, weight);
379
424
if (quantAttr) {
@@ -390,11 +435,12 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
390
435
static void buildTransConvOpWithQuantInfo (
391
436
OpBuilder &builder, OperationState &result, Type outputType, Value input,
392
437
Value weight, Value bias, DenseI64ArrayAttr outpad,
393
- DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape) {
438
+ DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType ) {
394
439
result.addOperands ({input, weight, bias});
395
440
result.addAttribute (" out_pad" , outpad);
396
441
result.addAttribute (" stride" , stride);
397
442
result.addAttribute (" out_shape" , outputShape);
443
+ result.addAttribute (" acc_type" , accType);
398
444
auto quantAttr = ::buildConvOpQuantizationAttr (builder, input, weight);
399
445
400
446
if (quantAttr) {
@@ -1595,7 +1641,11 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
1595
1641
return success ();
1596
1642
}
1597
1643
1598
- LogicalResult Conv2DOp::verify () { return verifyConvOp (*this ); }
1644
+ LogicalResult Conv2DOp::verify () {
1645
+ if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed ())
1646
+ return failure ();
1647
+ return success ();
1648
+ }
1599
1649
1600
1650
LogicalResult Conv3DOp::inferReturnTypeComponents (
1601
1651
MLIRContext *context, ::std::optional<Location> location,
@@ -1667,7 +1717,11 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
1667
1717
return success ();
1668
1718
}
1669
1719
1670
- LogicalResult Conv3DOp::verify () { return verifyConvOp (*this ); }
1720
+ LogicalResult Conv3DOp::verify () {
1721
+ if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed ())
1722
+ return failure ();
1723
+ return success ();
1724
+ }
1671
1725
1672
1726
LogicalResult AvgPool2dOp::inferReturnTypeComponents (
1673
1727
MLIRContext *context, ::std::optional<Location> location,
@@ -1762,7 +1816,11 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
1762
1816
return success ();
1763
1817
}
1764
1818
1765
- LogicalResult DepthwiseConv2DOp::verify () { return verifyConvOp (*this ); }
1819
+ LogicalResult DepthwiseConv2DOp::verify () {
1820
+ if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed ())
1821
+ return failure ();
1822
+ return success ();
1823
+ }
1766
1824
1767
1825
LogicalResult TransposeConv2DOp::inferReturnTypeComponents (
1768
1826
MLIRContext *context, ::std::optional<Location> location,
@@ -1828,6 +1886,12 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
1828
1886
return success ();
1829
1887
}
1830
1888
1889
+ LogicalResult TransposeConv2DOp::verify () {
1890
+ if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed ())
1891
+ return failure ();
1892
+ return success ();
1893
+ }
1894
+
1831
1895
LogicalResult IfOp::inferReturnTypeComponents (
1832
1896
MLIRContext *context, ::std::optional<Location> location,
1833
1897
IfOp::Adaptor adaptor,
0 commit comments