Skip to content

Commit 8158d43

Browse files
authored
[TOSA] Rescale output_zp fix (#136116)
Patch corrects output_zp in case of usigned output
1 parent 46f18b7 commit 8158d43

File tree

3 files changed

+13
-4
lines changed

3 files changed

+13
-4
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1490,6 +1490,14 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
14901490
return;
14911491
};
14921492

1493+
// pre-process OutputZP as it can be unsigned
1494+
auto outBitwidth = outputTy.getElementType().getIntOrFloatBitWidth();
1495+
APInt OZp(outBitwidth, !op.getOutputUnsigned());
1496+
OZp = static_cast<int64_t>(*maybeOZp);
1497+
*maybeOZp = op.getOutputUnsigned()
1498+
? static_cast<int64_t>(OZp.getZExtValue())
1499+
: OZp.getSExtValue();
1500+
14931501
auto outputZp = createConstOpFromZpVal<int32_t>(
14941502
op, *maybeOZp, nestedBuilder.getI32Type(), nestedBuilder);
14951503

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2028,7 +2028,8 @@ static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
20282028
return op.emitOpError()
20292029
<< "expect " << tensorName << "_zp of 0, got " << zp;
20302030
}
2031-
if (zpElemType.isInteger(16) && tensorUnsigned && zp != 32768) {
2031+
if (zpElemType.isInteger(16) && tensorUnsigned &&
2032+
zp != static_cast<int16_t>(32768)) {
20322033
return op.emitOpError() << "expect " << tensorName
20332034
<< "_zp of 0 or 32768 for unsigned int16 "
20342035
<< tensorName << ", got " << zp;

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,11 +1161,11 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
11611161
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>)
11621162
// CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8):
11631163
// CHECK: [[C17:%.+]] = arith.constant 17
1164-
// CHECK: [[C22:%.+]] = arith.constant 22
1164+
// CHECK: [[C234:%.+]] = arith.constant 234
11651165
// CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]]
11661166
// CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
11671167
// CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
1168-
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
1168+
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C234]]
11691169
// CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
11701170
// CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
11711171
// CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
@@ -1175,7 +1175,7 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
11751175
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
11761176
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
11771177
%input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
1178-
%output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
1178+
%output_zp = "tosa.const"() {values = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
11791179
%1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
11801180

11811181
// CHECK: return

0 commit comments

Comments
 (0)