Skip to content

Commit 82bd099

Browse files
committed
Fix quantization for input to reference model
Add the zerpoint instead of subtracting. This worked since the tests so far used the ones as inputs which quantize to a zp of -128 which gives the same np.int8 result in both cases since the int8 wraps. Also needs to round and clip the scaled values to the int8 range. Signed-off-by: Per Åstrand <[email protected]> Change-Id: Ideaed6d072a4065573b38fb7476c7dbe8ba814fd
1 parent 45df800 commit 82bd099

File tree

2 files changed

+40
-20
lines changed

2 files changed

+40
-20
lines changed

backends/arm/test/arm_tosa_reference.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,17 @@ def tosa_ref_dump_inputs(
133133
# Torch is doing Input[FP32]->Q[INT8]->DQ[FP32]->Operator[FP32]->Q[INT]->DQ[FP32]->[Output]FP32
134134
# Need to quantize the input to INT8 for TOSA comsumption
135135
if profile is TosaProfile.BI:
136+
int8_max = np.iinfo(np.int8).max
137+
int8_min = np.iinfo(np.int8).min
136138
data = (
137-
(data / input_quantization_scales[name]) - input_quantization_zps[name]
138-
).astype(np.int8)
139+
(
140+
(data / np.float32(input_quantization_scales[name]))
141+
+ input_quantization_zps[name]
142+
)
143+
.round()
144+
.clip(int8_min, int8_max)
145+
.astype(np.int8)
146+
)
139147

140148
if save_on_disk:
141149
file_path = os.path.join(path, name + ".npy")

backends/arm/test/ops/test_add.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import shutil
1010
import unittest
1111

12-
from typing import Optional, Tuple
12+
from typing import Tuple
1313

1414
import torch
1515
from executorch.backends.arm.test.test_models import TosaProfile
@@ -30,13 +30,25 @@
3030

3131
class TestSimpleAdd(unittest.TestCase):
3232
class Add(torch.nn.Module):
33+
test_parameters = [
34+
(torch.ones(5),),
35+
(3 * torch.ones(8),),
36+
(10 * torch.randn(8),),
37+
]
38+
3339
def __init__(self):
3440
super().__init__()
3541

3642
def forward(self, x):
3743
return x + x
3844

3945
class Add2(torch.nn.Module):
46+
test_parameters = [
47+
(torch.ones(1, 1, 4, 4), torch.ones(1, 1, 4, 4)),
48+
(torch.randn(1, 1, 4, 4), torch.ones(1, 1, 4, 1)),
49+
(torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)),
50+
]
51+
4052
def __init__(self):
4153
super().__init__()
4254

@@ -114,40 +126,40 @@ def _test_add_u55_BI_pipeline(
114126
.to_executorch()
115127
)
116128

117-
def test_add_tosa_MI(self):
118-
test_data = (torch.randn(4, 4, 4),)
129+
@parameterized.expand(Add.test_parameters)
130+
def test_add_tosa_MI(self, test_data: torch.Tensor):
131+
test_data = (test_data,)
119132
self._test_add_tosa_MI_pipeline(self.Add(), test_data)
120133

121-
@parameterized.expand(
122-
[
123-
(torch.ones(5),), # test_data
124-
(3 * torch.ones(8),),
125-
]
126-
)
127-
def test_add_tosa_BI(self, test_data: Optional[Tuple[torch.Tensor]]):
134+
@parameterized.expand(Add.test_parameters)
135+
def test_add_tosa_BI(self, test_data: torch.Tensor):
128136
test_data = (test_data,)
129137
self._test_add_tosa_BI_pipeline(self.Add(), test_data)
130138

131139
@unittest.skipIf(
132140
not VELA_INSTALLED,
133141
"There is no point in running U55 tests if the Vela tool is not installed",
134142
)
135-
def test_add_u55_BI(self):
136-
test_data = (3 * torch.ones(5),)
143+
@parameterized.expand(Add.test_parameters)
144+
def test_add_u55_BI(self, test_data: torch.Tensor):
145+
test_data = (test_data,)
137146
self._test_add_u55_BI_pipeline(self.Add(), test_data)
138147

139-
def test_add2_tosa_MI(self):
140-
test_data = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1))
148+
@parameterized.expand(Add2.test_parameters)
149+
def test_add2_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor):
150+
test_data = (operand1, operand2)
141151
self._test_add_tosa_MI_pipeline(self.Add2(), test_data)
142152

143-
def test_add2_tosa_BI(self):
144-
test_data = (torch.ones(1, 1, 4, 4), torch.ones(1, 1, 4, 1))
153+
@parameterized.expand(Add2.test_parameters)
154+
def test_add2_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
155+
test_data = (operand1, operand2)
145156
self._test_add_tosa_BI_pipeline(self.Add2(), test_data)
146157

147158
@unittest.skipIf(
148159
not VELA_INSTALLED,
149160
"There is no point in running U55 tests if the Vela tool is not installed",
150161
)
151-
def test_add2_u55_BI(self):
152-
test_data = (torch.ones(1, 1, 4, 4), torch.ones(1, 1, 4, 1))
162+
@parameterized.expand(Add2.test_parameters)
163+
def test_add2_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
164+
test_data = (operand1, operand2)
153165
self._test_add_u55_BI_pipeline(self.Add2(), test_data)

0 commit comments

Comments
 (0)