-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tosa] Change Rescale zero points to be inputs #130340
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tosa Author: Tai Ly (Tai78641) Changes*Update RescaleOp to use zero-point as operands instead of attributes. Patch is 61.46 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/130340.diff 15 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index a9b458acd87f2..08f28a7538c3d 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -216,15 +216,15 @@ profileComplianceMap = {
{fp32T, fp16T}}}}},
{"tosa.rescale",
{{{Profile::pro_int},
- {{i8T, i8T},
- {i8T, i16T},
- {i8T, i32T},
- {i16T, i8T},
- {i16T, i16T},
- {i16T, i32T},
- {i32T, i8T},
- {i32T, i16T},
- {i32T, i32T}}}}},
+ {{i8T, i8T, i8T, i8T},
+ {i8T, i8T, i16T, i16T},
+ {i8T, i8T, i32T, i32T},
+ {i16T, i16T, i8T, i8T},
+ {i16T, i16T, i16T, i16T},
+ {i16T, i16T, i32T, i32T},
+ {i32T, i32T, i8T, i8T},
+ {i32T, i32T, i16T, i16T},
+ {i32T, i32T, i32T, i32T}}}}},
{"tosa.const",
{{{Profile::pro_int}, {{boolT}, {i8T}, {i16T}, {i32T}}},
{{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
@@ -384,7 +384,10 @@ extensionComplianceMap = {
{fp16T, fp8e5m2T},
{fp32T, fp8e5m2T}}}}},
{"tosa.rescale",
- {{{Extension::int16}, {{i48T, i8T}, {i48T, i16T}, {i48T, i32T}}}}},
+ {{{Extension::int16},
+ {{i48T, i48T, i8T, i8T},
+ {i48T, i48T, i16T, i16T},
+ {i48T, i48T, i32T, i32T}}}}},
{"tosa.const",
{{{Extension::int4}, {{i4T}}},
{{Extension::int16}, {{i48T}}},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 097f78cd487ea..cd593a2816355 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2337,8 +2337,8 @@ def Tosa_RescaleOp : Tosa_InferShapedTypeOp<"rescale"> {
Tosa_Tensor:$input,
Tosa_1DInt16Or32Tensor:$multiplier,
Tosa_1DInt8Tensor:$shift,
- I32Attr:$input_zp,
- I32Attr:$output_zp,
+ Tosa_ScalarIntOrFloatTensor:$input_zp,
+ Tosa_ScalarIntOrFloatTensor:$output_zp,
BoolAttr:$scale32,
BoolAttr:$double_round,
BoolAttr:$per_channel,
@@ -2355,6 +2355,13 @@ def Tosa_RescaleOp : Tosa_InferShapedTypeOp<"rescale"> {
Extension<[Tosa_EXT_INT16]>,
];
+ let extraClassDeclaration = [{
+ FailureOr<int64_t> getInputZeroPoint();
+ FailureOr<int64_t> getOutputZeroPoint();
+ LogicalResult verifyInputZeroPoint(int64_t zp);
+ LogicalResult verifyOutputZeroPoint(int64_t zp);
+ }];
+
let hasVerifier = 1;
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index f7dd33c7e8b53..89af54132f820 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -84,10 +84,9 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
template <typename T>
static arith::ConstantOp
-createConstFromIntAttribute(Operation *op, const std::string &attrName,
- Type requiredAttrType, OpBuilder &rewriter) {
- auto castedN = static_cast<T>(
- cast<IntegerAttr>(op->getAttr(attrName)).getValue().getSExtValue());
+createConstOpFromZpVal(Operation *op, const int64_t &zp, Type requiredAttrType,
+ OpBuilder &rewriter) {
+ auto castedN = static_cast<T>(zp);
return rewriter.create<arith::ConstantOp>(
op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
}
@@ -1491,11 +1490,26 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
// later.
int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
- auto inputZp = createConstFromIntAttribute<int32_t>(
- op, "input_zp", nestedBuilder.getIntegerType(inBitwidth),
+ FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
+ if (failed(maybeIZp)) {
+ (void)rewriter.notifyMatchFailure(
+ op, "input zero point cannot be statically determined");
+ return;
+ }
+
+ auto inputZp = createConstOpFromZpVal<int32_t>(
+ op, *maybeIZp, nestedBuilder.getIntegerType(inBitwidth),
nestedBuilder);
- auto outputZp = createConstFromIntAttribute<int32_t>(
- op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder);
+
+ FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
+ if (failed(maybeOZp)) {
+ (void)rewriter.notifyMatchFailure(
+ op, "output zero point cannot be statically determined");
+ return;
+ };
+
+ auto outputZp = createConstOpFromZpVal<int32_t>(
+ op, *maybeOZp, nestedBuilder.getI32Type(), nestedBuilder);
Value multiplier = multiplierConstant ? multiplierConstant
: blockArgs[multiplierArg];
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 4711122dc76e2..4d1f0be567d3c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -254,6 +254,27 @@ static Type getStorageElementTypeOrSelf(Type type) {
return elementType;
}
+static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val,
+ Value valZp, StringRef name) {
+ Type eType = getStorageElementTypeOrSelf(val.getType());
+ Type eZpType = getStorageElementTypeOrSelf(valZp.getType());
+
+ bool bothInts =
+ mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
+ bool bothFloats =
+ mlir::isa<FloatType>(eType) && mlir::isa<FloatType>(eZpType);
+ bool sameBitWidth =
+ (eType.getIntOrFloatBitWidth() == eZpType.getIntOrFloatBitWidth());
+
+ if ((!bothInts && !bothFloats) || !sameBitWidth) {
+ return op->emitOpError()
+ << "expected " << name << " and " << name
+ << "_zp to both be integer or float of the same bitwidth, but got "
+ << eType << " vs. " << eZpType;
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TOSA Operator Verifiers.
//===----------------------------------------------------------------------===//
@@ -1696,6 +1717,38 @@ static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
return success();
}
+static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
+ const int64_t &zp,
+ const std::string &operand) {
+ bool isInputZp = (zpVal == op.getInputZp());
+ bool isOutputZp = (zpVal == op.getOutputZp());
+ if (!isInputZp && !isOutputZp) {
+ return op.emitOpError("internal error: zero-point operand is neither "
+ "inputZp nor outputZp");
+ }
+
+ bool tensorUnsigned =
+ isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
+ StringRef tensorName = isInputZp ? "input" : "output";
+
+ Type zpElemType = getElementTypeOrSelf(zpVal);
+
+ if (zp != 0) {
+ if (!zpElemType.isInteger(8) &&
+ !(zpElemType.isInteger(16) && tensorUnsigned)) {
+ return op.emitOpError()
+ << "expect " << tensorName << "_zp of 0, got " << zp;
+ }
+ if (zpElemType.isInteger(16) && tensorUnsigned && zp != 32768) {
+ return op.emitOpError() << "expect " << tensorName
+ << "_zp of 0 or 32768 for unsigned int16 "
+ << tensorName << ", got " << zp;
+ }
+ }
+
+ return success();
+}
+
#define ZERO_POINT_HELPER(OP, OPERAND_NAME) \
FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
return getZeroPoint(*this, get##OPERAND_NAME##Zp()); \
@@ -1714,6 +1767,8 @@ ZERO_POINT_HELPER(TransposeConv2DOp, Input)
ZERO_POINT_HELPER(TransposeConv2DOp, Weight)
ZERO_POINT_HELPER(AvgPool2dOp, Input)
ZERO_POINT_HELPER(AvgPool2dOp, Output)
+ZERO_POINT_HELPER(RescaleOp, Input)
+ZERO_POINT_HELPER(RescaleOp, Output)
#undef ZERO_POINT_HELPER
LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
@@ -2698,41 +2753,21 @@ LogicalResult RescaleOp::verify() {
return failure();
}
- auto input_zp = getInputZpAttr().getInt();
- if (input_zp != 0) {
- // only int8/uint8 and uint16 input can have non-zero input_zp
- if (!inputElementType.isInteger(8) &&
- !(inputElementType.isInteger(16) && getInputUnsigned())) {
- emitOpError("expect input_zp of 0, got ") << input_zp;
- return failure();
- }
- // input_zp must be either 0 or 32768 for uint16 input
- if (inputElementType.isInteger(16) && getInputUnsigned() &&
- input_zp != 32768) {
- emitOpError(
- "expect input_zp of 0 or 32768 for unsigned int16 input, got ")
- << input_zp;
- return failure();
- }
- }
+ if (verifyRescaleValueAndZpTypes(*this, getInput(), getInputZp(), "input")
+ .failed())
+ return failure();
- auto output_zp = getOutputZpAttr().getInt();
- if (output_zp != 0) {
- // only int8/uint8 and uint16 output can have non-zero output_zp
- if (!outputElementType.isInteger(8) &&
- !(outputElementType.isInteger(16) && getOutputUnsigned())) {
- emitOpError("expect output_zp of 0, got ") << output_zp;
- return failure();
- }
- // output_zp must be either 0 or 32768 for uint16 output
- if (outputElementType.isInteger(16) && getOutputUnsigned() &&
- output_zp != 32768) {
- emitOpError(
- "expect output_zp of 0 or 32768 for unsigned int16 output, got ")
- << output_zp;
- return failure();
- }
- }
+ if (verifyRescaleValueAndZpTypes(*this, getOutput(), getOutputZp(), "output")
+ .failed())
+ return failure();
+
+ FailureOr<int64_t> maybeIZp = getInputZeroPoint();
+ if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
+ return failure();
+
+ FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
+ if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
+ return failure();
auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
if (!multiplierType) {
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 345616c9563b5..4fba2e5dfde2b 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -175,6 +175,8 @@ void ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
template <>
void ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
addValue(op.getInput());
+ addValue(op.getInputZp());
+ addValue(op.getOutputZp());
addValue(op.getOutput());
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 77687b83e5e3c..842b50e804cbe 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -35,8 +35,11 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<127> : tensor<1xi8>} : () -> tensor<1xi8>
+ %output_zp = "tosa.const"() {values = dense<-1> : tensor<1xi8>} : () -> tensor<1xi8>
+
// expected-error@+1 {{failed to legalize operation 'tosa.rescale'}}
- %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, double_round = false, per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index a3ed8c2805282..49d3e86e8fcd0 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1149,9 +1149,11 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
// CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
// CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
// CHECK-DAG: linalg.yield [[TRUNC]]
- %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
- %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
- %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
+ %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16>} : () -> tensor<1xi16>
+ %shift = "tosa.const"() {values = dense<15> : tensor<1xi8>} : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
+ %output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
// CHECK: return
return
@@ -1182,7 +1184,9 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
// CHECK: linalg.yield [[TRUNC]]
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
- %1 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
+ %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
+ %output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
+ %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
// CHECK: return
return
@@ -1195,19 +1199,22 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
// CHECK-LABEL: @rescale_i8_dyn_batch
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
func.func @rescale_i8_dyn_batch(%arg0 : tensor<?x2xi8>) -> () {
- %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
- %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+ %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16>} : () -> tensor<1xi16>
+ %shift = "tosa.const"() {values = dense<15> : tensor<1xi8>} : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
+ %output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
+
// CHECK: %[[C0:.+]] = arith.constant 0
// CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xi8>
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8>)
- %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<?x2xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<?x2xi8>
// CHECK: %[[C0:.+]] = arith.constant 0
// CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xi8>
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8>)
- %1 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<?x2xi8>
+ %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<?x2xi8>
return
}
@@ -1219,15 +1226,19 @@ func.func @rescale_i8_dyn_batch(%arg0 : tensor<?x2xi8>) -> () {
// CHECK-LABEL: @rescale_dyn
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
func.func @rescale_dyn(%arg0 : tensor<1x?x?x32xi32>) -> () {
+ %multiplier = "tosa.const"() {values = dense<1376784203> : tensor<1xi32>} : () -> tensor<1xi32>
+ %shift = "tosa.const"() {values = dense<38> : tensor<1xi8>} : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+ %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+
// CHECK: %[[C1:.+]] = arith.constant 1
// CHECK: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK: %[[C2:.+]] = arith.constant 2
// CHECK: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]]
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM2]])
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<1x?x?x32xi32>) outs(%[[INIT]] : tensor<1x?x?x32xi8>)
- %multiplier = "tosa.const"() {values = dense<1376784203> : tensor<1xi32> } : () -> tensor<1xi32>
- %shift = "tosa.const"() {values = dense<38> : tensor<1xi8> } : () -> tensor<1xi8>
- %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = true, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x?x?x32xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x?x?x32xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, double_round = true, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1x?x?x32xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x?x?x32xi8>
+
return
}
@@ -1257,7 +1268,9 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
// CHECK: linalg.yield [[TRUNC]]
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
- %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xi8>, tensor<1x...
[truncated]
|
@llvm/pr-subscribers-mlir-linalg Author: Tai Ly (Tai78641) Changes*Update RescaleOp to use zero-point as operands instead of attributes. Patch is 61.46 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/130340.diff 15 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index a9b458acd87f2..08f28a7538c3d 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -216,15 +216,15 @@ profileComplianceMap = {
{fp32T, fp16T}}}}},
{"tosa.rescale",
{{{Profile::pro_int},
- {{i8T, i8T},
- {i8T, i16T},
- {i8T, i32T},
- {i16T, i8T},
- {i16T, i16T},
- {i16T, i32T},
- {i32T, i8T},
- {i32T, i16T},
- {i32T, i32T}}}}},
+ {{i8T, i8T, i8T, i8T},
+ {i8T, i8T, i16T, i16T},
+ {i8T, i8T, i32T, i32T},
+ {i16T, i16T, i8T, i8T},
+ {i16T, i16T, i16T, i16T},
+ {i16T, i16T, i32T, i32T},
+ {i32T, i32T, i8T, i8T},
+ {i32T, i32T, i16T, i16T},
+ {i32T, i32T, i32T, i32T}}}}},
{"tosa.const",
{{{Profile::pro_int}, {{boolT}, {i8T}, {i16T}, {i32T}}},
{{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
@@ -384,7 +384,10 @@ extensionComplianceMap = {
{fp16T, fp8e5m2T},
{fp32T, fp8e5m2T}}}}},
{"tosa.rescale",
- {{{Extension::int16}, {{i48T, i8T}, {i48T, i16T}, {i48T, i32T}}}}},
+ {{{Extension::int16},
+ {{i48T, i48T, i8T, i8T},
+ {i48T, i48T, i16T, i16T},
+ {i48T, i48T, i32T, i32T}}}}},
{"tosa.const",
{{{Extension::int4}, {{i4T}}},
{{Extension::int16}, {{i48T}}},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 097f78cd487ea..cd593a2816355 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2337,8 +2337,8 @@ def Tosa_RescaleOp : Tosa_InferShapedTypeOp<"rescale"> {
Tosa_Tensor:$input,
Tosa_1DInt16Or32Tensor:$multiplier,
Tosa_1DInt8Tensor:$shift,
- I32Attr:$input_zp,
- I32Attr:$output_zp,
+ Tosa_ScalarIntOrFloatTensor:$input_zp,
+ Tosa_ScalarIntOrFloatTensor:$output_zp,
BoolAttr:$scale32,
BoolAttr:$double_round,
BoolAttr:$per_channel,
@@ -2355,6 +2355,13 @@ def Tosa_RescaleOp : Tosa_InferShapedTypeOp<"rescale"> {
Extension<[Tosa_EXT_INT16]>,
];
+ let extraClassDeclaration = [{
+ FailureOr<int64_t> getInputZeroPoint();
+ FailureOr<int64_t> getOutputZeroPoint();
+ LogicalResult verifyInputZeroPoint(int64_t zp);
+ LogicalResult verifyOutputZeroPoint(int64_t zp);
+ }];
+
let hasVerifier = 1;
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index f7dd33c7e8b53..89af54132f820 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -84,10 +84,9 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
template <typename T>
static arith::ConstantOp
-createConstFromIntAttribute(Operation *op, const std::string &attrName,
- Type requiredAttrType, OpBuilder &rewriter) {
- auto castedN = static_cast<T>(
- cast<IntegerAttr>(op->getAttr(attrName)).getValue().getSExtValue());
+createConstOpFromZpVal(Operation *op, const int64_t &zp, Type requiredAttrType,
+ OpBuilder &rewriter) {
+ auto castedN = static_cast<T>(zp);
return rewriter.create<arith::ConstantOp>(
op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
}
@@ -1491,11 +1490,26 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
// later.
int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
- auto inputZp = createConstFromIntAttribute<int32_t>(
- op, "input_zp", nestedBuilder.getIntegerType(inBitwidth),
+ FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
+ if (failed(maybeIZp)) {
+ (void)rewriter.notifyMatchFailure(
+ op, "input zero point cannot be statically determined");
+ return;
+ }
+
+ auto inputZp = createConstOpFromZpVal<int32_t>(
+ op, *maybeIZp, nestedBuilder.getIntegerType(inBitwidth),
nestedBuilder);
- auto outputZp = createConstFromIntAttribute<int32_t>(
- op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder);
+
+ FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
+ if (failed(maybeOZp)) {
+ (void)rewriter.notifyMatchFailure(
+ op, "output zero point cannot be statically determined");
+ return;
+ };
+
+ auto outputZp = createConstOpFromZpVal<int32_t>(
+ op, *maybeOZp, nestedBuilder.getI32Type(), nestedBuilder);
Value multiplier = multiplierConstant ? multiplierConstant
: blockArgs[multiplierArg];
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 4711122dc76e2..4d1f0be567d3c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -254,6 +254,27 @@ static Type getStorageElementTypeOrSelf(Type type) {
return elementType;
}
+static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val,
+ Value valZp, StringRef name) {
+ Type eType = getStorageElementTypeOrSelf(val.getType());
+ Type eZpType = getStorageElementTypeOrSelf(valZp.getType());
+
+ bool bothInts =
+ mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
+ bool bothFloats =
+ mlir::isa<FloatType>(eType) && mlir::isa<FloatType>(eZpType);
+ bool sameBitWidth =
+ (eType.getIntOrFloatBitWidth() == eZpType.getIntOrFloatBitWidth());
+
+ if ((!bothInts && !bothFloats) || !sameBitWidth) {
+ return op->emitOpError()
+ << "expected " << name << " and " << name
+ << "_zp to both be integer or float of the same bitwidth, but got "
+ << eType << " vs. " << eZpType;
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TOSA Operator Verifiers.
//===----------------------------------------------------------------------===//
@@ -1696,6 +1717,38 @@ static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
return success();
}
+static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
+ const int64_t &zp,
+ const std::string &operand) {
+ bool isInputZp = (zpVal == op.getInputZp());
+ bool isOutputZp = (zpVal == op.getOutputZp());
+ if (!isInputZp && !isOutputZp) {
+ return op.emitOpError("internal error: zero-point operand is neither "
+ "inputZp nor outputZp");
+ }
+
+ bool tensorUnsigned =
+ isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
+ StringRef tensorName = isInputZp ? "input" : "output";
+
+ Type zpElemType = getElementTypeOrSelf(zpVal);
+
+ if (zp != 0) {
+ if (!zpElemType.isInteger(8) &&
+ !(zpElemType.isInteger(16) && tensorUnsigned)) {
+ return op.emitOpError()
+ << "expect " << tensorName << "_zp of 0, got " << zp;
+ }
+ if (zpElemType.isInteger(16) && tensorUnsigned && zp != 32768) {
+ return op.emitOpError() << "expect " << tensorName
+ << "_zp of 0 or 32768 for unsigned int16 "
+ << tensorName << ", got " << zp;
+ }
+ }
+
+ return success();
+}
+
#define ZERO_POINT_HELPER(OP, OPERAND_NAME) \
FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
return getZeroPoint(*this, get##OPERAND_NAME##Zp()); \
@@ -1714,6 +1767,8 @@ ZERO_POINT_HELPER(TransposeConv2DOp, Input)
ZERO_POINT_HELPER(TransposeConv2DOp, Weight)
ZERO_POINT_HELPER(AvgPool2dOp, Input)
ZERO_POINT_HELPER(AvgPool2dOp, Output)
+ZERO_POINT_HELPER(RescaleOp, Input)
+ZERO_POINT_HELPER(RescaleOp, Output)
#undef ZERO_POINT_HELPER
LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
@@ -2698,41 +2753,21 @@ LogicalResult RescaleOp::verify() {
return failure();
}
- auto input_zp = getInputZpAttr().getInt();
- if (input_zp != 0) {
- // only int8/uint8 and uint16 input can have non-zero input_zp
- if (!inputElementType.isInteger(8) &&
- !(inputElementType.isInteger(16) && getInputUnsigned())) {
- emitOpError("expect input_zp of 0, got ") << input_zp;
- return failure();
- }
- // input_zp must be either 0 or 32768 for uint16 input
- if (inputElementType.isInteger(16) && getInputUnsigned() &&
- input_zp != 32768) {
- emitOpError(
- "expect input_zp of 0 or 32768 for unsigned int16 input, got ")
- << input_zp;
- return failure();
- }
- }
+ if (verifyRescaleValueAndZpTypes(*this, getInput(), getInputZp(), "input")
+ .failed())
+ return failure();
- auto output_zp = getOutputZpAttr().getInt();
- if (output_zp != 0) {
- // only int8/uint8 and uint16 output can have non-zero output_zp
- if (!outputElementType.isInteger(8) &&
- !(outputElementType.isInteger(16) && getOutputUnsigned())) {
- emitOpError("expect output_zp of 0, got ") << output_zp;
- return failure();
- }
- // output_zp must be either 0 or 32768 for uint16 output
- if (outputElementType.isInteger(16) && getOutputUnsigned() &&
- output_zp != 32768) {
- emitOpError(
- "expect output_zp of 0 or 32768 for unsigned int16 output, got ")
- << output_zp;
- return failure();
- }
- }
+ if (verifyRescaleValueAndZpTypes(*this, getOutput(), getOutputZp(), "output")
+ .failed())
+ return failure();
+
+ FailureOr<int64_t> maybeIZp = getInputZeroPoint();
+ if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
+ return failure();
+
+ FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
+ if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
+ return failure();
auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
if (!multiplierType) {
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 345616c9563b5..4fba2e5dfde2b 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -175,6 +175,8 @@ void ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
template <>
void ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
addValue(op.getInput());
+ addValue(op.getInputZp());
+ addValue(op.getOutputZp());
addValue(op.getOutput());
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 77687b83e5e3c..842b50e804cbe 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -35,8 +35,11 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<127> : tensor<1xi8>} : () -> tensor<1xi8>
+ %output_zp = "tosa.const"() {values = dense<-1> : tensor<1xi8>} : () -> tensor<1xi8>
+
// expected-error@+1 {{failed to legalize operation 'tosa.rescale'}}
- %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, double_round = false, per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index a3ed8c2805282..49d3e86e8fcd0 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1149,9 +1149,11 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
// CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
// CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
// CHECK-DAG: linalg.yield [[TRUNC]]
- %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
- %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
- %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
+ %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16>} : () -> tensor<1xi16>
+ %shift = "tosa.const"() {values = dense<15> : tensor<1xi8>} : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
+ %output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
// CHECK: return
return
@@ -1182,7 +1184,9 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
// CHECK: linalg.yield [[TRUNC]]
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
- %1 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
+ %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
+ %output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
+ %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
// CHECK: return
return
@@ -1195,19 +1199,22 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
// CHECK-LABEL: @rescale_i8_dyn_batch
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
func.func @rescale_i8_dyn_batch(%arg0 : tensor<?x2xi8>) -> () {
- %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
- %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+ %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16>} : () -> tensor<1xi16>
+ %shift = "tosa.const"() {values = dense<15> : tensor<1xi8>} : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
+ %output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
+
// CHECK: %[[C0:.+]] = arith.constant 0
// CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xi8>
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8>)
- %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<?x2xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<?x2xi8>
// CHECK: %[[C0:.+]] = arith.constant 0
// CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xi8>
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8>)
- %1 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<?x2xi8>
+ %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<?x2xi8>
return
}
@@ -1219,15 +1226,19 @@ func.func @rescale_i8_dyn_batch(%arg0 : tensor<?x2xi8>) -> () {
// CHECK-LABEL: @rescale_dyn
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
func.func @rescale_dyn(%arg0 : tensor<1x?x?x32xi32>) -> () {
+ %multiplier = "tosa.const"() {values = dense<1376784203> : tensor<1xi32>} : () -> tensor<1xi32>
+ %shift = "tosa.const"() {values = dense<38> : tensor<1xi8>} : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+ %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+
// CHECK: %[[C1:.+]] = arith.constant 1
// CHECK: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK: %[[C2:.+]] = arith.constant 2
// CHECK: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]]
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM2]])
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<1x?x?x32xi32>) outs(%[[INIT]] : tensor<1x?x?x32xi8>)
- %multiplier = "tosa.const"() {values = dense<1376784203> : tensor<1xi32> } : () -> tensor<1xi32>
- %shift = "tosa.const"() {values = dense<38> : tensor<1xi8> } : () -> tensor<1xi8>
- %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = true, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x?x?x32xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x?x?x32xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, double_round = true, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1x?x?x32xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x?x?x32xi8>
+
return
}
@@ -1257,7 +1268,9 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
// CHECK: linalg.yield [[TRUNC]]
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
- %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xi8>, tensor<1x...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
*Update RescaleOp to use zero-point as operands instead of attributes. *Check input_zp data type against the input and output_zp data type against the output. Change-Id: I2cf0106eb9f9ec88e16de5efc93b651053e5fc92 Signed-off-by: Peng Sun <[email protected]>
rebased and resolved lit tests merge issues |
Changes needed for llvm/llvm-project#130337, llvm/llvm-project#129720, and llvm/llvm-project#130340 FUTURE_COPYBARA_INTEGRATE_REVIEW=#23793 from Tixxx:tixxx/cp_sync 3a8a643 PiperOrigin-RevId: 738783395
Changes needed for llvm/llvm-project#130337, llvm/llvm-project#129720, and llvm/llvm-project#130340 PiperOrigin-RevId: 738783395
Changes needed for llvm/llvm-project#130337, llvm/llvm-project#129720, and llvm/llvm-project#130340 FUTURE_COPYBARA_INTEGRATE_REVIEW=#23793 from Tixxx:tixxx/cp_sync 3a8a643 PiperOrigin-RevId: 738783395
Changes needed for llvm/llvm-project#130337, llvm/llvm-project#129720, and llvm/llvm-project#130340 PiperOrigin-RevId: 738783395
Changes needed for llvm/llvm-project#130337, llvm/llvm-project#129720, and llvm/llvm-project#130340 PiperOrigin-RevId: 739122578
Changes needed for llvm/llvm-project#130337, llvm/llvm-project#129720, and llvm/llvm-project#130340 PiperOrigin-RevId: 739122578
Changes needed for llvm/llvm-project#130337, llvm/llvm-project#129720, and llvm/llvm-project#130340 PiperOrigin-RevId: 739122578
Changes needed for llvm/llvm-project#130337, llvm/llvm-project#129720, and llvm/llvm-project#130340 PiperOrigin-RevId: 739122578
Changes needed for llvm/llvm-project#130337, llvm/llvm-project#129720, and llvm/llvm-project#130340 PiperOrigin-RevId: 739122578
*Update RescaleOp to use zero-point as operands instead of attributes.
*Check input_zp data type against the input and output_zp data type
against the output.