Skip to content

Commit b67944b

Browse files
TessilTai78641
authored andcommitted
[mlir][tosa] Replace UniformQuantizedType by the more generic QuantizedType in Conv verifiers
also fixed buildTransConvOpWithQuantInfo to insert input/weight zp operands Change-Id: Ie1961af931864f801914a62976bc988881ee075e Signed-off-by: Tai Ly <[email protected]>
1 parent 4d7192a commit b67944b

File tree

1 file changed

+26
-16
lines changed

1 file changed

+26
-16
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -240,16 +240,13 @@ static LogicalResult verifyConvOp(T op) {
240240
bool biasIsFloat = llvm::isa<FloatType>(biasEType);
241241
bool resultIsFloat = llvm::isa<FloatType>(resultEType);
242242

243-
if (auto quantType =
244-
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
243+
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
245244
inputEType = quantType.getStorageType();
246245

247-
if (auto quantType =
248-
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(biasEType))
246+
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
249247
biasEType = quantType.getStorageType();
250248

251-
if (auto quantType =
252-
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultEType))
249+
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
253250
resultEType = quantType.getStorageType();
254251

255252
if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
@@ -346,8 +343,7 @@ static LogicalResult verifyConvOpModes(T op) {
346343
auto inputEType =
347344
llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
348345

349-
if (auto quantType =
350-
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
346+
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
351347
inputEType = quantType.getStorageType();
352348

353349
auto accType = op.getAccType();
@@ -369,7 +365,23 @@ static LogicalResult verifyConvOpModes(T op) {
369365
if (inputEType.isF32() && !accType.isF32())
370366
return op.emitOpError("accumulator type for f32 tensor is not f32");
371367

372-
return success();
368+
auto resultEType =
369+
llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
370+
371+
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
372+
resultEType = quantType.getStorageType();
373+
374+
// check allowed input/result element types combinations
375+
if ((inputEType.isInteger(8) && resultEType.isInteger(32)) ||
376+
(inputEType.isInteger(16) && resultEType.isInteger(48)) ||
377+
(isa<Float8E5M2Type>(inputEType) && resultEType.isF16()) ||
378+
(isa<Float8E4M3FNType>(inputEType) && resultEType.isF16()) ||
379+
(inputEType.isF16() && resultEType.isF16()) ||
380+
(inputEType.isBF16() && resultEType.isBF16()) ||
381+
(inputEType.isF32() && resultEType.isF32()))
382+
return success();
383+
384+
return op.emitOpError("input/output element types are incompatible.");
373385
}
374386

375387
// verify that inType and outType have same element types
@@ -519,7 +531,8 @@ static void buildTransConvOpWithQuantInfo(
519531
OpBuilder &builder, OperationState &result, Type outputType, Value input,
520532
Value weight, Value bias, DenseI64ArrayAttr outpad,
521533
DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType) {
522-
result.addOperands({input, weight, bias});
534+
auto zps = createZPsAsConst(builder, input, weight);
535+
result.addOperands({input, weight, bias, zps.first, zps.second});
523536
result.addAttribute("out_pad", outpad);
524537
result.addAttribute("stride", stride);
525538
result.addAttribute("out_shape", outputShape);
@@ -2478,18 +2491,15 @@ LogicalResult mlir::tosa::getZeroPoint(ElementsAttr zpAttr, int64_t &zp) {
24782491
return failure();
24792492
}
24802493

2481-
// Create a rank-0 const tensor for zero point of the source tensor.
2494+
// Create a rank-1 const tensor for zero point of the source tensor.
24822495
std::optional<Value> mlir::tosa::createZeroPointTensor(OpBuilder &builder,
24832496
Location loc,
24842497
Type srcElemType,
24852498
int64_t zp) {
2486-
if (auto quantType =
2487-
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(srcElemType))
2488-
srcElemType = quantType.getStorageType();
2489-
2490-
auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
2499+
srcElemType = getElementTypeOrSelf(srcElemType);
24912500
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcElemType))
24922501
srcElemType = quantType.getStorageType();
2502+
auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
24932503
if (llvm::isa<FloatType>(srcElemType)) {
24942504
auto zpAttr = DenseElementsAttr::get(
24952505
zpType, builder.getFloatAttr(srcElemType, static_cast<double>(zp)));

0 commit comments

Comments
 (0)