Skip to content

Commit 027aa70

Browse files
authored
[TOSA] Fix negate maxValue computation (#126295)
getInput1Zp() returns an unsigned value which means in case of negative zero point value the max intermediate value computation currently goes wrong. Use getInput1ZpAttr() instead which returns an APInt and allows easy sign extension to int64_t.
1 parent 95922d8 commit 027aa70

File tree

2 files changed

+20
-5
lines changed

2 files changed

+20
-5
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,13 @@ static Value createLinalgBodyCalculationForElementwiseOp(
146146
return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
147147

148148
if (isa<IntegerType>(elementTy)) {
149-
auto inputZpAttr = cast<tosa::NegateOp>(op).getInput1Zp();
150-
auto outputZpAttr = cast<tosa::NegateOp>(op).getOutputZp();
149+
auto inputZpAttr = cast<tosa::NegateOp>(op).getInput1ZpAttr();
150+
auto outputZpAttr = cast<tosa::NegateOp>(op).getOutputZpAttr();
151151

152-
const int64_t inZp = inputZpAttr ? *inputZpAttr : 0;
153-
const int64_t outZp = outputZpAttr ? *outputZpAttr : 0;
152+
const int64_t inZp =
153+
inputZpAttr ? inputZpAttr.getValue().getSExtValue() : 0;
154+
const int64_t outZp =
155+
outputZpAttr ? outputZpAttr.getValue().getSExtValue() : 0;
154156

155157
if (!inZp && !outZp) {
156158
auto constant = rewriter.create<arith::ConstantOp>(

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -911,12 +911,25 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
911911
// CHECK: linalg.yield [[TRUNC]]
912912
%2 = tosa.negate %arg0 {input1_zp = 32640 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8>
913913

914+
// CHECK: linalg.generic
915+
// CHECK: ^bb0(%[[BBARG0:.+]]: i8,
916+
// CHECK: [[C_128:%.+]] = arith.constant -128
917+
// CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
918+
// CHECK: [[SUB:%.+]] = arith.subi [[C_128]], [[EXT]]
919+
// CHECK: [[MIN:%.+]] = arith.constant -128
920+
// CHECK: [[MAX:%.+]] = arith.constant 127
921+
// CHECK: [[LBOUND:%.+]] = arith.maxsi [[MIN]], [[SUB]]
922+
// CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]]
923+
// CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]]
924+
// CHECK: linalg.yield [[TRUNC]]
925+
%3 = tosa.negate %arg0 {input1_zp = -128 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8>
926+
914927
// CHECK: linalg.generic
915928
// CHECK: ^bb0(%[[BBARG0:.+]]: i8,
916929
// CHECK: [[ZERO:%.+]] = arith.constant 0
917930
// CHECK: [[SUB:%.+]] = arith.subi [[ZERO]],
918931
// CHECK: linalg.yield [[SUB]]
919-
%3 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 0, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8>
932+
%4 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 0, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8>
920933

921934
return
922935
}

0 commit comments

Comments
 (0)