Skip to content

Commit eac5f44

Browse files
committed
Fix quantization for input to reference model
Add the zerpoint instead of subtracting. This worked since the tests so far used the ones as inputs which quantize to a zp of -128 which gives the same np.int8 result in both cases since the int8 wraps. Also needs to round and clip the scaled values to the int8 range. Signed-off-by: Per Åstrand <[email protected]> Change-Id: Ideaed6d072a4065573b38fb7476c7dbe8ba814fd
1 parent 0b6add8 commit eac5f44

File tree

2 files changed

+41
-21
lines changed

2 files changed

+41
-21
lines changed

backends/arm/test/arm_tosa_reference.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,17 @@ def tosa_ref_dump_inputs(
135135
# Torch is doing Input[FP32]->Q[INT8]->DQ[FP32]->Operator[FP32]->Q[INT]->DQ[FP32]->[Output]FP32
136136
# Need to quantize the input to INT8 for TOSA comsumption
137137
if profile is TosaProfile.BI:
138+
int8_max = np.iinfo(np.int8).max
139+
int8_min = np.iinfo(np.int8).min
138140
data = (
139-
(data / input_quantization_scales[name]) - input_quantization_zps[name]
140-
).astype(np.int8)
141+
(
142+
(data / np.float32(input_quantization_scales[name]))
143+
+ input_quantization_zps[name]
144+
)
145+
.round()
146+
.clip(int8_min, int8_max)
147+
.astype(np.int8)
148+
)
141149

142150
if save_on_disk:
143151
file_path = os.path.join(path, name + ".npy")

backends/arm/test/ops/test_add.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import shutil
1010
import unittest
1111

12-
from typing import Optional, Tuple
12+
from typing import Tuple
1313

1414
import torch
1515
from executorch.backends.arm.test.test_models import TosaProfile
@@ -30,13 +30,25 @@
3030

3131
class TestSimpleAdd(unittest.TestCase):
3232
class Add(torch.nn.Module):
33+
test_parameters = [
34+
(torch.ones(5),),
35+
(3 * torch.ones(8),),
36+
(10 * torch.randn(8),),
37+
]
38+
3339
def __init__(self):
3440
super().__init__()
3541

3642
def forward(self, x):
3743
return x + x
3844

3945
class Add2(torch.nn.Module):
46+
test_parameters = [
47+
(torch.ones(1, 1, 4, 4), torch.ones(1, 1, 4, 4)),
48+
(torch.randn(1, 1, 4, 4), torch.ones(1, 1, 4, 1)),
49+
(torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)),
50+
]
51+
4052
def __init__(self):
4153
super().__init__()
4254

@@ -88,7 +100,7 @@ def _test_add_tosa_BI_pipeline(
88100
.to_executorch()
89101
)
90102
if TOSA_REF_MODEL_INSTALLED:
91-
tester.run_method().compare_outputs()
103+
tester.run_method().compare_outputs(qtol=1)
92104
else:
93105
logger.warning(
94106
"TOSA ref model tool not installed, skip numerical correctness tests"
@@ -114,42 +126,42 @@ def _test_add_u55_BI_pipeline(
114126
.to_executorch()
115127
)
116128

117-
def test_add_tosa_MI(self):
118-
test_data = (torch.randn(4, 4, 4),)
129+
@parameterized.expand(Add.test_parameters)
130+
def test_add_tosa_MI(self, test_data: torch.Tensor):
131+
test_data = (test_data,)
119132
self._test_add_tosa_MI_pipeline(self.Add(), test_data)
120133

121134
# TODO: Will this type of parametrization be supported? pytest seem
122135
# have issue with it.
123-
@parameterized.expand(
124-
[
125-
(torch.ones(5),), # test_data
126-
(3 * torch.ones(8),),
127-
]
128-
)
129-
def test_add_tosa_BI(self, test_data: Optional[Tuple[torch.Tensor]]):
136+
@parameterized.expand(Add.test_parameters)
137+
def test_add_tosa_BI(self, test_data: torch.Tensor):
130138
test_data = (test_data,)
131139
self._test_add_tosa_BI_pipeline(self.Add(), test_data)
132140

133141
@unittest.skipIf(
134142
not VELA_INSTALLED,
135143
"There is no point in running U55 tests if the Vela tool is not installed",
136144
)
137-
def test_add_u55_BI(self):
138-
test_data = (3 * torch.ones(5),)
145+
@parameterized.expand(Add.test_parameters)
146+
def test_add_u55_BI(self, test_data: torch.Tensor):
147+
test_data = (test_data,)
139148
self._test_add_u55_BI_pipeline(self.Add(), test_data)
140149

141-
def test_add2_tosa_MI(self):
142-
test_data = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1))
150+
@parameterized.expand(Add2.test_parameters)
151+
def test_add2_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor):
152+
test_data = (operand1, operand2)
143153
self._test_add_tosa_MI_pipeline(self.Add2(), test_data)
144154

145-
def test_add2_tosa_BI(self):
146-
test_data = (torch.ones(1, 1, 4, 4), torch.ones(1, 1, 4, 1))
155+
@parameterized.expand(Add2.test_parameters)
156+
def test_add2_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
157+
test_data = (operand1, operand2)
147158
self._test_add_tosa_BI_pipeline(self.Add2(), test_data)
148159

149160
@unittest.skipIf(
150161
not VELA_INSTALLED,
151162
"There is no point in running U55 tests if the Vela tool is not installed",
152163
)
153-
def test_add2_u55_BI(self):
154-
test_data = (torch.ones(1, 1, 4, 4), torch.ones(1, 1, 4, 1))
164+
@parameterized.expand(Add2.test_parameters)
165+
def test_add2_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
166+
test_data = (operand1, operand2)
155167
self._test_add_u55_BI_pipeline(self.Add2(), test_data)

0 commit comments

Comments
 (0)