Skip to content

Commit c62371b

Browse files
TessilTai78641
authored andcommitted
[mlir][tosa] Replace UniformQuantizedType by the more generic QuantizedType in Conv verifiers
Change-Id: Ie1961af931864f801914a62976bc988881ee075e Signed-off-by: Tai Ly <[email protected]>
1 parent 756dab4 commit c62371b

File tree

1 file changed

+21
-9
lines changed

1 file changed

+21
-9
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -240,16 +240,13 @@ static LogicalResult verifyConvOp(T op) {
240240
bool biasIsFloat = llvm::isa<FloatType>(biasEType);
241241
bool resultIsFloat = llvm::isa<FloatType>(resultEType);
242242

243-
if (auto quantType =
244-
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
243+
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
245244
inputEType = quantType.getStorageType();
246245

247-
if (auto quantType =
248-
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(biasEType))
246+
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
249247
biasEType = quantType.getStorageType();
250248

251-
if (auto quantType =
252-
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultEType))
249+
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
253250
resultEType = quantType.getStorageType();
254251

255252
if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
@@ -346,8 +343,7 @@ static LogicalResult verifyConvOpModes(T op) {
346343
auto inputEType =
347344
llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
348345

349-
if (auto quantType =
350-
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
346+
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
351347
inputEType = quantType.getStorageType();
352348

353349
auto accType = op.getAccType();
@@ -369,7 +365,23 @@ static LogicalResult verifyConvOpModes(T op) {
369365
if (inputEType.isF32() && !accType.isF32())
370366
return op.emitOpError("accumulator type for f32 tensor is not f32");
371367

372-
return success();
368+
auto resultEType =
369+
llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
370+
371+
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
372+
resultEType = quantType.getStorageType();
373+
374+
// check allowed input/result element types combinations
375+
if ((inputEType.isInteger(8) && resultEType.isInteger(32)) ||
376+
(inputEType.isInteger(16) && resultEType.isInteger(48)) ||
377+
(isa<Float8E5M2Type>(inputEType) && resultEType.isF16()) ||
378+
(isa<Float8E4M3FNType>(inputEType) && resultEType.isF16()) ||
379+
(inputEType.isF16() && resultEType.isF16()) ||
380+
(inputEType.isBF16() && resultEType.isBF16()) ||
381+
(inputEType.isF32() && resultEType.isF32()))
382+
return success();
383+
384+
return op.emitOpError("input/output element types are incompatible.");
373385
}
374386

375387
// verify that inType and outType have same element types

0 commit comments

Comments
 (0)