Skip to content

[mlir][tosa] Change Rescale zero points to be inputs #130340

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 11, 2025

Conversation

Tai78641
Copy link
Contributor

@Tai78641 Tai78641 commented Mar 7, 2025

*Update RescaleOp to use zero-point as operands instead of attributes.
*Check input_zp data type against the input and output_zp data type
against the output.

@llvmbot
Copy link
Member

llvmbot commented Mar 7, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tosa

Author: Tai Ly (Tai78641)

Changes

*Update RescaleOp to use zero-point as operands instead of attributes.
*Check input_zp data type against the input and output_zp data type
against the output.


Patch is 61.46 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/130340.diff

15 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc (+13-10)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+9-2)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+22-8)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+69-34)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp (+2)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir (+4-1)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+45-24)
  • (modified) mlir/test/Dialect/Tosa/availability.mlir (+7-5)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+3-1)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+54-18)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+5-3)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+6-2)
  • (modified) mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir (+15)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+6-2)
  • (modified) mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp (+10-3)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index a9b458acd87f2..08f28a7538c3d 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -216,15 +216,15 @@ profileComplianceMap = {
         {fp32T, fp16T}}}}},
     {"tosa.rescale",
      {{{Profile::pro_int},
-       {{i8T, i8T},
-        {i8T, i16T},
-        {i8T, i32T},
-        {i16T, i8T},
-        {i16T, i16T},
-        {i16T, i32T},
-        {i32T, i8T},
-        {i32T, i16T},
-        {i32T, i32T}}}}},
+       {{i8T, i8T, i8T, i8T},
+        {i8T, i8T, i16T, i16T},
+        {i8T, i8T, i32T, i32T},
+        {i16T, i16T, i8T, i8T},
+        {i16T, i16T, i16T, i16T},
+        {i16T, i16T, i32T, i32T},
+        {i32T, i32T, i8T, i8T},
+        {i32T, i32T, i16T, i16T},
+        {i32T, i32T, i32T, i32T}}}}},
     {"tosa.const",
      {{{Profile::pro_int}, {{boolT}, {i8T}, {i16T}, {i32T}}},
       {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
@@ -384,7 +384,10 @@ extensionComplianceMap = {
         {fp16T, fp8e5m2T},
         {fp32T, fp8e5m2T}}}}},
     {"tosa.rescale",
-     {{{Extension::int16}, {{i48T, i8T}, {i48T, i16T}, {i48T, i32T}}}}},
+     {{{Extension::int16},
+       {{i48T, i48T, i8T, i8T},
+        {i48T, i48T, i16T, i16T},
+        {i48T, i48T, i32T, i32T}}}}},
     {"tosa.const",
      {{{Extension::int4}, {{i4T}}},
       {{Extension::int16}, {{i48T}}},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 097f78cd487ea..cd593a2816355 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2337,8 +2337,8 @@ def Tosa_RescaleOp : Tosa_InferShapedTypeOp<"rescale"> {
     Tosa_Tensor:$input,
     Tosa_1DInt16Or32Tensor:$multiplier,
     Tosa_1DInt8Tensor:$shift,
-    I32Attr:$input_zp,
-    I32Attr:$output_zp,
+    Tosa_ScalarIntOrFloatTensor:$input_zp,
+    Tosa_ScalarIntOrFloatTensor:$output_zp,
     BoolAttr:$scale32,
     BoolAttr:$double_round,
     BoolAttr:$per_channel,
@@ -2355,6 +2355,13 @@ def Tosa_RescaleOp : Tosa_InferShapedTypeOp<"rescale"> {
     Extension<[Tosa_EXT_INT16]>,
   ];
 
+  let extraClassDeclaration = [{
+    FailureOr<int64_t> getInputZeroPoint();
+    FailureOr<int64_t> getOutputZeroPoint();
+    LogicalResult verifyInputZeroPoint(int64_t zp);
+    LogicalResult verifyOutputZeroPoint(int64_t zp);
+  }];
+
   let hasVerifier = 1;
 
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index f7dd33c7e8b53..89af54132f820 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -84,10 +84,9 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
 
 template <typename T>
 static arith::ConstantOp
-createConstFromIntAttribute(Operation *op, const std::string &attrName,
-                            Type requiredAttrType, OpBuilder &rewriter) {
-  auto castedN = static_cast<T>(
-      cast<IntegerAttr>(op->getAttr(attrName)).getValue().getSExtValue());
+createConstOpFromZpVal(Operation *op, const int64_t &zp, Type requiredAttrType,
+                       OpBuilder &rewriter) {
+  auto castedN = static_cast<T>(zp);
   return rewriter.create<arith::ConstantOp>(
       op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
 }
@@ -1491,11 +1490,26 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
           // later.
           int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
 
-          auto inputZp = createConstFromIntAttribute<int32_t>(
-              op, "input_zp", nestedBuilder.getIntegerType(inBitwidth),
+          FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
+          if (failed(maybeIZp)) {
+            (void)rewriter.notifyMatchFailure(
+                op, "input zero point cannot be statically determined");
+            return;
+          }
+
+          auto inputZp = createConstOpFromZpVal<int32_t>(
+              op, *maybeIZp, nestedBuilder.getIntegerType(inBitwidth),
               nestedBuilder);
-          auto outputZp = createConstFromIntAttribute<int32_t>(
-              op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder);
+
+          FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
+          if (failed(maybeOZp)) {
+            (void)rewriter.notifyMatchFailure(
+                op, "output zero point cannot be statically determined");
+            return;
+          };
+
+          auto outputZp = createConstOpFromZpVal<int32_t>(
+              op, *maybeOZp, nestedBuilder.getI32Type(), nestedBuilder);
 
           Value multiplier = multiplierConstant ? multiplierConstant
                                                 : blockArgs[multiplierArg];
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 4711122dc76e2..4d1f0be567d3c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -254,6 +254,27 @@ static Type getStorageElementTypeOrSelf(Type type) {
   return elementType;
 }
 
+static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val,
+                                                  Value valZp, StringRef name) {
+  Type eType = getStorageElementTypeOrSelf(val.getType());
+  Type eZpType = getStorageElementTypeOrSelf(valZp.getType());
+
+  bool bothInts =
+      mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
+  bool bothFloats =
+      mlir::isa<FloatType>(eType) && mlir::isa<FloatType>(eZpType);
+  bool sameBitWidth =
+      (eType.getIntOrFloatBitWidth() == eZpType.getIntOrFloatBitWidth());
+
+  if ((!bothInts && !bothFloats) || !sameBitWidth) {
+    return op->emitOpError()
+           << "expected " << name << " and " << name
+           << "_zp to both be integer or float of the same bitwidth, but got "
+           << eType << " vs. " << eZpType;
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TOSA Operator Verifiers.
 //===----------------------------------------------------------------------===//
@@ -1696,6 +1717,38 @@ static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
   return success();
 }
 
+static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
+                                     const int64_t &zp,
+                                     const std::string &operand) {
+  bool isInputZp = (zpVal == op.getInputZp());
+  bool isOutputZp = (zpVal == op.getOutputZp());
+  if (!isInputZp && !isOutputZp) {
+    return op.emitOpError("internal error: zero-point operand is neither "
+                          "inputZp nor outputZp");
+  }
+
+  bool tensorUnsigned =
+      isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
+  StringRef tensorName = isInputZp ? "input" : "output";
+
+  Type zpElemType = getElementTypeOrSelf(zpVal);
+
+  if (zp != 0) {
+    if (!zpElemType.isInteger(8) &&
+        !(zpElemType.isInteger(16) && tensorUnsigned)) {
+      return op.emitOpError()
+             << "expect " << tensorName << "_zp of 0, got " << zp;
+    }
+    if (zpElemType.isInteger(16) && tensorUnsigned && zp != 32768) {
+      return op.emitOpError() << "expect " << tensorName
+                              << "_zp of 0 or 32768 for unsigned int16 "
+                              << tensorName << ", got " << zp;
+    }
+  }
+
+  return success();
+}
+
 #define ZERO_POINT_HELPER(OP, OPERAND_NAME)                                    \
   FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() {                \
     return getZeroPoint(*this, get##OPERAND_NAME##Zp());                       \
@@ -1714,6 +1767,8 @@ ZERO_POINT_HELPER(TransposeConv2DOp, Input)
 ZERO_POINT_HELPER(TransposeConv2DOp, Weight)
 ZERO_POINT_HELPER(AvgPool2dOp, Input)
 ZERO_POINT_HELPER(AvgPool2dOp, Output)
+ZERO_POINT_HELPER(RescaleOp, Input)
+ZERO_POINT_HELPER(RescaleOp, Output)
 #undef ZERO_POINT_HELPER
 
 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
@@ -2698,41 +2753,21 @@ LogicalResult RescaleOp::verify() {
     return failure();
   }
 
-  auto input_zp = getInputZpAttr().getInt();
-  if (input_zp != 0) {
-    // only int8/uint8 and uint16 input can have non-zero input_zp
-    if (!inputElementType.isInteger(8) &&
-        !(inputElementType.isInteger(16) && getInputUnsigned())) {
-      emitOpError("expect input_zp of 0, got ") << input_zp;
-      return failure();
-    }
-    // input_zp must be either 0 or 32768 for uint16 input
-    if (inputElementType.isInteger(16) && getInputUnsigned() &&
-        input_zp != 32768) {
-      emitOpError(
-          "expect input_zp of 0 or 32768 for unsigned int16 input, got ")
-          << input_zp;
-      return failure();
-    }
-  }
+  if (verifyRescaleValueAndZpTypes(*this, getInput(), getInputZp(), "input")
+          .failed())
+    return failure();
 
-  auto output_zp = getOutputZpAttr().getInt();
-  if (output_zp != 0) {
-    // only int8/uint8 and uint16 output can have non-zero output_zp
-    if (!outputElementType.isInteger(8) &&
-        !(outputElementType.isInteger(16) && getOutputUnsigned())) {
-      emitOpError("expect output_zp of 0, got ") << output_zp;
-      return failure();
-    }
-    // output_zp must be either 0 or 32768 for uint16 output
-    if (outputElementType.isInteger(16) && getOutputUnsigned() &&
-        output_zp != 32768) {
-      emitOpError(
-          "expect output_zp of 0 or 32768 for unsigned int16 output, got ")
-          << output_zp;
-      return failure();
-    }
-  }
+  if (verifyRescaleValueAndZpTypes(*this, getOutput(), getOutputZp(), "output")
+          .failed())
+    return failure();
+
+  FailureOr<int64_t> maybeIZp = getInputZeroPoint();
+  if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
+    return failure();
+
+  FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
+  if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
+    return failure();
 
   auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
   if (!multiplierType) {
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 345616c9563b5..4fba2e5dfde2b 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -175,6 +175,8 @@ void ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
 template <>
 void ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
   addValue(op.getInput());
+  addValue(op.getInputZp());
+  addValue(op.getOutputZp());
   addValue(op.getOutput());
 }
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 77687b83e5e3c..842b50e804cbe 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -35,8 +35,11 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
 func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
   %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<127> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<-1> : tensor<1xi8>} : () -> tensor<1xi8>
+
   // expected-error@+1 {{failed to legalize operation 'tosa.rescale'}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, double_round = false, per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
   return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
 }
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index a3ed8c2805282..49d3e86e8fcd0 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1149,9 +1149,11 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
   // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
   // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
   // CHECK-DAG: linalg.yield [[TRUNC]]
-  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
-  %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
-  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16>} : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<15> : tensor<1xi8>} : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
 
   // CHECK: return
   return
@@ -1182,7 +1184,9 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
   // CHECK: linalg.yield [[TRUNC]]
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
   %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
-  %1 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
+  %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
+  %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
 
   // CHECK: return
   return
@@ -1195,19 +1199,22 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
 // CHECK-LABEL: @rescale_i8_dyn_batch
 // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
 func.func @rescale_i8_dyn_batch(%arg0 : tensor<?x2xi8>) -> () {
-  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
-  %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16>} : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<15> : tensor<1xi8>} : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
+
   // CHECK: %[[C0:.+]] = arith.constant 0
   // CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
   // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xi8>
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8>)
-  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<?x2xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<?x2xi8>
 
   // CHECK: %[[C0:.+]] = arith.constant 0
   // CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
   // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xi8>
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8>)
-  %1 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<?x2xi8>
+  %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<?x2xi8>
 
   return
 }
@@ -1219,15 +1226,19 @@ func.func @rescale_i8_dyn_batch(%arg0 : tensor<?x2xi8>) -> () {
 // CHECK-LABEL: @rescale_dyn
 // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
 func.func @rescale_dyn(%arg0 : tensor<1x?x?x32xi32>) -> () {
+  %multiplier = "tosa.const"() {values = dense<1376784203> : tensor<1xi32>} : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<38> : tensor<1xi8>} : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+
   // CHECK: %[[C1:.+]] = arith.constant 1
   // CHECK: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
   // CHECK: %[[C2:.+]] = arith.constant 2
   // CHECK: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]]
   // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM2]])
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<1x?x?x32xi32>) outs(%[[INIT]] : tensor<1x?x?x32xi8>)
-  %multiplier = "tosa.const"() {values = dense<1376784203> : tensor<1xi32> } : () -> tensor<1xi32>
-  %shift = "tosa.const"() {values = dense<38> : tensor<1xi8> } : () -> tensor<1xi8>
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = true, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x?x?x32xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x?x?x32xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, double_round = true, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1x?x?x32xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x?x?x32xi8>
+
   return
 }
 
@@ -1257,7 +1268,9 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
   // CHECK: linalg.yield [[TRUNC]]
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
   %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
-  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xi8>, tensor<1x...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Mar 7, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Tai Ly (Tai78641)

Changes

*Update RescaleOp to use zero-point as operands instead of attributes.
*Check input_zp data type against the input and output_zp data type
against the output.


Patch is 61.46 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/130340.diff

15 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc (+13-10)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+9-2)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+22-8)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+69-34)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp (+2)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir (+4-1)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+45-24)
  • (modified) mlir/test/Dialect/Tosa/availability.mlir (+7-5)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+3-1)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+54-18)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+5-3)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+6-2)
  • (modified) mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir (+15)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+6-2)
  • (modified) mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp (+10-3)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index a9b458acd87f2..08f28a7538c3d 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -216,15 +216,15 @@ profileComplianceMap = {
         {fp32T, fp16T}}}}},
     {"tosa.rescale",
      {{{Profile::pro_int},
-       {{i8T, i8T},
-        {i8T, i16T},
-        {i8T, i32T},
-        {i16T, i8T},
-        {i16T, i16T},
-        {i16T, i32T},
-        {i32T, i8T},
-        {i32T, i16T},
-        {i32T, i32T}}}}},
+       {{i8T, i8T, i8T, i8T},
+        {i8T, i8T, i16T, i16T},
+        {i8T, i8T, i32T, i32T},
+        {i16T, i16T, i8T, i8T},
+        {i16T, i16T, i16T, i16T},
+        {i16T, i16T, i32T, i32T},
+        {i32T, i32T, i8T, i8T},
+        {i32T, i32T, i16T, i16T},
+        {i32T, i32T, i32T, i32T}}}}},
     {"tosa.const",
      {{{Profile::pro_int}, {{boolT}, {i8T}, {i16T}, {i32T}}},
       {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
@@ -384,7 +384,10 @@ extensionComplianceMap = {
         {fp16T, fp8e5m2T},
         {fp32T, fp8e5m2T}}}}},
     {"tosa.rescale",
-     {{{Extension::int16}, {{i48T, i8T}, {i48T, i16T}, {i48T, i32T}}}}},
+     {{{Extension::int16},
+       {{i48T, i48T, i8T, i8T},
+        {i48T, i48T, i16T, i16T},
+        {i48T, i48T, i32T, i32T}}}}},
     {"tosa.const",
      {{{Extension::int4}, {{i4T}}},
       {{Extension::int16}, {{i48T}}},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 097f78cd487ea..cd593a2816355 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2337,8 +2337,8 @@ def Tosa_RescaleOp : Tosa_InferShapedTypeOp<"rescale"> {
     Tosa_Tensor:$input,
     Tosa_1DInt16Or32Tensor:$multiplier,
     Tosa_1DInt8Tensor:$shift,
-    I32Attr:$input_zp,
-    I32Attr:$output_zp,
+    Tosa_ScalarIntOrFloatTensor:$input_zp,
+    Tosa_ScalarIntOrFloatTensor:$output_zp,
     BoolAttr:$scale32,
     BoolAttr:$double_round,
     BoolAttr:$per_channel,
@@ -2355,6 +2355,13 @@ def Tosa_RescaleOp : Tosa_InferShapedTypeOp<"rescale"> {
     Extension<[Tosa_EXT_INT16]>,
   ];
 
+  let extraClassDeclaration = [{
+    FailureOr<int64_t> getInputZeroPoint();
+    FailureOr<int64_t> getOutputZeroPoint();
+    LogicalResult verifyInputZeroPoint(int64_t zp);
+    LogicalResult verifyOutputZeroPoint(int64_t zp);
+  }];
+
   let hasVerifier = 1;
 
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index f7dd33c7e8b53..89af54132f820 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -84,10 +84,9 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
 
 template <typename T>
 static arith::ConstantOp
-createConstFromIntAttribute(Operation *op, const std::string &attrName,
-                            Type requiredAttrType, OpBuilder &rewriter) {
-  auto castedN = static_cast<T>(
-      cast<IntegerAttr>(op->getAttr(attrName)).getValue().getSExtValue());
+createConstOpFromZpVal(Operation *op, const int64_t &zp, Type requiredAttrType,
+                       OpBuilder &rewriter) {
+  auto castedN = static_cast<T>(zp);
   return rewriter.create<arith::ConstantOp>(
       op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
 }
@@ -1491,11 +1490,26 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
           // later.
           int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
 
-          auto inputZp = createConstFromIntAttribute<int32_t>(
-              op, "input_zp", nestedBuilder.getIntegerType(inBitwidth),
+          FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
+          if (failed(maybeIZp)) {
+            (void)rewriter.notifyMatchFailure(
+                op, "input zero point cannot be statically determined");
+            return;
+          }
+
+          auto inputZp = createConstOpFromZpVal<int32_t>(
+              op, *maybeIZp, nestedBuilder.getIntegerType(inBitwidth),
               nestedBuilder);
-          auto outputZp = createConstFromIntAttribute<int32_t>(
-              op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder);
+
+          FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
+          if (failed(maybeOZp)) {
+            (void)rewriter.notifyMatchFailure(
+                op, "output zero point cannot be statically determined");
+            return;
+          };
+
+          auto outputZp = createConstOpFromZpVal<int32_t>(
+              op, *maybeOZp, nestedBuilder.getI32Type(), nestedBuilder);
 
           Value multiplier = multiplierConstant ? multiplierConstant
                                                 : blockArgs[multiplierArg];
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 4711122dc76e2..4d1f0be567d3c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -254,6 +254,27 @@ static Type getStorageElementTypeOrSelf(Type type) {
   return elementType;
 }
 
+static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val,
+                                                  Value valZp, StringRef name) {
+  Type eType = getStorageElementTypeOrSelf(val.getType());
+  Type eZpType = getStorageElementTypeOrSelf(valZp.getType());
+
+  bool bothInts =
+      mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
+  bool bothFloats =
+      mlir::isa<FloatType>(eType) && mlir::isa<FloatType>(eZpType);
+  bool sameBitWidth =
+      (eType.getIntOrFloatBitWidth() == eZpType.getIntOrFloatBitWidth());
+
+  if ((!bothInts && !bothFloats) || !sameBitWidth) {
+    return op->emitOpError()
+           << "expected " << name << " and " << name
+           << "_zp to both be integer or float of the same bitwidth, but got "
+           << eType << " vs. " << eZpType;
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TOSA Operator Verifiers.
 //===----------------------------------------------------------------------===//
@@ -1696,6 +1717,38 @@ static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
   return success();
 }
 
+static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
+                                     const int64_t &zp,
+                                     const std::string &operand) {
+  bool isInputZp = (zpVal == op.getInputZp());
+  bool isOutputZp = (zpVal == op.getOutputZp());
+  if (!isInputZp && !isOutputZp) {
+    return op.emitOpError("internal error: zero-point operand is neither "
+                          "inputZp nor outputZp");
+  }
+
+  bool tensorUnsigned =
+      isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
+  StringRef tensorName = isInputZp ? "input" : "output";
+
+  Type zpElemType = getElementTypeOrSelf(zpVal);
+
+  if (zp != 0) {
+    if (!zpElemType.isInteger(8) &&
+        !(zpElemType.isInteger(16) && tensorUnsigned)) {
+      return op.emitOpError()
+             << "expect " << tensorName << "_zp of 0, got " << zp;
+    }
+    if (zpElemType.isInteger(16) && tensorUnsigned && zp != 32768) {
+      return op.emitOpError() << "expect " << tensorName
+                              << "_zp of 0 or 32768 for unsigned int16 "
+                              << tensorName << ", got " << zp;
+    }
+  }
+
+  return success();
+}
+
 #define ZERO_POINT_HELPER(OP, OPERAND_NAME)                                    \
   FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() {                \
     return getZeroPoint(*this, get##OPERAND_NAME##Zp());                       \
@@ -1714,6 +1767,8 @@ ZERO_POINT_HELPER(TransposeConv2DOp, Input)
 ZERO_POINT_HELPER(TransposeConv2DOp, Weight)
 ZERO_POINT_HELPER(AvgPool2dOp, Input)
 ZERO_POINT_HELPER(AvgPool2dOp, Output)
+ZERO_POINT_HELPER(RescaleOp, Input)
+ZERO_POINT_HELPER(RescaleOp, Output)
 #undef ZERO_POINT_HELPER
 
 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
@@ -2698,41 +2753,21 @@ LogicalResult RescaleOp::verify() {
     return failure();
   }
 
-  auto input_zp = getInputZpAttr().getInt();
-  if (input_zp != 0) {
-    // only int8/uint8 and uint16 input can have non-zero input_zp
-    if (!inputElementType.isInteger(8) &&
-        !(inputElementType.isInteger(16) && getInputUnsigned())) {
-      emitOpError("expect input_zp of 0, got ") << input_zp;
-      return failure();
-    }
-    // input_zp must be either 0 or 32768 for uint16 input
-    if (inputElementType.isInteger(16) && getInputUnsigned() &&
-        input_zp != 32768) {
-      emitOpError(
-          "expect input_zp of 0 or 32768 for unsigned int16 input, got ")
-          << input_zp;
-      return failure();
-    }
-  }
+  if (verifyRescaleValueAndZpTypes(*this, getInput(), getInputZp(), "input")
+          .failed())
+    return failure();
 
-  auto output_zp = getOutputZpAttr().getInt();
-  if (output_zp != 0) {
-    // only int8/uint8 and uint16 output can have non-zero output_zp
-    if (!outputElementType.isInteger(8) &&
-        !(outputElementType.isInteger(16) && getOutputUnsigned())) {
-      emitOpError("expect output_zp of 0, got ") << output_zp;
-      return failure();
-    }
-    // output_zp must be either 0 or 32768 for uint16 output
-    if (outputElementType.isInteger(16) && getOutputUnsigned() &&
-        output_zp != 32768) {
-      emitOpError(
-          "expect output_zp of 0 or 32768 for unsigned int16 output, got ")
-          << output_zp;
-      return failure();
-    }
-  }
+  if (verifyRescaleValueAndZpTypes(*this, getOutput(), getOutputZp(), "output")
+          .failed())
+    return failure();
+
+  FailureOr<int64_t> maybeIZp = getInputZeroPoint();
+  if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
+    return failure();
+
+  FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
+  if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
+    return failure();
 
   auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
   if (!multiplierType) {
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 345616c9563b5..4fba2e5dfde2b 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -175,6 +175,8 @@ void ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
 template <>
 void ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
   addValue(op.getInput());
+  addValue(op.getInputZp());
+  addValue(op.getOutputZp());
   addValue(op.getOutput());
 }
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 77687b83e5e3c..842b50e804cbe 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -35,8 +35,11 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
 func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
   %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<127> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<-1> : tensor<1xi8>} : () -> tensor<1xi8>
+
   // expected-error@+1 {{failed to legalize operation 'tosa.rescale'}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, double_round = false, per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
   return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
 }
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index a3ed8c2805282..49d3e86e8fcd0 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1149,9 +1149,11 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
   // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
   // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
   // CHECK-DAG: linalg.yield [[TRUNC]]
-  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
-  %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
-  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16>} : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<15> : tensor<1xi8>} : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
 
   // CHECK: return
   return
@@ -1182,7 +1184,9 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
   // CHECK: linalg.yield [[TRUNC]]
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
   %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
-  %1 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
+  %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
+  %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
 
   // CHECK: return
   return
@@ -1195,19 +1199,22 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
 // CHECK-LABEL: @rescale_i8_dyn_batch
 // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
 func.func @rescale_i8_dyn_batch(%arg0 : tensor<?x2xi8>) -> () {
-  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
-  %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16>} : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<15> : tensor<1xi8>} : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
+
   // CHECK: %[[C0:.+]] = arith.constant 0
   // CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
   // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xi8>
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8>)
-  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<?x2xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<?x2xi8>
 
   // CHECK: %[[C0:.+]] = arith.constant 0
   // CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
   // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xi8>
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8>)
-  %1 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<?x2xi8>
+  %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<?x2xi8>
 
   return
 }
@@ -1219,15 +1226,19 @@ func.func @rescale_i8_dyn_batch(%arg0 : tensor<?x2xi8>) -> () {
 // CHECK-LABEL: @rescale_dyn
 // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
 func.func @rescale_dyn(%arg0 : tensor<1x?x?x32xi32>) -> () {
+  %multiplier = "tosa.const"() {values = dense<1376784203> : tensor<1xi32>} : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<38> : tensor<1xi8>} : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+
   // CHECK: %[[C1:.+]] = arith.constant 1
   // CHECK: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
   // CHECK: %[[C2:.+]] = arith.constant 2
   // CHECK: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]]
   // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM2]])
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<1x?x?x32xi32>) outs(%[[INIT]] : tensor<1x?x?x32xi8>)
-  %multiplier = "tosa.const"() {values = dense<1376784203> : tensor<1xi32> } : () -> tensor<1xi32>
-  %shift = "tosa.const"() {values = dense<38> : tensor<1xi8> } : () -> tensor<1xi8>
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = true, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x?x?x32xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x?x?x32xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, double_round = true, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1x?x?x32xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x?x?x32xi8>
+
   return
 }
 
@@ -1257,7 +1268,9 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
   // CHECK: linalg.yield [[TRUNC]]
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
   %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
-  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xi8>, tensor<1x...
[truncated]

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

 *Update RescaleOp to use zero-point as operands instead of attributes.
 *Check input_zp data type against the input and output_zp data type
   against the output.

Change-Id: I2cf0106eb9f9ec88e16de5efc93b651053e5fc92
Signed-off-by: Peng Sun <[email protected]>
@Tai78641
Copy link
Contributor Author

rebased and resolved lit tests merge issues

@GeorgeARM GeorgeARM merged commit 913d077 into llvm:main Mar 11, 2025
11 checks passed
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Mar 21, 2025
Changes needed for llvm/llvm-project#130337,
llvm/llvm-project#129720,
and llvm/llvm-project#130340

FUTURE_COPYBARA_INTEGRATE_REVIEW=#23793 from Tixxx:tixxx/cp_sync 3a8a643
PiperOrigin-RevId: 738783395
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 21, 2025
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Mar 21, 2025
Changes needed for llvm/llvm-project#130337,
llvm/llvm-project#129720,
and llvm/llvm-project#130340

FUTURE_COPYBARA_INTEGRATE_REVIEW=#23793 from Tixxx:tixxx/cp_sync 3a8a643
PiperOrigin-RevId: 738783395
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 21, 2025
copybara-service bot pushed a commit to openxla/shardy that referenced this pull request Mar 21, 2025
copybara-service bot pushed a commit to tensorflow/mlir-hlo that referenced this pull request Mar 21, 2025
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants