Skip to content

[mlir][tosa] Replace UniformQuantizedType by the more generic Quantiz… #126275

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 11, 2025

Conversation

Tai78641
Copy link
Contributor

@Tai78641 Tai78641 commented Feb 7, 2025

…edType in Conv verifiers

Change-Id: Ie1961af931864f801914a62976bc988881ee075e

@llvmbot
Copy link
Member

llvmbot commented Feb 7, 2025

@llvm/pr-subscribers-mlir-tosa

Author: Tai Ly (Tai78641)

Changes

…edType in Conv verifiers

also fixed buildTransConvOpWithQuantInfo to insert input/weight zp operands

Change-Id: Ie1961af931864f801914a62976bc988881ee075e


Full diff: https://github.com/llvm/llvm-project/pull/126275.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+26-16)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 031c279ff09e275..6143a9d23a00394 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -240,16 +240,13 @@ static LogicalResult verifyConvOp(T op) {
   bool biasIsFloat = llvm::isa<FloatType>(biasEType);
   bool resultIsFloat = llvm::isa<FloatType>(resultEType);
 
-  if (auto quantType =
-          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
+  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
     inputEType = quantType.getStorageType();
 
-  if (auto quantType =
-          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(biasEType))
+  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
     biasEType = quantType.getStorageType();
 
-  if (auto quantType =
-          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultEType))
+  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
     resultEType = quantType.getStorageType();
 
   if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
@@ -346,8 +343,7 @@ static LogicalResult verifyConvOpModes(T op) {
   auto inputEType =
       llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
 
-  if (auto quantType =
-          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
+  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
     inputEType = quantType.getStorageType();
 
   auto accType = op.getAccType();
@@ -369,7 +365,23 @@ static LogicalResult verifyConvOpModes(T op) {
   if (inputEType.isF32() && !accType.isF32())
     return op.emitOpError("accumulator type for f32 tensor is not f32");
 
-  return success();
+  auto resultEType =
+      llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
+
+  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
+    resultEType = quantType.getStorageType();
+
+  // check allowed input/result element types combinations
+  if ((inputEType.isInteger(8) && resultEType.isInteger(32)) ||
+      (inputEType.isInteger(16) && resultEType.isInteger(48)) ||
+      (isa<Float8E5M2Type>(inputEType) && resultEType.isF16()) ||
+      (isa<Float8E4M3FNType>(inputEType) && resultEType.isF16()) ||
+      (inputEType.isF16() && resultEType.isF16()) ||
+      (inputEType.isBF16() && resultEType.isBF16()) ||
+      (inputEType.isF32() && resultEType.isF32()))
+    return success();
+
+  return op.emitOpError("input/output element types are incompatible.");
 }
 
 // verify that inType and outType have same element types
@@ -519,7 +531,8 @@ static void buildTransConvOpWithQuantInfo(
     OpBuilder &builder, OperationState &result, Type outputType, Value input,
     Value weight, Value bias, DenseI64ArrayAttr outpad,
     DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType) {
-  result.addOperands({input, weight, bias});
+  auto zps = createZPsAsConst(builder, input, weight);
+  result.addOperands({input, weight, bias, zps.first, zps.second});
   result.addAttribute("out_pad", outpad);
   result.addAttribute("stride", stride);
   result.addAttribute("out_shape", outputShape);
@@ -2478,18 +2491,15 @@ LogicalResult mlir::tosa::getZeroPoint(ElementsAttr zpAttr, int64_t &zp) {
   return failure();
 }
 
-// Create a rank-0 const tensor for zero point of the source tensor.
+// Create a rank-1 const tensor for zero point of the source tensor.
 std::optional<Value> mlir::tosa::createZeroPointTensor(OpBuilder &builder,
                                                        Location loc,
                                                        Type srcElemType,
                                                        int64_t zp) {
-  if (auto quantType =
-          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(srcElemType))
-    srcElemType = quantType.getStorageType();
-
-  auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
+  srcElemType = getElementTypeOrSelf(srcElemType);
   if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcElemType))
     srcElemType = quantType.getStorageType();
+  auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
   if (llvm::isa<FloatType>(srcElemType)) {
     auto zpAttr = DenseElementsAttr::get(
         zpType, builder.getFloatAttr(srcElemType, static_cast<double>(zp)));

@llvmbot
Copy link
Member

llvmbot commented Feb 7, 2025

@llvm/pr-subscribers-mlir

Author: Tai Ly (Tai78641)

Changes

…edType in Conv verifiers

also fixed buildTransConvOpWithQuantInfo to insert input/weight zp operands

Change-Id: Ie1961af931864f801914a62976bc988881ee075e


Full diff: https://github.com/llvm/llvm-project/pull/126275.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+26-16)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 031c279ff09e275..6143a9d23a00394 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -240,16 +240,13 @@ static LogicalResult verifyConvOp(T op) {
   bool biasIsFloat = llvm::isa<FloatType>(biasEType);
   bool resultIsFloat = llvm::isa<FloatType>(resultEType);
 
-  if (auto quantType =
-          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
+  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
     inputEType = quantType.getStorageType();
 
-  if (auto quantType =
-          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(biasEType))
+  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
     biasEType = quantType.getStorageType();
 
-  if (auto quantType =
-          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultEType))
+  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
     resultEType = quantType.getStorageType();
 
   if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
@@ -346,8 +343,7 @@ static LogicalResult verifyConvOpModes(T op) {
   auto inputEType =
       llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
 
-  if (auto quantType =
-          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
+  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
     inputEType = quantType.getStorageType();
 
   auto accType = op.getAccType();
@@ -369,7 +365,23 @@ static LogicalResult verifyConvOpModes(T op) {
   if (inputEType.isF32() && !accType.isF32())
     return op.emitOpError("accumulator type for f32 tensor is not f32");
 
-  return success();
+  auto resultEType =
+      llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
+
+  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
+    resultEType = quantType.getStorageType();
+
+  // check allowed input/result element types combinations
+  if ((inputEType.isInteger(8) && resultEType.isInteger(32)) ||
+      (inputEType.isInteger(16) && resultEType.isInteger(48)) ||
+      (isa<Float8E5M2Type>(inputEType) && resultEType.isF16()) ||
+      (isa<Float8E4M3FNType>(inputEType) && resultEType.isF16()) ||
+      (inputEType.isF16() && resultEType.isF16()) ||
+      (inputEType.isBF16() && resultEType.isBF16()) ||
+      (inputEType.isF32() && resultEType.isF32()))
+    return success();
+
+  return op.emitOpError("input/output element types are incompatible.");
 }
 
 // verify that inType and outType have same element types
@@ -519,7 +531,8 @@ static void buildTransConvOpWithQuantInfo(
     OpBuilder &builder, OperationState &result, Type outputType, Value input,
     Value weight, Value bias, DenseI64ArrayAttr outpad,
     DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType) {
-  result.addOperands({input, weight, bias});
+  auto zps = createZPsAsConst(builder, input, weight);
+  result.addOperands({input, weight, bias, zps.first, zps.second});
   result.addAttribute("out_pad", outpad);
   result.addAttribute("stride", stride);
   result.addAttribute("out_shape", outputShape);
@@ -2478,18 +2491,15 @@ LogicalResult mlir::tosa::getZeroPoint(ElementsAttr zpAttr, int64_t &zp) {
   return failure();
 }
 
-// Create a rank-0 const tensor for zero point of the source tensor.
+// Create a rank-1 const tensor for zero point of the source tensor.
 std::optional<Value> mlir::tosa::createZeroPointTensor(OpBuilder &builder,
                                                        Location loc,
                                                        Type srcElemType,
                                                        int64_t zp) {
-  if (auto quantType =
-          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(srcElemType))
-    srcElemType = quantType.getStorageType();
-
-  auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
+  srcElemType = getElementTypeOrSelf(srcElemType);
   if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcElemType))
     srcElemType = quantType.getStorageType();
+  auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
   if (llvm::isa<FloatType>(srcElemType)) {
     auto zpAttr = DenseElementsAttr::get(
         zpType, builder.getFloatAttr(srcElemType, static_cast<double>(zp)));

Copy link
Contributor

@GeorgeARM GeorgeARM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to add some tests? Are there already in place?

…edType in Conv verifiers

Change-Id: Ie1961af931864f801914a62976bc988881ee075e
Signed-off-by: Tai Ly <[email protected]>
@Tai78641
Copy link
Contributor Author

Do we need to add some tests? Are there already in place?

added tests

@Tai78641 Tai78641 requested a review from GeorgeARM February 10, 2025 21:42
@GeorgeARM GeorgeARM merged commit 1a8d2a4 into llvm:main Feb 11, 2025
8 checks passed
Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
Replace UniformQuantizedType by the more generic QuantizedType in Conv verifiers.

Change-Id: Ie1961af931864f801914a62976bc988881ee075e

Signed-off-by: Tai Ly <[email protected]>
Co-authored-by: Thibaut Goetghebuer-Planchon <[email protected]>
joaosaffran pushed a commit to joaosaffran/llvm-project that referenced this pull request Feb 14, 2025
Replace UniformQuantizedType by the more generic QuantizedType in Conv verifiers.

Change-Id: Ie1961af931864f801914a62976bc988881ee075e

Signed-off-by: Tai Ly <[email protected]>
Co-authored-by: Thibaut Goetghebuer-Planchon <[email protected]>
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
Replace UniformQuantizedType by the more generic QuantizedType in Conv verifiers.

Change-Id: Ie1961af931864f801914a62976bc988881ee075e

Signed-off-by: Tai Ly <[email protected]>
Co-authored-by: Thibaut Goetghebuer-Planchon <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants