@@ -240,16 +240,13 @@ static LogicalResult verifyConvOp(T op) {
240
240
bool biasIsFloat = llvm::isa<FloatType>(biasEType);
241
241
bool resultIsFloat = llvm::isa<FloatType>(resultEType);
242
242
243
- if (auto quantType =
244
- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
243
+ if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
245
244
inputEType = quantType.getStorageType ();
246
245
247
- if (auto quantType =
248
- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(biasEType))
246
+ if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
249
247
biasEType = quantType.getStorageType ();
250
248
251
- if (auto quantType =
252
- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultEType))
249
+ if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
253
250
resultEType = quantType.getStorageType ();
254
251
255
252
if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
@@ -346,8 +343,7 @@ static LogicalResult verifyConvOpModes(T op) {
346
343
auto inputEType =
347
344
llvm::cast<ShapedType>(op.getInput ().getType ()).getElementType ();
348
345
349
- if (auto quantType =
350
- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
346
+ if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
351
347
inputEType = quantType.getStorageType ();
352
348
353
349
auto accType = op.getAccType ();
@@ -369,7 +365,23 @@ static LogicalResult verifyConvOpModes(T op) {
369
365
if (inputEType.isF32 () && !accType.isF32 ())
370
366
return op.emitOpError (" accumulator type for f32 tensor is not f32" );
371
367
372
- return success ();
368
+ auto resultEType =
369
+ llvm::cast<ShapedType>(op.getResult ().getType ()).getElementType ();
370
+
371
+ if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
372
+ resultEType = quantType.getStorageType ();
373
+
374
+ // check allowed input/result element types combinations
375
+ if ((inputEType.isInteger (8 ) && resultEType.isInteger (32 )) ||
376
+ (inputEType.isInteger (16 ) && resultEType.isInteger (48 )) ||
377
+ (isa<Float8E5M2Type>(inputEType) && resultEType.isF16 ()) ||
378
+ (isa<Float8E4M3FNType>(inputEType) && resultEType.isF16 ()) ||
379
+ (inputEType.isF16 () && resultEType.isF16 ()) ||
380
+ (inputEType.isBF16 () && resultEType.isBF16 ()) ||
381
+ (inputEType.isF32 () && resultEType.isF32 ()))
382
+ return success ();
383
+
384
+ return op.emitOpError (" input/output element types are incompatible." );
373
385
}
374
386
375
387
// verify that inType and outType have same element types
0 commit comments