Skip to content

Commit c9ff559

Browse files
committed
Fix upstream CI failures
Fix order of decorators to expand unittest first, and then parameterized input. Fix bug in add operator conversion to handle different scales correctly. Signed-off-by: Per Åstrand <[email protected]> Change-Id: Ic228cf0215e8171392776739936a53c025802fd5
1 parent 82bd099 commit c9ff559

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

backends/arm/operators/op_add.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,11 @@ def define_node(
4747
input_A, input_A_scale, input_A_zp, _, _, _ = getNodeArgs(input_node_A)
4848
input_B, input_B_scale, input_B_zp, _, _, _ = getNodeArgs(input_node_B)
4949

50-
max_scale_2x = 2.0 * max(input_A_scale.number, input_B_scale.number)
51-
inputA_rescale_scale = input_A_scale.number / max_scale_2x
52-
inputB_rescale_scale = input_B_scale.number / max_scale_2x
50+
# Scale the int8 quantized input to a common scale in the integer
51+
# domain.
52+
min_scale = min(input_A_scale.number, input_B_scale.number)
53+
inputA_rescale_scale = input_A_scale.number / min_scale
54+
inputB_rescale_scale = input_B_scale.number / min_scale
5355

5456
input_A_rescaled_to_int32 = build_rescale_to_int32(
5557
tosa_graph,
@@ -81,7 +83,7 @@ def define_node(
8183
# Output
8284
output_node = list(node.users)[0]
8385
_, output_scale, output_zp, _, _, _ = getNodeArgs(output_node)
84-
output_rescale_scale = max_scale_2x / (output_scale.number)
86+
output_rescale_scale = min_scale / output_scale.number
8587

8688
# Rescale Back to INT8
8789
build_rescale_from_int32(

backends/arm/test/ops/test_add.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class Add2(torch.nn.Module):
4747
(torch.ones(1, 1, 4, 4), torch.ones(1, 1, 4, 4)),
4848
(torch.randn(1, 1, 4, 4), torch.ones(1, 1, 4, 1)),
4949
(torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)),
50+
(10000 * torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)),
5051
]
5152

5253
def __init__(self):
@@ -136,11 +137,11 @@ def test_add_tosa_BI(self, test_data: torch.Tensor):
136137
test_data = (test_data,)
137138
self._test_add_tosa_BI_pipeline(self.Add(), test_data)
138139

140+
@parameterized.expand(Add.test_parameters)
139141
@unittest.skipIf(
140142
not VELA_INSTALLED,
141143
"There is no point in running U55 tests if the Vela tool is not installed",
142144
)
143-
@parameterized.expand(Add.test_parameters)
144145
def test_add_u55_BI(self, test_data: torch.Tensor):
145146
test_data = (test_data,)
146147
self._test_add_u55_BI_pipeline(self.Add(), test_data)
@@ -155,11 +156,11 @@ def test_add2_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
155156
test_data = (operand1, operand2)
156157
self._test_add_tosa_BI_pipeline(self.Add2(), test_data)
157158

159+
@parameterized.expand(Add2.test_parameters)
158160
@unittest.skipIf(
159161
not VELA_INSTALLED,
160162
"There is no point in running U55 tests if the Vela tool is not installed",
161163
)
162-
@parameterized.expand(Add2.test_parameters)
163164
def test_add2_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
164165
test_data = (operand1, operand2)
165166
self._test_add_u55_BI_pipeline(self.Add2(), test_data)

0 commit comments

Comments
 (0)