@@ -240,16 +240,13 @@ static LogicalResult verifyConvOp(T op) {
240
240
bool biasIsFloat = llvm::isa<FloatType>(biasEType);
241
241
bool resultIsFloat = llvm::isa<FloatType>(resultEType);
242
242
243
- if (auto quantType =
244
- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
243
+ if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
245
244
inputEType = quantType.getStorageType ();
246
245
247
- if (auto quantType =
248
- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(biasEType))
246
+ if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
249
247
biasEType = quantType.getStorageType ();
250
248
251
- if (auto quantType =
252
- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultEType))
249
+ if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
253
250
resultEType = quantType.getStorageType ();
254
251
255
252
if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
@@ -346,8 +343,7 @@ static LogicalResult verifyConvOpModes(T op) {
346
343
auto inputEType =
347
344
llvm::cast<ShapedType>(op.getInput ().getType ()).getElementType ();
348
345
349
- if (auto quantType =
350
- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
346
+ if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
351
347
inputEType = quantType.getStorageType ();
352
348
353
349
auto accType = op.getAccType ();
@@ -369,7 +365,23 @@ static LogicalResult verifyConvOpModes(T op) {
369
365
if (inputEType.isF32 () && !accType.isF32 ())
370
366
return op.emitOpError (" accumulator type for f32 tensor is not f32" );
371
367
372
- return success ();
368
+ auto resultEType =
369
+ llvm::cast<ShapedType>(op.getResult ().getType ()).getElementType ();
370
+
371
+ if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
372
+ resultEType = quantType.getStorageType ();
373
+
374
+ // check allowed input/result element types combinations
375
+ if ((inputEType.isInteger (8 ) && resultEType.isInteger (32 )) ||
376
+ (inputEType.isInteger (16 ) && resultEType.isInteger (48 )) ||
377
+ (isa<Float8E5M2Type>(inputEType) && resultEType.isF16 ()) ||
378
+ (isa<Float8E4M3FNType>(inputEType) && resultEType.isF16 ()) ||
379
+ (inputEType.isF16 () && resultEType.isF16 ()) ||
380
+ (inputEType.isBF16 () && resultEType.isBF16 ()) ||
381
+ (inputEType.isF32 () && resultEType.isF32 ()))
382
+ return success ();
383
+
384
+ return op.emitOpError (" input/output element types are incompatible." );
373
385
}
374
386
375
387
// verify that inType and outType have same element types
@@ -519,7 +531,8 @@ static void buildTransConvOpWithQuantInfo(
519
531
OpBuilder &builder, OperationState &result, Type outputType, Value input,
520
532
Value weight, Value bias, DenseI64ArrayAttr outpad,
521
533
DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType) {
522
- result.addOperands ({input, weight, bias});
534
+ auto zps = createZPsAsConst (builder, input, weight);
535
+ result.addOperands ({input, weight, bias, zps.first , zps.second });
523
536
result.addAttribute (" out_pad" , outpad);
524
537
result.addAttribute (" stride" , stride);
525
538
result.addAttribute (" out_shape" , outputShape);
@@ -2478,18 +2491,15 @@ LogicalResult mlir::tosa::getZeroPoint(ElementsAttr zpAttr, int64_t &zp) {
2478
2491
return failure ();
2479
2492
}
2480
2493
2481
- // Create a rank-0 const tensor for zero point of the source tensor.
2494
+ // Create a rank-1 const tensor for zero point of the source tensor.
2482
2495
std::optional<Value> mlir::tosa::createZeroPointTensor (OpBuilder &builder,
2483
2496
Location loc,
2484
2497
Type srcElemType,
2485
2498
int64_t zp) {
2486
- if (auto quantType =
2487
- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(srcElemType))
2488
- srcElemType = quantType.getStorageType ();
2489
-
2490
- auto zpType = mlir::RankedTensorType::get ({1 }, srcElemType);
2499
+ srcElemType = getElementTypeOrSelf (srcElemType);
2491
2500
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcElemType))
2492
2501
srcElemType = quantType.getStorageType ();
2502
+ auto zpType = mlir::RankedTensorType::get ({1 }, srcElemType);
2493
2503
if (llvm::isa<FloatType>(srcElemType)) {
2494
2504
auto zpAttr = DenseElementsAttr::get (
2495
2505
zpType, builder.getFloatAttr (srcElemType, static_cast <double >(zp)));
0 commit comments