-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tosa] Convert RESCALE op multiplier and shift from attributes to inputs #129720
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir Author: Peng Sun (psunn) ChangesThis patch updates the TOSA RescaleOp by converting its multiplier and shift parameters from attributes to explicit inputs, aligning the op with the TOSA V1.0 specification. Additionally, this commit adds RescaleOp-specific implementations of inferReturnTypeComponents and verify functions. Patch is 50.08 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129720.diff 15 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 85bd3fb1bb1cc..43ef79e6f63f0 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2262,9 +2262,7 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
//===----------------------------------------------------------------------===//
// Operator: rescale
//===----------------------------------------------------------------------===//
-def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>]> {
+def Tosa_RescaleOp: Tosa_InferShapedTypeOp<"rescale"> {
let summary = "Tosa rescale operator";
let description = [{
@@ -2290,10 +2288,10 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
let arguments = (ins
Tosa_Tensor:$input,
+ Tosa_1DInt16Or32Tensor:$multiplier,
+ Tosa_1DInt8Tensor:$shift,
I32Attr:$input_zp,
I32Attr:$output_zp,
- DenseI32ArrayAttr:$multiplier,
- DenseI8ArrayAttr:$shift,
BoolAttr:$scale32,
BoolAttr:$double_round,
BoolAttr:$per_channel,
@@ -2310,6 +2308,8 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
Extension<[Tosa_EXT_INT16]>,
];
+ let hasVerifier = 1;
+
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index cf6ddc66f4ada..881e423d0950e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -43,6 +43,7 @@ class Tosa_QuantizedType<string n, list<int> params, bit signed>
def Tosa_Int4 : I<4>;
def Tosa_Int8 : I<8>;
+def Tosa_Int16 : I<16>;
def Tosa_Int32 : I<32>;
def Tosa_Int64 : I<64>;
@@ -54,7 +55,10 @@ def Tosa_Int : AnyTypeOf<[AnyUnsignedInteger,
AnySignlessInteger]>;
def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32,
- Tosa_Int64]>;
+ Tosa_Int64]>;
+
+def Tosa_Int16Or32 : AnyTypeOf<[Tosa_Int16,
+ Tosa_Int32]>;
//===----------------------------------------------------------------------===//
// Quantized Integer Types.
@@ -68,11 +72,23 @@ def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32,
// int8 : symmetric per tensor/per channel, signed
// int16 : symmetric per tensor, signed
//===----------------------------------------------------------------------===//
-def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>,
- Tosa_QuantizedType<"int4", [4, 0], 1>,
- Tosa_QuantizedType<"int8", [8, 0], 1>,
- Tosa_QuantizedType<"int16", [16, 0], 1>,
- Tosa_QuantizedType<"int32", [32, 0], 1>]>;
+def Tosa_QuantizedInt : AnyTypeOf<[ Tosa_QuantizedType<"uint8", [8], 0>,
+ Tosa_QuantizedType<"int4", [4, 0], 1>,
+ Tosa_QuantizedType<"int8", [8, 0], 1>,
+ Tosa_QuantizedType<"int16", [16, 0], 1>,
+ Tosa_QuantizedType<"int32", [32, 0], 1>]>;
+
+//===----------------------------------------------------------------------===//
+// Floating-point types.
+//===----------------------------------------------------------------------===//
+def Tosa_Float : AnyTypeOf<[
+ F32,
+ F16,
+ BF16]>;
+
+def Tosa_F8 : AnyTypeOf<[
+ F8E4M3FN,
+ F8E5M2]>;
//===----------------------------------------------------------------------===//
// Multi-category types.
@@ -162,6 +178,10 @@ def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNu
def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [4]>], "4-d tosa-conformant tensor", "::mlir::TensorType">;
def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tosa-conformant tensor", "::mlir::TensorType">;
+// 1D tensor of specific types
+def Tosa_1DInt8Tensor : 1DTensorOf<[Tosa_Int8]>;
+def Tosa_1DInt16Or32Tensor : 1DTensorOf<[Tosa_Int16Or32]>;
+
// Ranked tensors up to given rank.
def Tosa_Tensor1Dto4D : AnyTypeOf<[
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>;
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
index 10dc5dd36cfa9..0124b37b018a7 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
@@ -31,6 +31,18 @@ namespace tosa {
void computeMultiplierAndShift(double scale, int32_t &multiplier,
int32_t &shift, int32_t scaleWidth);
+// Return a const value for array of int8 vec
+Value getConstTensorInt8(OpBuilder &builder, Location loc,
+ ArrayRef<int8_t> vec);
+
+// Return a const value for array of int16 vec
+Value getConstTensorInt16(OpBuilder &builder, Location loc,
+ ArrayRef<int16_t> vec);
+
+// Return a const value for array of int32 vec
+Value getConstTensorInt32(OpBuilder &builder, Location loc,
+ ArrayRef<int32_t> vec);
+
//// Builds ConvOpQuantizationAttr from input and weight.
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder,
Value input, Value weight);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 8732ddafa24d4..5d1344d17ead0 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1389,8 +1389,24 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
}
// The shift and multiplier values.
- SmallVector<int32_t> multiplierValues(op.getMultiplier());
- SmallVector<int8_t> shiftValues(op.getShift());
+ ElementsAttr shiftElems;
+ if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
+ return rewriter.notifyMatchFailure(
+ op, "tosa.rescale requires constant shift input values");
+
+ ElementsAttr multiplierElems;
+ if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
+ return rewriter.notifyMatchFailure(
+ op, "tosa.rescale requires constant multiplier input values");
+
+ SmallVector<int8_t> shiftValues;
+ for (auto idx : shiftElems.getValues<IntegerAttr>()) {
+ shiftValues.push_back(static_cast<int8_t>(idx.getInt()));
+ }
+ SmallVector<int32_t> multiplierValues;
+ for (auto idx : multiplierElems.getValues<IntegerAttr>()) {
+ multiplierValues.push_back(static_cast<int32_t>(idx.getInt()));
+ }
// If we shift by more than the bitwidth, this just sets to 0.
for (int i = 0, s = multiplierValues.size(); i < s; i++) {
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 1050f3f30fe98..12efce82942b4 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2024,7 +2024,6 @@ NARY_SHAPE_INFER(tosa::MinimumOp)
NARY_SHAPE_INFER(tosa::NegateOp)
NARY_SHAPE_INFER(tosa::PowOp)
NARY_SHAPE_INFER(tosa::ReciprocalOp)
-NARY_SHAPE_INFER(tosa::RescaleOp)
NARY_SHAPE_INFER(tosa::ReverseOp)
NARY_SHAPE_INFER(tosa::RsqrtOp)
NARY_SHAPE_INFER(tosa::SinOp)
@@ -2469,6 +2468,139 @@ LogicalResult TransposeConv2DOp::verify() {
return success();
}
+LogicalResult RescaleOp::verify() {
+ auto inputType = llvm::dyn_cast<ShapedType>(getInput().getType());
+ if (!inputType) {
+ emitOpError("expect shaped tensor for input, got ") << getInput().getType();
+ return failure();
+ }
+ auto inputElementType = inputType.getElementType();
+ if (auto inputQType =
+ llvm::dyn_cast<quant::QuantizedType>(inputElementType)) {
+ inputElementType = inputQType.getStorageType();
+ }
+ if (!mlir::isa<IntegerType>(inputElementType)) {
+ emitOpError("expect input to have integer element type, got ")
+ << inputElementType;
+ return failure();
+ }
+ auto outputType = llvm::dyn_cast<ShapedType>(getOutput().getType());
+ if (!outputType) {
+ emitOpError("expect shaped tensor for output, got ")
+ << getOutput().getType();
+ return failure();
+ }
+ auto outputElementType = outputType.getElementType();
+ if (auto outputQType =
+ llvm::dyn_cast<quant::QuantizedType>(outputElementType)) {
+ outputElementType = outputQType.getStorageType();
+ }
+ if (!mlir::isa<IntegerType>(outputElementType)) {
+ emitOpError("expect output to have integer element type, got ")
+ << outputElementType;
+ return failure();
+ }
+ auto input_zp = getInputZpAttr().getInt();
+ if (input_zp != 0) {
+ // only int8/uint8 and uint16 input can have non-zero input_zp
+ if (!inputElementType.isInteger(8) &&
+ !(inputElementType.isInteger(16) && getInputUnsigned())) {
+ emitOpError("expect input_zp of 0, got ") << input_zp;
+ return failure();
+ }
+ // input_zp must be either 0 or 32768 for uint16 input
+ if (inputElementType.isInteger(16) && getInputUnsigned() &&
+ input_zp != 32768) {
+ emitOpError(
+ "expect input_zp of 0 or 32768 for unsigned int16 input, got ")
+ << input_zp;
+ return failure();
+ }
+ }
+
+ auto output_zp = getOutputZpAttr().getInt();
+ if (output_zp != 0) {
+ // only int8/uint8 and uint16 output can have non-zero output_zp
+ if (!outputElementType.isInteger(8) &&
+ !(outputElementType.isInteger(16) && getOutputUnsigned())) {
+ emitOpError("expect output_zp of 0, got ") << output_zp;
+ return failure();
+ }
+ // output_zp must be either 0 or 32768 for uint16 output
+ if (outputElementType.isInteger(16) && getOutputUnsigned() &&
+ output_zp != 32768) {
+ emitOpError(
+ "expect output_zp of 0 or 32768 for unsigned int16 output, got ")
+ << output_zp;
+ return failure();
+ }
+ }
+
+ auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
+ if (!multiplierType) {
+ emitOpError("expect shaped tensor for multiplier, got ")
+ << getMultiplier().getType();
+ return failure();
+ }
+ auto shiftType = llvm::dyn_cast<ShapedType>(getShift().getType());
+ if (!shiftType) {
+ emitOpError("expect shaped tensor for shift, got ") << getShift().getType();
+ return failure();
+ }
+
+ // multiplier element type must be i32 for scale32 = true
+ if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
+ emitOpError("expect i32 element type for multiplier for scale32=true, got ")
+ << multiplierType.getElementType();
+ return failure();
+ }
+
+ // multiplier element type must be i16 for scale32 = false
+ if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
+ emitOpError(
+ "expect i16 element type for multiplier for scale32=false, got ")
+ << multiplierType.getElementType();
+ return failure();
+ }
+
+ // multiplier/shift must have shape = {numChannels},
+ // where numChannel is 1 if per_channel = false
+ // otherwise numChannel is dimension in input shape's last axis
+ int64_t numChannels = 1;
+ if (getPerChannel()) {
+ ArrayRef<int64_t> inputShape = inputType.getShape();
+ numChannels = inputShape[inputShape.size() - 1];
+ }
+
+ ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
+ // multiplier input has rank 1 by dialect definition
+ if (multiplierShape[0] != numChannels) {
+ emitOpError("expect shape of { ")
+ << numChannels << " } for multiplier input, got { "
+ << multiplierShape[0] << " }";
+ return failure();
+ }
+
+ ArrayRef<int64_t> shiftShape = shiftType.getShape();
+ // shift input has rank 1 by dialect definition
+ if (shiftShape[0] != numChannels) {
+ emitOpError("expect shape of { ")
+ << numChannels << " } for shift input, got { " << shiftShape[0] << " }";
+ return failure();
+ }
+
+ return success();
+}
+
+LogicalResult RescaleOp::inferReturnTypeComponents(
+ MLIRContext *context, ::std::optional<Location> location,
+ RescaleOp::Adaptor adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ ShapeAdaptor inputShape(adaptor.getInput().getType());
+ inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
+ return success();
+}
+
LogicalResult IfOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
IfOp::Adaptor adaptor,
diff --git a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
index 0f7562767001c..5c6bb8921491e 100644
--- a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
@@ -107,6 +107,45 @@ void mlir::tosa::computeMultiplierAndShift(double scale, int32_t &multiplier,
}
}
+// Return a const value for array of int8 vec
+Value mlir::tosa::getConstTensorInt8(OpBuilder &builder, Location loc,
+ ArrayRef<int8_t> vec) {
+ int64_t count = vec.size();
+ assert(count > 0);
+ auto element_type = builder.getI8Type();
+ mlir::RankedTensorType const_type =
+ RankedTensorType::get({count}, element_type);
+ mlir::DenseElementsAttr const_attr = DenseElementsAttr::get(const_type, vec);
+ auto const_op = builder.create<tosa::ConstOp>(loc, const_type, const_attr);
+ return const_op.getResult();
+}
+
+// Return a const value for array of int16 vec
+Value mlir::tosa::getConstTensorInt16(OpBuilder &builder, Location loc,
+ ArrayRef<int16_t> vec) {
+ int64_t count = vec.size();
+ assert(count > 0);
+ auto element_type = builder.getI16Type();
+ mlir::RankedTensorType const_type =
+ RankedTensorType::get({count}, element_type);
+ mlir::DenseElementsAttr const_attr = DenseElementsAttr::get(const_type, vec);
+ auto const_op = builder.create<tosa::ConstOp>(loc, const_type, const_attr);
+ return const_op.getResult();
+}
+
+// Return a const value for array of int32 vec
+Value mlir::tosa::getConstTensorInt32(OpBuilder &builder, Location loc,
+ ArrayRef<int32_t> vec) {
+ int64_t count = vec.size();
+ assert(count > 0);
+ auto element_type = builder.getI32Type();
+ mlir::RankedTensorType const_type =
+ RankedTensorType::get({count}, element_type);
+ mlir::DenseElementsAttr const_attr = DenseElementsAttr::get(const_type, vec);
+ auto const_op = builder.create<tosa::ConstOp>(loc, const_type, const_attr);
+ return const_op.getResult();
+}
+
#define GET_UQTYPE(inputType) \
(llvm::dyn_cast<quant::UniformQuantizedType>((inputType).getElementType()))
#define GET_QTYPE(inputType) \
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index afc1d5c609181..ddaaf5fcf7120 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -33,8 +33,10 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
// CHECK-LABEL: @rescale_unsupported_type
func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
+ %multiplier = "tosa.const"() {value = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+ %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
// expected-error@+1 {{failed to legalize operation 'tosa.rescale'}}
- %0 = tosa.rescale %arg0 {double_round = false, input_zp = 127 : i32, multiplier = array<i32: 1073741824>, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array<i8: 30>, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 6ca260a5324a9..16a49376e58f5 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1149,7 +1149,9 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
// CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
// CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
// CHECK-DAG: linalg.yield [[TRUNC]]
- %0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>) -> tensor<2xi8>
+ %multiplier = "tosa.const"() {value = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+ %shift = "tosa.const"() {value = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
// CHECK: return
return
@@ -1178,7 +1180,9 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
// CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
// CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
// CHECK: linalg.yield [[TRUNC]]
- %1 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>) -> tensor<2xi8>
+ %multiplier = "tosa.const"() {value = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+ %shift = "tosa.const"() {value = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+ %1 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
// CHECK: return
return
@@ -1191,17 +1195,19 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
// CHECK-LABEL: @rescale_i8_dyn_batch
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
func.func @rescale_i8_dyn_batch(%arg0 : tensor<?x2xi8>) -> () {
+ %multiplier = "tosa.const"() {value = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+ %shift = "tosa.const"() {value = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
// CHECK: %[[C0:.+]] = arith.constant 0
// CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xi8>
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8>)
- %0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<?x2xi8>) -> tensor<?x2xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>) -> te...
[truncated]
|
64f54d3
to
80d7b4e
Compare
e329cbb
to
1081656
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Leaving to @lhutton1 for approval
1081656
to
80119ff
Compare
Thank you, @GeorgeARM and @lhutton1, for reviewing! |
80119ff
to
c836926
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates, LGTM! Since #129943 was merged we need to update const value
to values
This patch updates the TOSA RescaleOp by converting its multiplier and shift parameters from attributes to explicit inputs, aligning the op with the TOSA V1.0 specification. Additionally, this commit adds RescaleOp-specific implementations of inferReturnTypeComponents and verify functions. Co-authored-by: Tai Ly <[email protected]> Signed-off-by: Peng Sun <[email protected]> Change-Id: I9e21bf757e736dabea5a2e77398e1b8a268b8ee9
c836926
to
c30014d
Compare
Changes needed for llvm/llvm-project#130337 and llvm/llvm-project#129720 PiperOrigin-RevId: 738783395
Changes needed for llvm/llvm-project#130337 and llvm/llvm-project#129720 PiperOrigin-RevId: 738783395
Changes needed for llvm/llvm-project#130337, llvm/llvm-project#129720, and llvm/llvm-project#130340 FUTURE_COPYBARA_INTEGRATE_REVIEW=#23793 from Tixxx:tixxx/cp_sync 3a8a643 PiperOrigin-RevId: 738783395
Changes needed for llvm/llvm-project#130337, llvm/llvm-project#129720, and llvm/llvm-project#130340 PiperOrigin-RevId: 738783395
Changes needed for llvm/llvm-project#130337, llvm/llvm-project#129720, and llvm/llvm-project#130340 FUTURE_COPYBARA_INTEGRATE_REVIEW=#23793 from Tixxx:tixxx/cp_sync 3a8a643 PiperOrigin-RevId: 738783395
Changes needed for llvm/llvm-project#130337, llvm/llvm-project#129720, and llvm/llvm-project#130340 PiperOrigin-RevId: 738783395
Changes needed for llvm/llvm-project#130337, llvm/llvm-project#129720, and llvm/llvm-project#130340 PiperOrigin-RevId: 739122578
Changes needed for llvm/llvm-project#130337, llvm/llvm-project#129720, and llvm/llvm-project#130340 PiperOrigin-RevId: 739122578
Changes needed for llvm/llvm-project#130337, llvm/llvm-project#129720, and llvm/llvm-project#130340 PiperOrigin-RevId: 739122578
Changes needed for llvm/llvm-project#130337, llvm/llvm-project#129720, and llvm/llvm-project#130340 PiperOrigin-RevId: 739122578
Changes needed for llvm/llvm-project#130337, llvm/llvm-project#129720, and llvm/llvm-project#130340 PiperOrigin-RevId: 739122578
This patch updates the TOSA RescaleOp by converting its multiplier and shift parameters from attributes to explicit inputs, aligning the op with the TOSA V1.0 specification.
Additionally, this commit adds RescaleOp-specific implementations of inferReturnTypeComponents and verify functions.