Skip to content

Fix quantization for input to reference model #2317

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions backends/arm/operators/op_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,11 @@ def define_node(
input_A, input_A_scale, input_A_zp, _, _, _ = getNodeArgs(input_node_A)
input_B, input_B_scale, input_B_zp, _, _, _ = getNodeArgs(input_node_B)

max_scale_2x = 2.0 * max(input_A_scale.number, input_B_scale.number)
inputA_rescale_scale = input_A_scale.number / max_scale_2x
inputB_rescale_scale = input_B_scale.number / max_scale_2x
# Scale the int8 quantized input to a common scale in the integer
# domain.
min_scale = min(input_A_scale.number, input_B_scale.number)
inputA_rescale_scale = input_A_scale.number / min_scale
inputB_rescale_scale = input_B_scale.number / min_scale

broadcasted_shape = broadcast_shapes(input_A.shape, input_B.shape)
if permute_memory_to_nhwc:
Expand Down Expand Up @@ -88,7 +90,7 @@ def define_node(
# Output
output_node = list(node.users)[0]
_, output_scale, output_zp, _, _, _ = getNodeArgs(output_node)
output_rescale_scale = max_scale_2x / (output_scale.number)
output_rescale_scale = min_scale / output_scale.number

# Rescale Back to INT8
build_rescale_from_int32(
Expand Down
12 changes: 10 additions & 2 deletions backends/arm/test/arm_tosa_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,17 @@ def tosa_ref_dump_inputs(
# Torch is doing Input[FP32]->Q[INT8]->DQ[FP32]->Operator[FP32]->Q[INT]->DQ[FP32]->[Output]FP32
# Need to quantize the input to INT8 for TOSA comsumption
if profile is TosaProfile.BI:
int8_max = np.iinfo(np.int8).max
int8_min = np.iinfo(np.int8).min
data = (
(data / input_quantization_scales[name]) - input_quantization_zps[name]
).astype(np.int8)
(
(data / np.float32(input_quantization_scales[name]))
+ input_quantization_zps[name]
)
.round()
.clip(int8_min, int8_max)
.astype(np.int8)
)

if save_on_disk:
file_path = os.path.join(path, name + ".npy")
Expand Down
49 changes: 31 additions & 18 deletions backends/arm/test/ops/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import shutil
import unittest

from typing import Optional, Tuple
from typing import Tuple

import torch
from executorch.backends.arm.test.test_models import TosaProfile
Expand All @@ -30,6 +30,12 @@

class TestSimpleAdd(unittest.TestCase):
class Add(torch.nn.Module):
test_parameters = [
(torch.ones(5),),
(3 * torch.ones(8),),
(10 * torch.randn(8),),
]

def __init__(self):
super().__init__()
self.permute_memory_to_nhwc = False
Expand All @@ -38,6 +44,13 @@ def forward(self, x):
return x + x

class Add2(torch.nn.Module):
test_parameters = [
(torch.ones(1, 1, 4, 4), torch.ones(1, 1, 4, 4)),
(torch.randn(1, 1, 4, 4), torch.ones(1, 1, 4, 1)),
(torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)),
(10000 * torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)),
]

def __init__(self):
super().__init__()
self.permute_memory_to_nhwc = False
Expand Down Expand Up @@ -118,40 +131,40 @@ def _test_add_u55_BI_pipeline(
.to_executorch()
)

def test_add_tosa_MI(self):
test_data = (torch.randn(4, 4, 4),)
@parameterized.expand(Add.test_parameters)
def test_add_tosa_MI(self, test_data: torch.Tensor):
test_data = (test_data,)
self._test_add_tosa_MI_pipeline(self.Add(), test_data)

@parameterized.expand(
[
(torch.ones(5),), # test_data
(3 * torch.ones(8),),
]
)
def test_add_tosa_BI(self, test_data: Optional[Tuple[torch.Tensor]]):
@parameterized.expand(Add.test_parameters)
def test_add_tosa_BI(self, test_data: torch.Tensor):
test_data = (test_data,)
self._test_add_tosa_BI_pipeline(self.Add(), test_data)

@parameterized.expand(Add.test_parameters)
@unittest.skipIf(
not VELA_INSTALLED,
"There is no point in running U55 tests if the Vela tool is not installed",
)
def test_add_u55_BI(self):
test_data = (3 * torch.ones(5),)
def test_add_u55_BI(self, test_data: torch.Tensor):
test_data = (test_data,)
self._test_add_u55_BI_pipeline(self.Add(), test_data)

def test_add2_tosa_MI(self):
test_data = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1))
@parameterized.expand(Add2.test_parameters)
def test_add2_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor):
test_data = (operand1, operand2)
self._test_add_tosa_MI_pipeline(self.Add2(), test_data)

def test_add2_tosa_BI(self):
test_data = (torch.ones(1, 1, 4, 4), torch.ones(1, 1, 4, 1))
@parameterized.expand(Add2.test_parameters)
def test_add2_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
test_data = (operand1, operand2)
self._test_add_tosa_BI_pipeline(self.Add2(), test_data)

@parameterized.expand(Add2.test_parameters)
@unittest.skipIf(
not VELA_INSTALLED,
"There is no point in running U55 tests if the Vela tool is not installed",
)
def test_add2_u55_BI(self):
test_data = (torch.ones(1, 1, 4, 4), torch.ones(1, 1, 4, 1))
def test_add2_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
test_data = (operand1, operand2)
self._test_add_u55_BI_pipeline(self.Add2(), test_data)