Skip to content

Commit d06ccd2

Browse files
perfacebook-github-bot
authored andcommitted
Fix quantization for input to reference model (#2317)
Summary: Add the zerpoint instead of subtracting. This worked since the tests so far used the ones as inputs which quantize to a zp of -128 which gives the same np.int8 result in both cases since the int8 wraps. Also needs to round and clip the scaled values to the int8 range. Signed-off-by: Per Åstrand <[email protected]> Pull Request resolved: #2317 Reviewed By: mergennachin Differential Revision: D55201623 Pulled By: digantdesai fbshipit-source-id: 81a07186f3ebc4adb75af28cee109ab6ed4d0de8
1 parent 9c7bb45 commit d06ccd2

File tree

3 files changed

+47
-24
lines changed

3 files changed

+47
-24
lines changed

backends/arm/operators/op_add.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,11 @@ def define_node(
4848
input_A, input_A_scale, input_A_zp, _, _, _ = getNodeArgs(input_node_A)
4949
input_B, input_B_scale, input_B_zp, _, _, _ = getNodeArgs(input_node_B)
5050

51-
max_scale_2x = 2.0 * max(input_A_scale.number, input_B_scale.number)
52-
inputA_rescale_scale = input_A_scale.number / max_scale_2x
53-
inputB_rescale_scale = input_B_scale.number / max_scale_2x
51+
# Scale the int8 quantized input to a common scale in the integer
52+
# domain.
53+
min_scale = min(input_A_scale.number, input_B_scale.number)
54+
inputA_rescale_scale = input_A_scale.number / min_scale
55+
inputB_rescale_scale = input_B_scale.number / min_scale
5456

5557
broadcasted_shape = broadcast_shapes(input_A.shape, input_B.shape)
5658
if permute_memory_to_nhwc:
@@ -88,7 +90,7 @@ def define_node(
8890
# Output
8991
output_node = list(node.users)[0]
9092
_, output_scale, output_zp, _, _, _ = getNodeArgs(output_node)
91-
output_rescale_scale = max_scale_2x / (output_scale.number)
93+
output_rescale_scale = min_scale / output_scale.number
9294

9395
# Rescale Back to INT8
9496
build_rescale_from_int32(

backends/arm/test/arm_tosa_reference.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,17 @@ def tosa_ref_dump_inputs(
139139
# Torch is doing Input[FP32]->Q[INT8]->DQ[FP32]->Operator[FP32]->Q[INT]->DQ[FP32]->[Output]FP32
140140
# Need to quantize the input to INT8 for TOSA comsumption
141141
if profile is TosaProfile.BI:
142+
int8_max = np.iinfo(np.int8).max
143+
int8_min = np.iinfo(np.int8).min
142144
data = (
143-
(data / input_quantization_scales[name]) - input_quantization_zps[name]
144-
).astype(np.int8)
145+
(
146+
(data / np.float32(input_quantization_scales[name]))
147+
+ input_quantization_zps[name]
148+
)
149+
.round()
150+
.clip(int8_min, int8_max)
151+
.astype(np.int8)
152+
)
145153

146154
if save_on_disk:
147155
file_path = os.path.join(path, name + ".npy")

backends/arm/test/ops/test_add.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import shutil
1010
import unittest
1111

12-
from typing import Optional, Tuple
12+
from typing import Tuple
1313

1414
import torch
1515
from executorch.backends.arm.test.test_models import TosaProfile
@@ -30,6 +30,12 @@
3030

3131
class TestSimpleAdd(unittest.TestCase):
3232
class Add(torch.nn.Module):
33+
test_parameters = [
34+
(torch.ones(5),),
35+
(3 * torch.ones(8),),
36+
(10 * torch.randn(8),),
37+
]
38+
3339
def __init__(self):
3440
super().__init__()
3541
self.permute_memory_to_nhwc = False
@@ -38,6 +44,13 @@ def forward(self, x):
3844
return x + x
3945

4046
class Add2(torch.nn.Module):
47+
test_parameters = [
48+
(torch.ones(1, 1, 4, 4), torch.ones(1, 1, 4, 4)),
49+
(torch.randn(1, 1, 4, 4), torch.ones(1, 1, 4, 1)),
50+
(torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)),
51+
(10000 * torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)),
52+
]
53+
4154
def __init__(self):
4255
super().__init__()
4356
self.permute_memory_to_nhwc = False
@@ -118,40 +131,40 @@ def _test_add_u55_BI_pipeline(
118131
.to_executorch()
119132
)
120133

121-
def test_add_tosa_MI(self):
122-
test_data = (torch.randn(4, 4, 4),)
134+
@parameterized.expand(Add.test_parameters)
135+
def test_add_tosa_MI(self, test_data: torch.Tensor):
136+
test_data = (test_data,)
123137
self._test_add_tosa_MI_pipeline(self.Add(), test_data)
124138

125-
@parameterized.expand(
126-
[
127-
(torch.ones(5),), # test_data
128-
(3 * torch.ones(8),),
129-
]
130-
)
131-
def test_add_tosa_BI(self, test_data: Optional[Tuple[torch.Tensor]]):
139+
@parameterized.expand(Add.test_parameters)
140+
def test_add_tosa_BI(self, test_data: torch.Tensor):
132141
test_data = (test_data,)
133142
self._test_add_tosa_BI_pipeline(self.Add(), test_data)
134143

144+
@parameterized.expand(Add.test_parameters)
135145
@unittest.skipIf(
136146
not VELA_INSTALLED,
137147
"There is no point in running U55 tests if the Vela tool is not installed",
138148
)
139-
def test_add_u55_BI(self):
140-
test_data = (3 * torch.ones(5),)
149+
def test_add_u55_BI(self, test_data: torch.Tensor):
150+
test_data = (test_data,)
141151
self._test_add_u55_BI_pipeline(self.Add(), test_data)
142152

143-
def test_add2_tosa_MI(self):
144-
test_data = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1))
153+
@parameterized.expand(Add2.test_parameters)
154+
def test_add2_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor):
155+
test_data = (operand1, operand2)
145156
self._test_add_tosa_MI_pipeline(self.Add2(), test_data)
146157

147-
def test_add2_tosa_BI(self):
148-
test_data = (torch.ones(1, 1, 4, 4), torch.ones(1, 1, 4, 1))
158+
@parameterized.expand(Add2.test_parameters)
159+
def test_add2_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
160+
test_data = (operand1, operand2)
149161
self._test_add_tosa_BI_pipeline(self.Add2(), test_data)
150162

163+
@parameterized.expand(Add2.test_parameters)
151164
@unittest.skipIf(
152165
not VELA_INSTALLED,
153166
"There is no point in running U55 tests if the Vela tool is not installed",
154167
)
155-
def test_add2_u55_BI(self):
156-
test_data = (torch.ones(1, 1, 4, 4), torch.ones(1, 1, 4, 1))
168+
def test_add2_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
169+
test_data = (operand1, operand2)
157170
self._test_add_u55_BI_pipeline(self.Add2(), test_data)

0 commit comments

Comments
 (0)