Skip to content

[mlir][tosa] Change zero points of convolution ops to required inputs #127679

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 26, 2025

Conversation

Tai78641
Copy link
Contributor

This patch changes the input_zp and weight_zp for convolution operators to be required inputs
in order to align with the TOSA Spec 1.0.

Convolution operators affected are:
CONV2D, CONV3D, DEPTHWISE_CONV2D, and TRANSPOSE_CONV2D.

Change-Id: I7aa6e05580c83c617394c3f7fd18fba8c3f3b0ec

@llvmbot
Copy link
Member

llvmbot commented Feb 18, 2025

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: Tai Ly (Tai78641)

Changes

This patch changes the input_zp and weight_zp for convolution operators to be required inputs
in order to align with the TOSA Spec 1.0.

Convolution operators affected are:
CONV2D, CONV3D, DEPTHWISE_CONV2D, and TRANSPOSE_CONV2D.

Change-Id: I7aa6e05580c83c617394c3f7fd18fba8c3f3b0ec


Patch is 151.48 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/127679.diff

15 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h (-106)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+40-8)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+41-22)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+109-49)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp (+17-7)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp (+27-27)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+61-32)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+10-6)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+51-18)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+108-117)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+17-16)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir (+5-3)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir (+8-4)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+61-61)
  • (modified) mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir (+5-3)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index 069073bc2d164..c18b46c9474fc 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -161,112 +161,6 @@ namespace tosa {
 std::optional<Value> createZeroPointTensor(OpBuilder &builder, Location loc,
                                            Type srcElemType, int64_t zp = 0);
 
-// Get zero point value from the attribute argument.
-LogicalResult getZeroPoint(ElementsAttr zpAttr, int64_t &zp);
-
-// Verify if zero point falls into valid range.
-template <typename T>
-LogicalResult verifyZeroPoint(Type zpElemType, int64_t zp) {
-  if constexpr (!std::is_same_v<T, Conv2DOp> && !std::is_same_v<T, Conv3DOp> &&
-                !std::is_same_v<T, DepthwiseConv2DOp> &&
-                !std::is_same_v<T, TransposeConv2DOp>) {
-    return failure();
-  }
-
-  if (!zpElemType.isIntOrFloat())
-    return failure();
-
-  if (!zpElemType.isInteger(8) && zp != 0)
-    return failure();
-
-  if (zpElemType.isSignedInteger(8) && (zp < -128 || zp > 127))
-    return failure();
-
-  if (zpElemType.isUnsignedInteger(8) && (zp < 0 || zp > 255))
-    return failure();
-
-  return success();
-}
-
-// Helper type trait to determine if an operation is a tosa convolution.
-template <typename Op>
-struct IsTosaConv : std::false_type {};
-
-template <>
-struct IsTosaConv<tosa::Conv2DOp> : std::true_type {};
-template <>
-struct IsTosaConv<tosa::DepthwiseConv2DOp> : std::true_type {};
-template <>
-struct IsTosaConv<tosa::TransposeConv2DOp> : std::true_type {};
-template <>
-struct IsTosaConv<tosa::Conv3DOp> : std::true_type {};
-
-template <typename Op>
-constexpr bool is_tosa_conv_v = IsTosaConv<Op>::value;
-
-// Helper struct to hold the zero points of a TOSA convolution operation as
-// named 64-bit integer fields.
-struct ConvZpPair {
-  ConvZpPair(std::int64_t inputZp, std::int64_t weightZp)
-      : inputZp(inputZp), weightZp(weightZp) {}
-  std::int64_t inputZp;
-  std::int64_t weightZp;
-};
-
-// Helper function which attempts to extract the zero points from a TOSA
-// convolution by matching them against defining ops which should be tosa.const
-// operations.
-//
-// There are three possible results:
-// 1. Failed to extract the zero-points i.e. they should exist and don't or they
-// do exist but are invalid.
-// 2. Succeeded in extracting zero-points.
-// 3. Zero points are "empty" and meaningless for this op i.e. non-quantized
-// convolution.
-using FailOrMaybeZP = llvm::FailureOr<std::optional<ConvZpPair>>;
-template <typename TosaConvOp>
-std::enable_if_t<is_tosa_conv_v<TosaConvOp>, FailOrMaybeZP>
-extractConvZpPair(TosaConvOp op, PatternRewriter &rewriter) {
-  // Strictly speaking the base TOSA spec requires that for non int8 types
-  // zero points must be zero. However, in the dialect these operands are
-  // optional and only required for int8. They have no semantic meaning for
-  // non-quantized types and can therefore be safely ignored. This is case 3.
-  if (auto opElementTY =
-          cast<ShapedType>(op->getOperand(0).getType()).getElementType();
-      !opElementTY.isInteger(8))
-    return FailOrMaybeZP(std::nullopt);
-
-  // Now we know we should have a zero point check it is valid.
-  if (!op.getInputZp())
-    return rewriter.notifyMatchFailure(op, "missing input zero point");
-
-  // Helper to extract the zero point by matching its definition against a
-  // constant.
-  auto extractZeroPoint = [](Value zpValue) -> std::optional<int64_t> {
-    ElementsAttr zpAttr;
-    if (!matchPattern(zpValue, m_Constant(&zpAttr)))
-      return std::nullopt;
-
-    int64_t zp;
-    if (tosa::getZeroPoint(zpAttr, zp).failed())
-      return std::nullopt;
-
-    return std::make_optional(zp);
-  };
-
-  auto maybeInputZp = extractZeroPoint(op.getInputZp());
-  if (!maybeInputZp)
-    return rewriter.notifyMatchFailure(op, "unable to extract input zp");
-
-  if (!op.getWeightZp())
-    return rewriter.notifyMatchFailure(op, "missing weight zero point");
-
-  auto maybeWeightZp = extractZeroPoint(op.getWeightZp());
-  if (!maybeWeightZp)
-    return rewriter.notifyMatchFailure(op, "unable to extract weight zp");
-
-  return std::make_optional<ConvZpPair>(*maybeInputZp, *maybeWeightZp);
-}
 } // namespace tosa
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index d11ba65a13736..8ac6a81dc38af 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -105,8 +105,9 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
     Tosa_Tensor4D:$input,
     TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
     Tosa_Tensor1D:$bias,
-    Optional<Tosa_ScalarTensor>:$input_zp,
-    Optional<Tosa_ScalarTensor>:$weight_zp,
+    Tosa_ScalarTensor:$input_zp,
+    Tosa_ScalarTensor:$weight_zp,
+
     Tosa_IntArrayAttr4:$pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr2:$dilation,
@@ -118,6 +119,13 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
     Tosa_Tensor4D:$output
   );
 
+  let extraClassDeclaration = [{
+    LogicalResult getInputZeroPoint(int64_t &zp);
+    LogicalResult getWeightZeroPoint(int64_t &zp);
+    LogicalResult verifyInputZeroPoint(int64_t zp);
+    LogicalResult verifyWeightZeroPoint(int64_t zp);
+  }];
+
   let builders = [Tosa_ConvOpQuantInfoBuilder];
   let hasVerifier = 1;
 }
@@ -136,8 +144,9 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
     Tosa_Tensor5D:$input,
     TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
     Tosa_Tensor1D:$bias,
-    Optional<Tosa_ScalarTensor>:$input_zp,
-    Optional<Tosa_ScalarTensor>:$weight_zp,
+    Tosa_ScalarTensor:$input_zp,
+    Tosa_ScalarTensor:$weight_zp,
+
     Tosa_IntArrayAttr6:$pad,
     Tosa_IntArrayAttr3:$stride,
     Tosa_IntArrayAttr3:$dilation,
@@ -149,6 +158,13 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
     Tosa_Tensor5D:$output
   );
 
+  let extraClassDeclaration = [{
+    LogicalResult getInputZeroPoint(int64_t &zp);
+    LogicalResult getWeightZeroPoint(int64_t &zp);
+    LogicalResult verifyInputZeroPoint(int64_t zp);
+    LogicalResult verifyWeightZeroPoint(int64_t zp);
+  }];
+
   let builders = [Tosa_ConvOpQuantInfoBuilder];
   let hasVerifier = 1;
 }
@@ -168,8 +184,9 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
     Tosa_Tensor4D:$input,
     TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
     Tosa_Tensor1D:$bias,
-    Optional<Tosa_ScalarTensor>:$input_zp,
-    Optional<Tosa_ScalarTensor>:$weight_zp,
+    Tosa_ScalarTensor:$input_zp,
+    Tosa_ScalarTensor:$weight_zp,
+
     Tosa_IntArrayAttr4:$pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr2:$dilation,
@@ -181,6 +198,13 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
     Tosa_Tensor4D:$output
   );
 
+  let extraClassDeclaration = [{
+    LogicalResult getInputZeroPoint(int64_t &zp);
+    LogicalResult getWeightZeroPoint(int64_t &zp);
+    LogicalResult verifyInputZeroPoint(int64_t zp);
+    LogicalResult verifyWeightZeroPoint(int64_t zp);
+  }];
+
   let builders = [Tosa_ConvOpQuantInfoBuilder];
   let hasVerifier = 1;
 }
@@ -330,8 +354,9 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
     Tosa_Tensor4D:$input,
     TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
     Tosa_Tensor1D:$bias,
-    Optional<Tosa_ScalarTensor>:$input_zp,
-    Optional<Tosa_ScalarTensor>:$weight_zp,
+    Tosa_ScalarTensor:$input_zp,
+    Tosa_ScalarTensor:$weight_zp,
+
     Tosa_IntArrayAttr4:$out_pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr4:$out_shape,
@@ -343,6 +368,13 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
     Tosa_Tensor4D:$output
   );
 
+  let extraClassDeclaration = [{
+    LogicalResult getInputZeroPoint(int64_t &zp);
+    LogicalResult getWeightZeroPoint(int64_t &zp);
+    LogicalResult verifyInputZeroPoint(int64_t zp);
+    LogicalResult verifyWeightZeroPoint(int64_t zp);
+  }];
+
   let builders = [Tosa_TransConvOpQuantInfoBuilder];
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index a8fd536dd2548..bb7d5a23d9365 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -259,11 +259,21 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
     DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
 
-    auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
-    if (llvm::failed(failureOrMaybeZps))
-      return failure();
+    // Get and verify zero points.
+    int64_t inputZpVal;
+    int64_t weightZpVal;
+
+    if (op.getInputZeroPoint(inputZpVal).failed() ||
+        op.getWeightZeroPoint(weightZpVal).failed())
+      return rewriter.notifyMatchFailure(
+          op, "bail out if zero points cannot statically be determined");
+
+    if (op.verifyInputZeroPoint(inputZpVal).failed() ||
+        op.verifyWeightZeroPoint(weightZpVal).failed())
+      return rewriter.notifyMatchFailure(
+          op, "zero point must be zero for non-int8 integer types");
 
-    auto maybeZps = failureOrMaybeZps.value();
+    bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);
 
     if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
       return rewriter.notifyMatchFailure(
@@ -289,7 +299,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
 
     // Apply padding as necessary.
     TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
-    if (maybeZps) {
+    if (hasZp) {
       int64_t intMin =
           APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
               .getSExtValue();
@@ -297,11 +307,11 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
           APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
               .getSExtValue();
 
-      if (maybeZps->inputZp < intMin || maybeZps->inputZp > intMax)
+      if (inputZpVal < intMin || inputZpVal > intMax)
         return rewriter.notifyMatchFailure(
             op, "tosa.conv op quantization has zp outside of input range");
 
-      zeroAttr = rewriter.getIntegerAttr(inputETy, maybeZps->inputZp);
+      zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
     }
 
     llvm::SmallVector<int64_t> pad;
@@ -314,8 +324,8 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
       // For 2D convolutions, we need to check if the target convolution op
       // wants a HWCF kernel layout.
       bool wantHwcf =
-          maybeZps ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
-                   : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
+          hasZp ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
+                : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
       if (wantHwcf) {
         // Transpose the kernel to match dimension ordering of the linalg
         // convolution operation.
@@ -376,9 +386,9 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     Value broadcastBias =
         linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
 
-    if (maybeZps) {
-      auto iZp = rewriter.getI32IntegerAttr(maybeZps->inputZp);
-      auto kZp = rewriter.getI32IntegerAttr(maybeZps->weightZp);
+    if (hasZp) {
+      auto iZp = rewriter.getI32IntegerAttr(inputZpVal);
+      auto kZp = rewriter.getI32IntegerAttr(weightZpVal);
 
       auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
       auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
@@ -441,18 +451,27 @@ class DepthwiseConvConverter
         /*inputSizeDims=*/{1, 2},
         /*kernelSizeDims=*/{0, 1}, rewriter);
 
-    auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
-    if (llvm::failed(failureOrMaybeZps))
-      return failure();
+    // Get and verify zero points.
+    int64_t inputZpVal;
+    int64_t weightZpVal;
+
+    if (op.getInputZeroPoint(inputZpVal).failed() ||
+        op.getWeightZeroPoint(weightZpVal).failed())
+      return rewriter.notifyMatchFailure(
+          op, "bail out if zero points cannot statically be determined");
 
-    auto maybeZps = failureOrMaybeZps.value();
+    if (op.verifyInputZeroPoint(inputZpVal).failed() ||
+        op.verifyWeightZeroPoint(weightZpVal).failed())
+      return rewriter.notifyMatchFailure(
+          op, "zero point must be zero for non-int8 integer types");
 
+    bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);
     auto weightShape = weightTy.getShape();
     auto resultShape = resultTy.getShape();
 
     // Apply padding as necessary.
     TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
-    if (maybeZps) {
+    if (hasZp) {
       int64_t intMin =
           APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
               .getSExtValue();
@@ -460,12 +479,12 @@ class DepthwiseConvConverter
           APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
               .getSExtValue();
 
-      if (maybeZps->inputZp < intMin || maybeZps->inputZp > intMax)
+      if (inputZpVal < intMin || inputZpVal > intMax)
         return rewriter.notifyMatchFailure(
             op, "tosa.depthwise_conv op quantization has zp outside of input "
                 "range");
 
-      zeroAttr = rewriter.getIntegerAttr(inputETy, maybeZps->inputZp);
+      zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
     }
 
     llvm::SmallVector<int64_t> pad;
@@ -505,7 +524,7 @@ class DepthwiseConvConverter
     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
 
-    if (!maybeZps) {
+    if (!hasZp) {
       Value conv = rewriter
                        .create<linalg::DepthwiseConv2DNhwcHwcmOp>(
                            loc, linalgConvTy, ValueRange{input, weight},
@@ -532,8 +551,8 @@ class DepthwiseConvConverter
               .getResult(0);
       rewriter.replaceOp(op, result);
     } else {
-      IntegerAttr iZp = rewriter.getI32IntegerAttr(maybeZps->inputZp);
-      IntegerAttr wZp = rewriter.getI32IntegerAttr(maybeZps->weightZp);
+      IntegerAttr iZp = rewriter.getI32IntegerAttr(inputZpVal);
+      IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal);
       auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
       auto kZpVal = rewriter.create<arith::ConstantOp>(loc, wZp);
       Value conv =
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 67021d6c07401..ad8077d3d47a6 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -211,6 +211,18 @@ void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
   }
 }
 
+//===----------------------------------------------------------------------===//
+// Tosa utilities.
+//===----------------------------------------------------------------------===//
+
+static Type getStorageElementTypeOrSelf(Type type) {
+  auto elementType = getElementTypeOrSelf(type);
+  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(elementType))
+    elementType = quantType.getStorageType();
+
+  return elementType;
+}
+
 //===----------------------------------------------------------------------===//
 // TOSA Operator Verifiers.
 //===----------------------------------------------------------------------===//
@@ -243,6 +255,9 @@ static LogicalResult verifyConvOp(T op) {
   if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
     inputEType = quantType.getStorageType();
 
+  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
+    weightEType = quantType.getStorageType();
+
   if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
     biasEType = quantType.getStorageType();
 
@@ -279,36 +294,32 @@ static LogicalResult verifyConvOp(T op) {
     return failure();
   }
 
-  // We require an explicit input zero point and weight zero point for i8
-  // convolution.
-  if (!op.getInputZp() && !op.getWeightZp())
-    return inputEType.isInteger(8) ? failure() : success();
+  auto inputZpEType = getStorageElementTypeOrSelf(op.getInputZp().getType());
+  if (inputEType != inputZpEType) {
+    return op.emitOpError("expect both input and its zero point are the same "
+                          "element type, got ")
+           << inputEType << " and " << inputZpEType;
+  }
 
-  ElementsAttr inputZpAttr;
-  ElementsAttr weightZpAttr;
-  if (!matchPattern(op.getInputZp(), m_Constant(&inputZpAttr)) ||
-      !matchPattern(op.getWeightZp(), m_Constant(&weightZpAttr))) {
-    op.emitOpError(
-        "bail out if the actual value of zero points cannot be determined");
-    return failure();
+  auto weightZpEType = getStorageElementTypeOrSelf(op.getWeightZp().getType());
+  if (weightEType != weightZpEType) {
+    return op.emitOpError("expect both weight and its zero point are the same "
+                          "element type, got ")
+           << weightEType << " and " << weightZpEType;
   }
 
-  // Get and verify explicit zero points.
   int64_t inputZpVal;
-  int64_t weightZpVal;
-
-  if (tosa::getZeroPoint(inputZpAttr, inputZpVal).failed() ||
-      tosa::verifyZeroPoint<T>(getElementTypeOrSelf(inputZpAttr), inputZpVal)
-          .failed()) {
-    op.emitOpError("input zero point must be zero for non-int8 integer types");
-    return failure();
+  if (op.getInputZeroPoint(inputZpVal).succeeded()) {
+    if (op.verifyInputZeroPoint(inputZpVal).failed())
+      return op.emitOpError(
+          "input zero point must be zero for non-int8 integer types");
   }
 
-  if (tosa::getZeroPoint(weightZpAttr, weightZpVal).failed() ||
-      tosa::verifyZeroPoint<T>(getElementTypeOrSelf(weightZpAttr), weightZpVal)
-          .failed()) {
-    op.emitOpError("weight zero point must be zero for non-int8 integer types");
-    return failure();
+  int64_t weightZpVal;
+  if (op.getWeightZeroPoint(weightZpVal).succeeded()) {
+    if (op.verifyWeightZeroPoint(weightZpVal).failed())
+      return op.emitOpError(
+          "weight zero point must be zero for non-int8 integer types");
   }
 
   return success();
@@ -1371,6 +1382,79 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
   return mlir::success();
 }
 
+template <typename T>
+static LogicalResult getZeroPoint(T op, Value val, int64_t &zp) {
+  ElementsAttr zpAttr;
+  if (!matchPattern(val, m_Constant(&zpAttr))) {
+    return failure();
+  }
+
+  Type zpElemType = zpAttr.getElementType();
+  if (auto quantType =
+          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(zpElemType)) {
+    zp = quantType.getZeroPoint();
+    return success();
+  }
+
+  if (llvm::isa<FloatType>(zpElemType)) {
+    if (!zpAttr.getValues<APFloat>()[0].isZero())
+      return op.emitOpError(
+          "non-zero zero point is not allowed for float types");
+    zp = 0;
+    return success();
+  }
+
+  if (llvm::isa<IntegerType>(zpElemType)) {
+    zp = zpAttr.getValues<APInt>()[0].getSExtValue();
+    return success();
+  }
+
+  return op.emitOpError("zero point is not allowed for unsupported types");
+}
+
+template <typename T>
+static LogicalResult verifyZeroPoint(T op, Value val, int64_t &zp) {
+  // TODO clean it up when the entire zero point (attribute -> input tensor
+  // type) change is done. Remaining Matmul, Rescale, Negate, and AvgPool2D.
+  if constexpr (!std::is_same_v<T, Conv2DOp> && !std::is_same_v<T, Conv3DOp> &&
+                !std::is_same_v<T, DepthwiseConv2DOp> &&
+                !std::is_same_v<T, TransposeConv2DOp>)
+    return failure();
+
+  Type zpElemType = getElementTypeOrSelf(val);
+
+  if (!zpElemType.isIntOrFloat())
+    return op.emitOpError("zero point is not integer or float typss");
+
+  if (!zpElemType.isInteger(8) && zp != 0)
+    return op.emitOpError("zero point must be zero for non-int8 integer types");
+
+  if (zp < -128 || zp > 127)
+    return failure();
+
+  return success();
+}
+
+#define ZERO_POINT_HELPER(OP)                                                  \
+  LogicalResult tosa::OP::getInputZeroPoint(int64_t &zp) {                     \
+    return getZeroPoint(*this, getInputZp(), zp);                              \
+  }                                                            ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Feb 18, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Tai Ly (Tai78641)

Changes

This patch changes the input_zp and weight_zp for convolution operators to be required inputs
in order to align with the TOSA Spec 1.0.

Convolution operators affected are:
CONV2D, CONV3D, DEPTHWISE_CONV2D, and TRANSPOSE_CONV2D.

Change-Id: I7aa6e05580c83c617394c3f7fd18fba8c3f3b0ec


Patch is 151.48 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/127679.diff

15 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h (-106)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+40-8)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+41-22)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+109-49)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp (+17-7)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp (+27-27)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+61-32)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+10-6)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+51-18)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+108-117)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+17-16)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir (+5-3)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir (+8-4)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+61-61)
  • (modified) mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir (+5-3)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index 069073bc2d164..c18b46c9474fc 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -161,112 +161,6 @@ namespace tosa {
 std::optional<Value> createZeroPointTensor(OpBuilder &builder, Location loc,
                                            Type srcElemType, int64_t zp = 0);
 
-// Get zero point value from the attribute argument.
-LogicalResult getZeroPoint(ElementsAttr zpAttr, int64_t &zp);
-
-// Verify if zero point falls into valid range.
-template <typename T>
-LogicalResult verifyZeroPoint(Type zpElemType, int64_t zp) {
-  if constexpr (!std::is_same_v<T, Conv2DOp> && !std::is_same_v<T, Conv3DOp> &&
-                !std::is_same_v<T, DepthwiseConv2DOp> &&
-                !std::is_same_v<T, TransposeConv2DOp>) {
-    return failure();
-  }
-
-  if (!zpElemType.isIntOrFloat())
-    return failure();
-
-  if (!zpElemType.isInteger(8) && zp != 0)
-    return failure();
-
-  if (zpElemType.isSignedInteger(8) && (zp < -128 || zp > 127))
-    return failure();
-
-  if (zpElemType.isUnsignedInteger(8) && (zp < 0 || zp > 255))
-    return failure();
-
-  return success();
-}
-
-// Helper type trait to determine if an operation is a tosa convolution.
-template <typename Op>
-struct IsTosaConv : std::false_type {};
-
-template <>
-struct IsTosaConv<tosa::Conv2DOp> : std::true_type {};
-template <>
-struct IsTosaConv<tosa::DepthwiseConv2DOp> : std::true_type {};
-template <>
-struct IsTosaConv<tosa::TransposeConv2DOp> : std::true_type {};
-template <>
-struct IsTosaConv<tosa::Conv3DOp> : std::true_type {};
-
-template <typename Op>
-constexpr bool is_tosa_conv_v = IsTosaConv<Op>::value;
-
-// Helper struct to hold the zero points of a TOSA convolution operation as
-// named 64-bit integer fields.
-struct ConvZpPair {
-  ConvZpPair(std::int64_t inputZp, std::int64_t weightZp)
-      : inputZp(inputZp), weightZp(weightZp) {}
-  std::int64_t inputZp;
-  std::int64_t weightZp;
-};
-
-// Helper function which attempts to extract the zero points from a TOSA
-// convolution by matching them against defining ops which should be tosa.const
-// operations.
-//
-// There are three possible results:
-// 1. Failed to extract the zero-points i.e. they should exist and don't or they
-// do exist but are invalid.
-// 2. Succeeded in extracting zero-points.
-// 3. Zero points are "empty" and meaningless for this op i.e. non-quantized
-// convolution.
-using FailOrMaybeZP = llvm::FailureOr<std::optional<ConvZpPair>>;
-template <typename TosaConvOp>
-std::enable_if_t<is_tosa_conv_v<TosaConvOp>, FailOrMaybeZP>
-extractConvZpPair(TosaConvOp op, PatternRewriter &rewriter) {
-  // Strictly speaking the base TOSA spec requires that for non int8 types
-  // zero points must be zero. However, in the dialect these operands are
-  // optional and only required for int8. They have no semantic meaning for
-  // non-quantized types and can therefore be safely ignored. This is case 3.
-  if (auto opElementTY =
-          cast<ShapedType>(op->getOperand(0).getType()).getElementType();
-      !opElementTY.isInteger(8))
-    return FailOrMaybeZP(std::nullopt);
-
-  // Now we know we should have a zero point check it is valid.
-  if (!op.getInputZp())
-    return rewriter.notifyMatchFailure(op, "missing input zero point");
-
-  // Helper to extract the zero point by matching its definition against a
-  // constant.
-  auto extractZeroPoint = [](Value zpValue) -> std::optional<int64_t> {
-    ElementsAttr zpAttr;
-    if (!matchPattern(zpValue, m_Constant(&zpAttr)))
-      return std::nullopt;
-
-    int64_t zp;
-    if (tosa::getZeroPoint(zpAttr, zp).failed())
-      return std::nullopt;
-
-    return std::make_optional(zp);
-  };
-
-  auto maybeInputZp = extractZeroPoint(op.getInputZp());
-  if (!maybeInputZp)
-    return rewriter.notifyMatchFailure(op, "unable to extract input zp");
-
-  if (!op.getWeightZp())
-    return rewriter.notifyMatchFailure(op, "missing weight zero point");
-
-  auto maybeWeightZp = extractZeroPoint(op.getWeightZp());
-  if (!maybeWeightZp)
-    return rewriter.notifyMatchFailure(op, "unable to extract weight zp");
-
-  return std::make_optional<ConvZpPair>(*maybeInputZp, *maybeWeightZp);
-}
 } // namespace tosa
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index d11ba65a13736..8ac6a81dc38af 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -105,8 +105,9 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
     Tosa_Tensor4D:$input,
     TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
     Tosa_Tensor1D:$bias,
-    Optional<Tosa_ScalarTensor>:$input_zp,
-    Optional<Tosa_ScalarTensor>:$weight_zp,
+    Tosa_ScalarTensor:$input_zp,
+    Tosa_ScalarTensor:$weight_zp,
+
     Tosa_IntArrayAttr4:$pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr2:$dilation,
@@ -118,6 +119,13 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
     Tosa_Tensor4D:$output
   );
 
+  let extraClassDeclaration = [{
+    LogicalResult getInputZeroPoint(int64_t &zp);
+    LogicalResult getWeightZeroPoint(int64_t &zp);
+    LogicalResult verifyInputZeroPoint(int64_t zp);
+    LogicalResult verifyWeightZeroPoint(int64_t zp);
+  }];
+
   let builders = [Tosa_ConvOpQuantInfoBuilder];
   let hasVerifier = 1;
 }
@@ -136,8 +144,9 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
     Tosa_Tensor5D:$input,
     TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
     Tosa_Tensor1D:$bias,
-    Optional<Tosa_ScalarTensor>:$input_zp,
-    Optional<Tosa_ScalarTensor>:$weight_zp,
+    Tosa_ScalarTensor:$input_zp,
+    Tosa_ScalarTensor:$weight_zp,
+
     Tosa_IntArrayAttr6:$pad,
     Tosa_IntArrayAttr3:$stride,
     Tosa_IntArrayAttr3:$dilation,
@@ -149,6 +158,13 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
     Tosa_Tensor5D:$output
   );
 
+  let extraClassDeclaration = [{
+    LogicalResult getInputZeroPoint(int64_t &zp);
+    LogicalResult getWeightZeroPoint(int64_t &zp);
+    LogicalResult verifyInputZeroPoint(int64_t zp);
+    LogicalResult verifyWeightZeroPoint(int64_t zp);
+  }];
+
   let builders = [Tosa_ConvOpQuantInfoBuilder];
   let hasVerifier = 1;
 }
@@ -168,8 +184,9 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
     Tosa_Tensor4D:$input,
     TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
     Tosa_Tensor1D:$bias,
-    Optional<Tosa_ScalarTensor>:$input_zp,
-    Optional<Tosa_ScalarTensor>:$weight_zp,
+    Tosa_ScalarTensor:$input_zp,
+    Tosa_ScalarTensor:$weight_zp,
+
     Tosa_IntArrayAttr4:$pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr2:$dilation,
@@ -181,6 +198,13 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
     Tosa_Tensor4D:$output
   );
 
+  let extraClassDeclaration = [{
+    LogicalResult getInputZeroPoint(int64_t &zp);
+    LogicalResult getWeightZeroPoint(int64_t &zp);
+    LogicalResult verifyInputZeroPoint(int64_t zp);
+    LogicalResult verifyWeightZeroPoint(int64_t zp);
+  }];
+
   let builders = [Tosa_ConvOpQuantInfoBuilder];
   let hasVerifier = 1;
 }
@@ -330,8 +354,9 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
     Tosa_Tensor4D:$input,
     TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
     Tosa_Tensor1D:$bias,
-    Optional<Tosa_ScalarTensor>:$input_zp,
-    Optional<Tosa_ScalarTensor>:$weight_zp,
+    Tosa_ScalarTensor:$input_zp,
+    Tosa_ScalarTensor:$weight_zp,
+
     Tosa_IntArrayAttr4:$out_pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr4:$out_shape,
@@ -343,6 +368,13 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
     Tosa_Tensor4D:$output
   );
 
+  let extraClassDeclaration = [{
+    LogicalResult getInputZeroPoint(int64_t &zp);
+    LogicalResult getWeightZeroPoint(int64_t &zp);
+    LogicalResult verifyInputZeroPoint(int64_t zp);
+    LogicalResult verifyWeightZeroPoint(int64_t zp);
+  }];
+
   let builders = [Tosa_TransConvOpQuantInfoBuilder];
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index a8fd536dd2548..bb7d5a23d9365 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -259,11 +259,21 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
     DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
 
-    auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
-    if (llvm::failed(failureOrMaybeZps))
-      return failure();
+    // Get and verify zero points.
+    int64_t inputZpVal;
+    int64_t weightZpVal;
+
+    if (op.getInputZeroPoint(inputZpVal).failed() ||
+        op.getWeightZeroPoint(weightZpVal).failed())
+      return rewriter.notifyMatchFailure(
+          op, "bail out if zero points cannot statically be determined");
+
+    if (op.verifyInputZeroPoint(inputZpVal).failed() ||
+        op.verifyWeightZeroPoint(weightZpVal).failed())
+      return rewriter.notifyMatchFailure(
+          op, "zero point must be zero for non-int8 integer types");
 
-    auto maybeZps = failureOrMaybeZps.value();
+    bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);
 
     if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
       return rewriter.notifyMatchFailure(
@@ -289,7 +299,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
 
     // Apply padding as necessary.
     TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
-    if (maybeZps) {
+    if (hasZp) {
       int64_t intMin =
           APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
               .getSExtValue();
@@ -297,11 +307,11 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
           APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
               .getSExtValue();
 
-      if (maybeZps->inputZp < intMin || maybeZps->inputZp > intMax)
+      if (inputZpVal < intMin || inputZpVal > intMax)
         return rewriter.notifyMatchFailure(
             op, "tosa.conv op quantization has zp outside of input range");
 
-      zeroAttr = rewriter.getIntegerAttr(inputETy, maybeZps->inputZp);
+      zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
     }
 
     llvm::SmallVector<int64_t> pad;
@@ -314,8 +324,8 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
       // For 2D convolutions, we need to check if the target convolution op
       // wants a HWCF kernel layout.
       bool wantHwcf =
-          maybeZps ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
-                   : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
+          hasZp ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
+                : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
       if (wantHwcf) {
         // Transpose the kernel to match dimension ordering of the linalg
         // convolution operation.
@@ -376,9 +386,9 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     Value broadcastBias =
         linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
 
-    if (maybeZps) {
-      auto iZp = rewriter.getI32IntegerAttr(maybeZps->inputZp);
-      auto kZp = rewriter.getI32IntegerAttr(maybeZps->weightZp);
+    if (hasZp) {
+      auto iZp = rewriter.getI32IntegerAttr(inputZpVal);
+      auto kZp = rewriter.getI32IntegerAttr(weightZpVal);
 
       auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
       auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
@@ -441,18 +451,27 @@ class DepthwiseConvConverter
         /*inputSizeDims=*/{1, 2},
         /*kernelSizeDims=*/{0, 1}, rewriter);
 
-    auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
-    if (llvm::failed(failureOrMaybeZps))
-      return failure();
+    // Get and verify zero points.
+    int64_t inputZpVal;
+    int64_t weightZpVal;
+
+    if (op.getInputZeroPoint(inputZpVal).failed() ||
+        op.getWeightZeroPoint(weightZpVal).failed())
+      return rewriter.notifyMatchFailure(
+          op, "bail out if zero points cannot statically be determined");
 
-    auto maybeZps = failureOrMaybeZps.value();
+    if (op.verifyInputZeroPoint(inputZpVal).failed() ||
+        op.verifyWeightZeroPoint(weightZpVal).failed())
+      return rewriter.notifyMatchFailure(
+          op, "zero point must be zero for non-int8 integer types");
 
+    bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);
     auto weightShape = weightTy.getShape();
     auto resultShape = resultTy.getShape();
 
     // Apply padding as necessary.
     TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
-    if (maybeZps) {
+    if (hasZp) {
       int64_t intMin =
           APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
               .getSExtValue();
@@ -460,12 +479,12 @@ class DepthwiseConvConverter
           APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
               .getSExtValue();
 
-      if (maybeZps->inputZp < intMin || maybeZps->inputZp > intMax)
+      if (inputZpVal < intMin || inputZpVal > intMax)
         return rewriter.notifyMatchFailure(
             op, "tosa.depthwise_conv op quantization has zp outside of input "
                 "range");
 
-      zeroAttr = rewriter.getIntegerAttr(inputETy, maybeZps->inputZp);
+      zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
     }
 
     llvm::SmallVector<int64_t> pad;
@@ -505,7 +524,7 @@ class DepthwiseConvConverter
     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
 
-    if (!maybeZps) {
+    if (!hasZp) {
       Value conv = rewriter
                        .create<linalg::DepthwiseConv2DNhwcHwcmOp>(
                            loc, linalgConvTy, ValueRange{input, weight},
@@ -532,8 +551,8 @@ class DepthwiseConvConverter
               .getResult(0);
       rewriter.replaceOp(op, result);
     } else {
-      IntegerAttr iZp = rewriter.getI32IntegerAttr(maybeZps->inputZp);
-      IntegerAttr wZp = rewriter.getI32IntegerAttr(maybeZps->weightZp);
+      IntegerAttr iZp = rewriter.getI32IntegerAttr(inputZpVal);
+      IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal);
       auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
       auto kZpVal = rewriter.create<arith::ConstantOp>(loc, wZp);
       Value conv =
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 67021d6c07401..ad8077d3d47a6 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -211,6 +211,18 @@ void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
   }
 }
 
+//===----------------------------------------------------------------------===//
+// Tosa utilities.
+//===----------------------------------------------------------------------===//
+
+static Type getStorageElementTypeOrSelf(Type type) {
+  auto elementType = getElementTypeOrSelf(type);
+  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(elementType))
+    elementType = quantType.getStorageType();
+
+  return elementType;
+}
+
 //===----------------------------------------------------------------------===//
 // TOSA Operator Verifiers.
 //===----------------------------------------------------------------------===//
@@ -243,6 +255,9 @@ static LogicalResult verifyConvOp(T op) {
   if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
     inputEType = quantType.getStorageType();
 
+  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
+    weightEType = quantType.getStorageType();
+
   if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
     biasEType = quantType.getStorageType();
 
@@ -279,36 +294,32 @@ static LogicalResult verifyConvOp(T op) {
     return failure();
   }
 
-  // We require an explicit input zero point and weight zero point for i8
-  // convolution.
-  if (!op.getInputZp() && !op.getWeightZp())
-    return inputEType.isInteger(8) ? failure() : success();
+  auto inputZpEType = getStorageElementTypeOrSelf(op.getInputZp().getType());
+  if (inputEType != inputZpEType) {
+    return op.emitOpError("expect both input and its zero point are the same "
+                          "element type, got ")
+           << inputEType << " and " << inputZpEType;
+  }
 
-  ElementsAttr inputZpAttr;
-  ElementsAttr weightZpAttr;
-  if (!matchPattern(op.getInputZp(), m_Constant(&inputZpAttr)) ||
-      !matchPattern(op.getWeightZp(), m_Constant(&weightZpAttr))) {
-    op.emitOpError(
-        "bail out if the actual value of zero points cannot be determined");
-    return failure();
+  auto weightZpEType = getStorageElementTypeOrSelf(op.getWeightZp().getType());
+  if (weightEType != weightZpEType) {
+    return op.emitOpError("expect both weight and its zero point are the same "
+                          "element type, got ")
+           << weightEType << " and " << weightZpEType;
   }
 
-  // Get and verify explicit zero points.
   int64_t inputZpVal;
-  int64_t weightZpVal;
-
-  if (tosa::getZeroPoint(inputZpAttr, inputZpVal).failed() ||
-      tosa::verifyZeroPoint<T>(getElementTypeOrSelf(inputZpAttr), inputZpVal)
-          .failed()) {
-    op.emitOpError("input zero point must be zero for non-int8 integer types");
-    return failure();
+  if (op.getInputZeroPoint(inputZpVal).succeeded()) {
+    if (op.verifyInputZeroPoint(inputZpVal).failed())
+      return op.emitOpError(
+          "input zero point must be zero for non-int8 integer types");
   }
 
-  if (tosa::getZeroPoint(weightZpAttr, weightZpVal).failed() ||
-      tosa::verifyZeroPoint<T>(getElementTypeOrSelf(weightZpAttr), weightZpVal)
-          .failed()) {
-    op.emitOpError("weight zero point must be zero for non-int8 integer types");
-    return failure();
+  int64_t weightZpVal;
+  if (op.getWeightZeroPoint(weightZpVal).succeeded()) {
+    if (op.verifyWeightZeroPoint(weightZpVal).failed())
+      return op.emitOpError(
+          "weight zero point must be zero for non-int8 integer types");
   }
 
   return success();
@@ -1371,6 +1382,79 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
   return mlir::success();
 }
 
+template <typename T>
+static LogicalResult getZeroPoint(T op, Value val, int64_t &zp) {
+  ElementsAttr zpAttr;
+  if (!matchPattern(val, m_Constant(&zpAttr))) {
+    return failure();
+  }
+
+  Type zpElemType = zpAttr.getElementType();
+  if (auto quantType =
+          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(zpElemType)) {
+    zp = quantType.getZeroPoint();
+    return success();
+  }
+
+  if (llvm::isa<FloatType>(zpElemType)) {
+    if (!zpAttr.getValues<APFloat>()[0].isZero())
+      return op.emitOpError(
+          "non-zero zero point is not allowed for float types");
+    zp = 0;
+    return success();
+  }
+
+  if (llvm::isa<IntegerType>(zpElemType)) {
+    zp = zpAttr.getValues<APInt>()[0].getSExtValue();
+    return success();
+  }
+
+  return op.emitOpError("zero point is not allowed for unsupported types");
+}
+
+template <typename T>
+static LogicalResult verifyZeroPoint(T op, Value val, int64_t &zp) {
+  // TODO clean it up when the entire zero point (attribute -> input tensor
+  // type) change is done. Remaining Matmul, Rescale, Negate, and AvgPool2D.
+  if constexpr (!std::is_same_v<T, Conv2DOp> && !std::is_same_v<T, Conv3DOp> &&
+                !std::is_same_v<T, DepthwiseConv2DOp> &&
+                !std::is_same_v<T, TransposeConv2DOp>)
+    return failure();
+
+  Type zpElemType = getElementTypeOrSelf(val);
+
+  if (!zpElemType.isIntOrFloat())
+    return op.emitOpError("zero point is not integer or float typss");
+
+  if (!zpElemType.isInteger(8) && zp != 0)
+    return op.emitOpError("zero point must be zero for non-int8 integer types");
+
+  if (zp < -128 || zp > 127)
+    return failure();
+
+  return success();
+}
+
+#define ZERO_POINT_HELPER(OP)                                                  \
+  LogicalResult tosa::OP::getInputZeroPoint(int64_t &zp) {                     \
+    return getZeroPoint(*this, getInputZp(), zp);                              \
+  }                                                            ...
[truncated]

@Tai78641 Tai78641 force-pushed the pr_require_conv_zps branch 2 times, most recently from 129a1c6 to 93388fe Compare February 19, 2025 21:50
@Tai78641
Copy link
Contributor Author

rebased

@Tai78641 Tai78641 force-pushed the pr_require_conv_zps branch 2 times, most recently from 214ea7b to 7494004 Compare February 25, 2025 18:35
Copy link
Contributor

@GeorgeARM GeorgeARM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Leaving to @lhutton1

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Tai78641 LGTM!

This patch changes the input_zp and weight_zp for convolution
operators to be required inputs.
Convolution operators affected are:
	CONV2D, CONV3D, DEPTHWISE_CONV2D, and TRANSPOSE_CONV2D.

Signed-off-by: Tai Ly <[email protected]>
Change-Id: I7aa6e05580c83c617394c3f7fd18fba8c3f3b0ec
@Tai78641 Tai78641 force-pushed the pr_require_conv_zps branch from 7494004 to 42f001d Compare February 26, 2025 17:50
@Jerry-Ge Jerry-Ge merged commit ca5bb23 into llvm:main Feb 26, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants