Skip to content

[mlir][tosa] Convert RESCALE op multiplier and shift from attributes to inputs #129720

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 7, 2025

Conversation

psunn
Copy link
Contributor

@psunn psunn commented Mar 4, 2025

This patch updates the TOSA RescaleOp by converting its multiplier and shift parameters from attributes to explicit inputs, aligning the op with the TOSA V1.0 specification.

Additionally, this commit adds RescaleOp-specific implementations of inferReturnTypeComponents and verify functions.

@llvmbot
Copy link
Member

llvmbot commented Mar 4, 2025

@llvm/pr-subscribers-mlir-tosa
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Peng Sun (psunn)

Changes

This patch updates the TOSA RescaleOp by converting its multiplier and shift parameters from attributes to explicit inputs, aligning the op with the TOSA V1.0 specification.

Additionally, this commit adds RescaleOp-specific implementations of inferReturnTypeComponents and verify functions.


Patch is 50.08 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129720.diff

15 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+5-5)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+26-6)
  • (modified) mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h (+12)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+18-2)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+133-1)
  • (modified) mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp (+39)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir (+3-1)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+25-9)
  • (modified) mlir/test/Dialect/Tosa/availability.mlir (+1-1)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+3-1)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+181)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+3-1)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+6-2)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+7-3)
  • (modified) mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp (+11-9)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 85bd3fb1bb1cc..43ef79e6f63f0 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2262,9 +2262,7 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
 //===----------------------------------------------------------------------===//
 // Operator: rescale
 //===----------------------------------------------------------------------===//
-def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
-      DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>]> {
+def Tosa_RescaleOp: Tosa_InferShapedTypeOp<"rescale"> {
   let summary = "Tosa rescale operator";
 
   let description = [{
@@ -2290,10 +2288,10 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
 
   let arguments = (ins
     Tosa_Tensor:$input,
+    Tosa_1DInt16Or32Tensor:$multiplier,
+    Tosa_1DInt8Tensor:$shift,
     I32Attr:$input_zp,
     I32Attr:$output_zp,
-    DenseI32ArrayAttr:$multiplier,
-    DenseI8ArrayAttr:$shift,
     BoolAttr:$scale32,
     BoolAttr:$double_round,
     BoolAttr:$per_channel,
@@ -2310,6 +2308,8 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
     Extension<[Tosa_EXT_INT16]>,
   ];
 
+  let hasVerifier = 1;
+
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
 }
 
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index cf6ddc66f4ada..881e423d0950e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -43,6 +43,7 @@ class Tosa_QuantizedType<string n, list<int> params, bit signed>
 
 def Tosa_Int4 : I<4>;
 def Tosa_Int8 : I<8>;
+def Tosa_Int16 : I<16>;
 def Tosa_Int32 : I<32>;
 def Tosa_Int64 : I<64>;
 
@@ -54,7 +55,10 @@ def Tosa_Int : AnyTypeOf<[AnyUnsignedInteger,
                           AnySignlessInteger]>;
 
 def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32,
-                   	        Tosa_Int64]>;
+                          Tosa_Int64]>;
+
+def Tosa_Int16Or32 : AnyTypeOf<[Tosa_Int16,
+                          Tosa_Int32]>;
 
 //===----------------------------------------------------------------------===//
 // Quantized Integer Types.
@@ -68,11 +72,23 @@ def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32,
 // int8  : symmetric  per tensor/per channel, signed
 // int16 : symmetric  per tensor,             signed
 //===----------------------------------------------------------------------===//
-def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>,
-                                   Tosa_QuantizedType<"int4", [4, 0], 1>,
-                                   Tosa_QuantizedType<"int8", [8, 0], 1>,
-                                   Tosa_QuantizedType<"int16", [16, 0], 1>,
-                                   Tosa_QuantizedType<"int32", [32, 0], 1>]>;
+def Tosa_QuantizedInt	: AnyTypeOf<[ Tosa_QuantizedType<"uint8", [8], 0>,
+                                     Tosa_QuantizedType<"int4", [4, 0], 1>,
+                                     Tosa_QuantizedType<"int8", [8, 0], 1>,
+                                     Tosa_QuantizedType<"int16", [16, 0], 1>,
+                                     Tosa_QuantizedType<"int32", [32, 0], 1>]>;
+
+//===----------------------------------------------------------------------===//
+// Floating-point types.
+//===----------------------------------------------------------------------===//
+def Tosa_Float : AnyTypeOf<[
+                            F32,
+                            F16,
+                            BF16]>;
+
+def Tosa_F8 : AnyTypeOf<[
+                        F8E4M3FN,
+                        F8E5M2]>;
 
 //===----------------------------------------------------------------------===//
 // Multi-category types.
@@ -162,6 +178,10 @@ def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNu
 def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [4]>], "4-d tosa-conformant tensor", "::mlir::TensorType">;
 def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tosa-conformant tensor", "::mlir::TensorType">;
 
+// 1D tensor of specific types
+def Tosa_1DInt8Tensor : 1DTensorOf<[Tosa_Int8]>;
+def Tosa_1DInt16Or32Tensor : 1DTensorOf<[Tosa_Int16Or32]>;
+
 // Ranked tensors up to given rank.
 def Tosa_Tensor1Dto4D : AnyTypeOf<[
   Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>;
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
index 10dc5dd36cfa9..0124b37b018a7 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
@@ -31,6 +31,18 @@ namespace tosa {
 void computeMultiplierAndShift(double scale, int32_t &multiplier,
                                int32_t &shift, int32_t scaleWidth);
 
+// Return a const value for array of int8 vec
+Value getConstTensorInt8(OpBuilder &builder, Location loc,
+                         ArrayRef<int8_t> vec);
+
+// Return a const value for array of int16 vec
+Value getConstTensorInt16(OpBuilder &builder, Location loc,
+                          ArrayRef<int16_t> vec);
+
+// Return a const value for array of int32 vec
+Value getConstTensorInt32(OpBuilder &builder, Location loc,
+                          ArrayRef<int32_t> vec);
+
 //// Builds ConvOpQuantizationAttr from input and weight.
 ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder,
                                                    Value input, Value weight);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 8732ddafa24d4..5d1344d17ead0 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1389,8 +1389,24 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
     }
 
     // The shift and multiplier values.
-    SmallVector<int32_t> multiplierValues(op.getMultiplier());
-    SmallVector<int8_t> shiftValues(op.getShift());
+    ElementsAttr shiftElems;
+    if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
+      return rewriter.notifyMatchFailure(
+          op, "tosa.rescale requires constant shift input values");
+
+    ElementsAttr multiplierElems;
+    if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
+      return rewriter.notifyMatchFailure(
+          op, "tosa.rescale requires constant multiplier input values");
+
+    SmallVector<int8_t> shiftValues;
+    for (auto idx : shiftElems.getValues<IntegerAttr>()) {
+      shiftValues.push_back(static_cast<int8_t>(idx.getInt()));
+    }
+    SmallVector<int32_t> multiplierValues;
+    for (auto idx : multiplierElems.getValues<IntegerAttr>()) {
+      multiplierValues.push_back(static_cast<int32_t>(idx.getInt()));
+    }
 
     // If we shift by more than the bitwidth, this just sets to 0.
     for (int i = 0, s = multiplierValues.size(); i < s; i++) {
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 1050f3f30fe98..12efce82942b4 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2024,7 +2024,6 @@ NARY_SHAPE_INFER(tosa::MinimumOp)
 NARY_SHAPE_INFER(tosa::NegateOp)
 NARY_SHAPE_INFER(tosa::PowOp)
 NARY_SHAPE_INFER(tosa::ReciprocalOp)
-NARY_SHAPE_INFER(tosa::RescaleOp)
 NARY_SHAPE_INFER(tosa::ReverseOp)
 NARY_SHAPE_INFER(tosa::RsqrtOp)
 NARY_SHAPE_INFER(tosa::SinOp)
@@ -2469,6 +2468,139 @@ LogicalResult TransposeConv2DOp::verify() {
   return success();
 }
 
+LogicalResult RescaleOp::verify() {
+  auto inputType = llvm::dyn_cast<ShapedType>(getInput().getType());
+  if (!inputType) {
+    emitOpError("expect shaped tensor for input, got ") << getInput().getType();
+    return failure();
+  }
+  auto inputElementType = inputType.getElementType();
+  if (auto inputQType =
+          llvm::dyn_cast<quant::QuantizedType>(inputElementType)) {
+    inputElementType = inputQType.getStorageType();
+  }
+  if (!mlir::isa<IntegerType>(inputElementType)) {
+    emitOpError("expect input to have integer element type, got ")
+        << inputElementType;
+    return failure();
+  }
+  auto outputType = llvm::dyn_cast<ShapedType>(getOutput().getType());
+  if (!outputType) {
+    emitOpError("expect shaped tensor for output, got ")
+        << getOutput().getType();
+    return failure();
+  }
+  auto outputElementType = outputType.getElementType();
+  if (auto outputQType =
+          llvm::dyn_cast<quant::QuantizedType>(outputElementType)) {
+    outputElementType = outputQType.getStorageType();
+  }
+  if (!mlir::isa<IntegerType>(outputElementType)) {
+    emitOpError("expect output to have integer element type, got ")
+        << outputElementType;
+    return failure();
+  }
+  auto input_zp = getInputZpAttr().getInt();
+  if (input_zp != 0) {
+    // only int8/uint8 and uint16 input can have non-zero input_zp
+    if (!inputElementType.isInteger(8) &&
+        !(inputElementType.isInteger(16) && getInputUnsigned())) {
+      emitOpError("expect input_zp of 0, got ") << input_zp;
+      return failure();
+    }
+    // input_zp must be either 0 or 32768 for uint16 input
+    if (inputElementType.isInteger(16) && getInputUnsigned() &&
+        input_zp != 32768) {
+      emitOpError(
+          "expect input_zp of 0 or 32768 for unsigned int16 input, got ")
+          << input_zp;
+      return failure();
+    }
+  }
+
+  auto output_zp = getOutputZpAttr().getInt();
+  if (output_zp != 0) {
+    // only int8/uint8 and uint16 output can have non-zero output_zp
+    if (!outputElementType.isInteger(8) &&
+        !(outputElementType.isInteger(16) && getOutputUnsigned())) {
+      emitOpError("expect output_zp of 0, got ") << output_zp;
+      return failure();
+    }
+    // output_zp must be either 0 or 32768 for uint16 output
+    if (outputElementType.isInteger(16) && getOutputUnsigned() &&
+        output_zp != 32768) {
+      emitOpError(
+          "expect output_zp of 0 or 32768 for unsigned int16 output, got ")
+          << output_zp;
+      return failure();
+    }
+  }
+
+  auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
+  if (!multiplierType) {
+    emitOpError("expect shaped tensor for multiplier, got ")
+        << getMultiplier().getType();
+    return failure();
+  }
+  auto shiftType = llvm::dyn_cast<ShapedType>(getShift().getType());
+  if (!shiftType) {
+    emitOpError("expect shaped tensor for shift, got ") << getShift().getType();
+    return failure();
+  }
+
+  // multiplier element type must be i32 for scale32 = true
+  if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
+    emitOpError("expect i32 element type for multiplier for scale32=true, got ")
+        << multiplierType.getElementType();
+    return failure();
+  }
+
+  // multiplier element type must be i16 for scale32 = false
+  if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
+    emitOpError(
+        "expect i16 element type for multiplier for scale32=false, got ")
+        << multiplierType.getElementType();
+    return failure();
+  }
+
+  // multiplier/shift must have shape = {numChannels},
+  // where numChannel is 1 if per_channel = false
+  // otherwise numChannel is dimension in input shape's last axis
+  int64_t numChannels = 1;
+  if (getPerChannel()) {
+    ArrayRef<int64_t> inputShape = inputType.getShape();
+    numChannels = inputShape[inputShape.size() - 1];
+  }
+
+  ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
+  // multiplier input has rank 1 by dialect definition
+  if (multiplierShape[0] != numChannels) {
+    emitOpError("expect shape of { ")
+        << numChannels << " } for multiplier input, got { "
+        << multiplierShape[0] << " }";
+    return failure();
+  }
+
+  ArrayRef<int64_t> shiftShape = shiftType.getShape();
+  // shift input has rank 1 by dialect definition
+  if (shiftShape[0] != numChannels) {
+    emitOpError("expect shape of { ")
+        << numChannels << " } for shift input, got { " << shiftShape[0] << " }";
+    return failure();
+  }
+
+  return success();
+}
+
+LogicalResult RescaleOp::inferReturnTypeComponents(
+    MLIRContext *context, ::std::optional<Location> location,
+    RescaleOp::Adaptor adaptor,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  ShapeAdaptor inputShape(adaptor.getInput().getType());
+  inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
+  return success();
+}
+
 LogicalResult IfOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     IfOp::Adaptor adaptor,
diff --git a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
index 0f7562767001c..5c6bb8921491e 100644
--- a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
@@ -107,6 +107,45 @@ void mlir::tosa::computeMultiplierAndShift(double scale, int32_t &multiplier,
   }
 }
 
+// Return a const value for array of int8 vec
+Value mlir::tosa::getConstTensorInt8(OpBuilder &builder, Location loc,
+                                     ArrayRef<int8_t> vec) {
+  int64_t count = vec.size();
+  assert(count > 0);
+  auto element_type = builder.getI8Type();
+  mlir::RankedTensorType const_type =
+      RankedTensorType::get({count}, element_type);
+  mlir::DenseElementsAttr const_attr = DenseElementsAttr::get(const_type, vec);
+  auto const_op = builder.create<tosa::ConstOp>(loc, const_type, const_attr);
+  return const_op.getResult();
+}
+
+// Return a const value for array of int16 vec
+Value mlir::tosa::getConstTensorInt16(OpBuilder &builder, Location loc,
+                                      ArrayRef<int16_t> vec) {
+  int64_t count = vec.size();
+  assert(count > 0);
+  auto element_type = builder.getI16Type();
+  mlir::RankedTensorType const_type =
+      RankedTensorType::get({count}, element_type);
+  mlir::DenseElementsAttr const_attr = DenseElementsAttr::get(const_type, vec);
+  auto const_op = builder.create<tosa::ConstOp>(loc, const_type, const_attr);
+  return const_op.getResult();
+}
+
+// Return a const value for array of int32 vec
+Value mlir::tosa::getConstTensorInt32(OpBuilder &builder, Location loc,
+                                      ArrayRef<int32_t> vec) {
+  int64_t count = vec.size();
+  assert(count > 0);
+  auto element_type = builder.getI32Type();
+  mlir::RankedTensorType const_type =
+      RankedTensorType::get({count}, element_type);
+  mlir::DenseElementsAttr const_attr = DenseElementsAttr::get(const_type, vec);
+  auto const_op = builder.create<tosa::ConstOp>(loc, const_type, const_attr);
+  return const_op.getResult();
+}
+
 #define GET_UQTYPE(inputType)                                                  \
   (llvm::dyn_cast<quant::UniformQuantizedType>((inputType).getElementType()))
 #define GET_QTYPE(inputType)                                                   \
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index afc1d5c609181..ddaaf5fcf7120 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -33,8 +33,10 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
 
 // CHECK-LABEL: @rescale_unsupported_type
 func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
+  %multiplier = "tosa.const"() {value = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
   // expected-error@+1 {{failed to legalize operation 'tosa.rescale'}}
-  %0 = tosa.rescale %arg0 {double_round = false, input_zp = 127 : i32, multiplier = array<i32: 1073741824>, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array<i8: 30>, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
   return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
 }
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 6ca260a5324a9..16a49376e58f5 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1149,7 +1149,9 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
   // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
   // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
   // CHECK-DAG: linalg.yield [[TRUNC]]
-  %0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>) -> tensor<2xi8>
+  %multiplier = "tosa.const"() {value = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {value = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
 
   // CHECK: return
   return
@@ -1178,7 +1180,9 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
   // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
   // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
   // CHECK: linalg.yield [[TRUNC]]
-  %1 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>) -> tensor<2xi8>
+  %multiplier = "tosa.const"() {value = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {value = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+  %1 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
 
   // CHECK: return
   return
@@ -1191,17 +1195,19 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
 // CHECK-LABEL: @rescale_i8_dyn_batch
 // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
 func.func @rescale_i8_dyn_batch(%arg0 : tensor<?x2xi8>) -> () {
+  %multiplier = "tosa.const"() {value = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {value = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
   // CHECK: %[[C0:.+]] = arith.constant 0
   // CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
   // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xi8>
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8>)
-  %0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<?x2xi8>) -> tensor<?x2xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>) -> te...
[truncated]

@Jerry-Ge Jerry-Ge requested review from GeorgeARM and lhutton1 March 4, 2025 16:02
@psunn psunn force-pushed the psunn/RESCALE_multiplier_shift branch from 64f54d3 to 80d7b4e Compare March 5, 2025 19:13
@psunn psunn requested a review from GeorgeARM March 5, 2025 19:20
@psunn psunn force-pushed the psunn/RESCALE_multiplier_shift branch 3 times, most recently from e329cbb to 1081656 Compare March 5, 2025 19:46
Copy link
Contributor

@GeorgeARM GeorgeARM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Leaving to @lhutton1 for approval

@psunn psunn force-pushed the psunn/RESCALE_multiplier_shift branch from 1081656 to 80119ff Compare March 6, 2025 16:26
@psunn psunn requested a review from lhutton1 March 6, 2025 16:30
@psunn
Copy link
Contributor Author

psunn commented Mar 6, 2025

Thank you, @GeorgeARM and @lhutton1, for reviewing!

@psunn psunn force-pushed the psunn/RESCALE_multiplier_shift branch from 80119ff to c836926 Compare March 6, 2025 22:24
@psunn psunn requested review from lhutton1 and Jerry-Ge March 6, 2025 22:26
Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates, LGTM! Since #129943 was merged we need to update const value to values

This patch updates the TOSA RescaleOp by converting its multiplier and
shift parameters from attributes to explicit inputs, aligning the op
with the TOSA V1.0 specification.
Additionally, this commit adds RescaleOp-specific implementations of
inferReturnTypeComponents and verify functions.

Co-authored-by: Tai Ly <[email protected]>
Signed-off-by: Peng Sun <[email protected]>
Change-Id: I9e21bf757e736dabea5a2e77398e1b8a268b8ee9
@psunn psunn force-pushed the psunn/RESCALE_multiplier_shift branch from c836926 to c30014d Compare March 7, 2025 10:48
@psunn psunn requested a review from GeorgeARM March 7, 2025 10:49
@GeorgeARM GeorgeARM merged commit 5685def into llvm:main Mar 7, 2025
11 checks passed
@psunn psunn deleted the psunn/RESCALE_multiplier_shift branch March 7, 2025 18:07
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Mar 20, 2025
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 20, 2025
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Mar 21, 2025
Changes needed for llvm/llvm-project#130337,
llvm/llvm-project#129720,
and llvm/llvm-project#130340

FUTURE_COPYBARA_INTEGRATE_REVIEW=#23793 from Tixxx:tixxx/cp_sync 3a8a643
PiperOrigin-RevId: 738783395
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 21, 2025
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Mar 21, 2025
Changes needed for llvm/llvm-project#130337,
llvm/llvm-project#129720,
and llvm/llvm-project#130340

FUTURE_COPYBARA_INTEGRATE_REVIEW=#23793 from Tixxx:tixxx/cp_sync 3a8a643
PiperOrigin-RevId: 738783395
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 21, 2025
copybara-service bot pushed a commit to openxla/shardy that referenced this pull request Mar 21, 2025
copybara-service bot pushed a commit to tensorflow/mlir-hlo that referenced this pull request Mar 21, 2025
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 21, 2025
jph-13 pushed a commit to jph-13/llvm-project that referenced this pull request Mar 21, 2025
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Mar 21, 2025
alekstheod pushed a commit to ROCm/xla that referenced this pull request Apr 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants