Skip to content

Commit 675cdb3

Browse files
committed
Fix quantization for input to reference model
Add the zerpoint instead of subtracting. This worked since the tests so far used the ones as inputs which quantize to a zp of -128 which gives the same np.int8 result in both cases since the int8 wraps. Also needs to round and clip the scaled values to the int8 range. Signed-off-by: Per Åstrand <[email protected]> Change-Id: Ideaed6d072a4065573b38fb7476c7dbe8ba814fd
1 parent f9cad4e commit 675cdb3

File tree

2 files changed

+40
-20
lines changed

2 files changed

+40
-20
lines changed

backends/arm/test/arm_tosa_reference.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,17 @@ def tosa_ref_dump_inputs(
139139
# Torch is doing Input[FP32]->Q[INT8]->DQ[FP32]->Operator[FP32]->Q[INT]->DQ[FP32]->[Output]FP32
140140
# Need to quantize the input to INT8 for TOSA comsumption
141141
if profile is TosaProfile.BI:
142+
int8_max = np.iinfo(np.int8).max
143+
int8_min = np.iinfo(np.int8).min
142144
data = (
143-
(data / input_quantization_scales[name]) - input_quantization_zps[name]
144-
).astype(np.int8)
145+
(
146+
(data / np.float32(input_quantization_scales[name]))
147+
+ input_quantization_zps[name]
148+
)
149+
.round()
150+
.clip(int8_min, int8_max)
151+
.astype(np.int8)
152+
)
145153

146154
if save_on_disk:
147155
file_path = os.path.join(path, name + ".npy")

backends/arm/test/ops/test_add.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import shutil
1010
import unittest
1111

12-
from typing import Optional, Tuple
12+
from typing import Tuple
1313

1414
import torch
1515
from executorch.backends.arm.test.test_models import TosaProfile
@@ -30,6 +30,12 @@
3030

3131
class TestSimpleAdd(unittest.TestCase):
3232
class Add(torch.nn.Module):
33+
test_parameters = [
34+
(torch.ones(5),),
35+
(3 * torch.ones(8),),
36+
(10 * torch.randn(8),),
37+
]
38+
3339
def __init__(self):
3440
super().__init__()
3541
self.permute_memory_to_nhwc = False
@@ -38,6 +44,12 @@ def forward(self, x):
3844
return x + x
3945

4046
class Add2(torch.nn.Module):
47+
test_parameters = [
48+
(torch.ones(1, 1, 4, 4), torch.ones(1, 1, 4, 4)),
49+
(torch.randn(1, 1, 4, 4), torch.ones(1, 1, 4, 1)),
50+
(torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)),
51+
]
52+
4153
def __init__(self):
4254
super().__init__()
4355
self.permute_memory_to_nhwc = False
@@ -118,40 +130,40 @@ def _test_add_u55_BI_pipeline(
118130
.to_executorch()
119131
)
120132

121-
def test_add_tosa_MI(self):
122-
test_data = (torch.randn(4, 4, 4),)
133+
@parameterized.expand(Add.test_parameters)
134+
def test_add_tosa_MI(self, test_data: torch.Tensor):
135+
test_data = (test_data,)
123136
self._test_add_tosa_MI_pipeline(self.Add(), test_data)
124137

125-
@parameterized.expand(
126-
[
127-
(torch.ones(5),), # test_data
128-
(3 * torch.ones(8),),
129-
]
130-
)
131-
def test_add_tosa_BI(self, test_data: Optional[Tuple[torch.Tensor]]):
138+
@parameterized.expand(Add.test_parameters)
139+
def test_add_tosa_BI(self, test_data: torch.Tensor):
132140
test_data = (test_data,)
133141
self._test_add_tosa_BI_pipeline(self.Add(), test_data)
134142

135143
@unittest.skipIf(
136144
not VELA_INSTALLED,
137145
"There is no point in running U55 tests if the Vela tool is not installed",
138146
)
139-
def test_add_u55_BI(self):
140-
test_data = (3 * torch.ones(5),)
147+
@parameterized.expand(Add.test_parameters)
148+
def test_add_u55_BI(self, test_data: torch.Tensor):
149+
test_data = (test_data,)
141150
self._test_add_u55_BI_pipeline(self.Add(), test_data)
142151

143-
def test_add2_tosa_MI(self):
144-
test_data = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1))
152+
@parameterized.expand(Add2.test_parameters)
153+
def test_add2_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor):
154+
test_data = (operand1, operand2)
145155
self._test_add_tosa_MI_pipeline(self.Add2(), test_data)
146156

147-
def test_add2_tosa_BI(self):
148-
test_data = (torch.ones(1, 1, 4, 4), torch.ones(1, 1, 4, 1))
157+
@parameterized.expand(Add2.test_parameters)
158+
def test_add2_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
159+
test_data = (operand1, operand2)
149160
self._test_add_tosa_BI_pipeline(self.Add2(), test_data)
150161

151162
@unittest.skipIf(
152163
not VELA_INSTALLED,
153164
"There is no point in running U55 tests if the Vela tool is not installed",
154165
)
155-
def test_add2_u55_BI(self):
156-
test_data = (torch.ones(1, 1, 4, 4), torch.ones(1, 1, 4, 1))
166+
@parameterized.expand(Add2.test_parameters)
167+
def test_add2_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
168+
test_data = (operand1, operand2)
157169
self._test_add_u55_BI_pipeline(self.Add2(), test_data)

0 commit comments

Comments
 (0)