Skip to content

Commit f1b2bf8

Browse files
committed
Fix upstream CI failures
Fix order of decorators to expand unittest first, and then parameterized input. Fix bug in add operator conversion to handle different scales correctly. Signed-off-by: Per Åstrand <[email protected]> Change-Id: Ic228cf0215e8171392776739936a53c025802fd5
1 parent 675cdb3 commit f1b2bf8

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

backends/arm/operators/op_add.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,11 @@ def define_node(
4848
input_A, input_A_scale, input_A_zp, _, _, _ = getNodeArgs(input_node_A)
4949
input_B, input_B_scale, input_B_zp, _, _, _ = getNodeArgs(input_node_B)
5050

51-
max_scale_2x = 2.0 * max(input_A_scale.number, input_B_scale.number)
52-
inputA_rescale_scale = input_A_scale.number / max_scale_2x
53-
inputB_rescale_scale = input_B_scale.number / max_scale_2x
51+
# Scale the int8 quantized input to a common scale in the integer
52+
# domain.
53+
min_scale = min(input_A_scale.number, input_B_scale.number)
54+
inputA_rescale_scale = input_A_scale.number / min_scale
55+
inputB_rescale_scale = input_B_scale.number / min_scale
5456

5557
broadcasted_shape = broadcast_shapes(input_A.shape, input_B.shape)
5658
if permute_memory_to_nhwc:
@@ -88,7 +90,7 @@ def define_node(
8890
# Output
8991
output_node = list(node.users)[0]
9092
_, output_scale, output_zp, _, _, _ = getNodeArgs(output_node)
91-
output_rescale_scale = max_scale_2x / (output_scale.number)
93+
output_rescale_scale = min_scale / output_scale.number
9294

9395
# Rescale Back to INT8
9496
build_rescale_from_int32(

backends/arm/test/ops/test_add.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class Add2(torch.nn.Module):
4848
(torch.ones(1, 1, 4, 4), torch.ones(1, 1, 4, 4)),
4949
(torch.randn(1, 1, 4, 4), torch.ones(1, 1, 4, 1)),
5050
(torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)),
51+
(10000 * torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)),
5152
]
5253

5354
def __init__(self):
@@ -140,11 +141,11 @@ def test_add_tosa_BI(self, test_data: torch.Tensor):
140141
test_data = (test_data,)
141142
self._test_add_tosa_BI_pipeline(self.Add(), test_data)
142143

144+
@parameterized.expand(Add.test_parameters)
143145
@unittest.skipIf(
144146
not VELA_INSTALLED,
145147
"There is no point in running U55 tests if the Vela tool is not installed",
146148
)
147-
@parameterized.expand(Add.test_parameters)
148149
def test_add_u55_BI(self, test_data: torch.Tensor):
149150
test_data = (test_data,)
150151
self._test_add_u55_BI_pipeline(self.Add(), test_data)
@@ -159,11 +160,11 @@ def test_add2_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
159160
test_data = (operand1, operand2)
160161
self._test_add_tosa_BI_pipeline(self.Add2(), test_data)
161162

163+
@parameterized.expand(Add2.test_parameters)
162164
@unittest.skipIf(
163165
not VELA_INSTALLED,
164166
"There is no point in running U55 tests if the Vela tool is not installed",
165167
)
166-
@parameterized.expand(Add2.test_parameters)
167168
def test_add2_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
168169
test_data = (operand1, operand2)
169170
self._test_add_u55_BI_pipeline(self.Add2(), test_data)

0 commit comments

Comments
 (0)