Skip to content

Commit b0a400c

Browse files
freddan80facebook-github-bot
authored andcommitted
Fix for TOSA BI clamp ops (#3092)
Summary: Min/max range values need to be on quantized form. Pull Request resolved: #3092 Reviewed By: mergennachin Differential Revision: D56476931 Pulled By: digantdesai fbshipit-source-id: 80fe1e4981c048653f808ef1ad9339997eb853a6
1 parent 2f5cbd4 commit b0a400c

File tree

8 files changed

+126
-44
lines changed

8 files changed

+126
-44
lines changed

backends/arm/operators/op_addmm.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def define_node(
7373
quant_node = input_node.all_input_nodes[0]
7474
else:
7575
quant_node = input_node
76-
input_zp = get_quant_node_args(quant_node)[1]
76+
input_zp = get_quant_node_args(quant_node).zp
7777
attr.ConvAttribute(
7878
pad=pad_attr,
7979
stride=stride_attr,
@@ -111,24 +111,21 @@ def define_node(
111111
# rank > 2 linear layer
112112
if input_node.target == exir_ops.edge.aten.view_copy.default:
113113
quant_node = input_node.all_input_nodes[0]
114-
input_scale, _ = get_quant_node_args(quant_node)
114+
input_scale = get_quant_node_args(quant_node).scale
115115
consumer_node = list(node.users)[0]
116116
consumer_consumer_node = list(consumer_node.users)[0]
117-
(
118-
consumer_node_scale,
119-
consumer_node_node_zp,
120-
) = get_quant_node_args(consumer_consumer_node)
121-
117+
quant_args = get_quant_node_args(consumer_consumer_node)
118+
consumer_node_scale = quant_args.scale
119+
consumer_node_node_zp = quant_args.zp
122120
else:
123-
input_scale, _ = get_quant_node_args(input_node)
121+
input_scale = get_quant_node_args(input_node).scale
124122
consumer_node = list(node.users)[0]
125-
(
126-
consumer_node_scale,
127-
consumer_node_node_zp,
128-
) = get_quant_node_args(consumer_node)
123+
quant_args = get_quant_node_args(consumer_node)
124+
consumer_node_scale = quant_args.scale
125+
consumer_node_node_zp = quant_args.zp
129126

130127
weight_node_q_node = weight_node.all_input_nodes[0]
131-
weight_scale, _ = get_quant_node_args(weight_node_q_node)
128+
weight_scale = get_quant_node_args(weight_node_q_node).scale
132129

133130
output_rescale_scale = (input_scale * weight_scale) / consumer_node_scale
134131
(

backends/arm/operators/op_common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ def build_avg_pool_2d_common(
3131
output_zp = 0
3232

3333
if is_quant_node:
34-
_, input_zp = get_quant_node_args(node.args[0])
35-
_, output_zp = get_quant_node_args(list(node.users)[0])
34+
input_zp = get_quant_node_args(node.args[0]).zp
35+
output_zp = get_quant_node_args(list(node.users)[0]).zp
3636

3737
attr = ts.TosaSerializerAttribute()
3838
attr.PoolAttribute(

backends/arm/operators/op_conv2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def define_node(
8080
)
8181

8282
input_zp = (
83-
get_quant_node_args(node.all_input_nodes[0])[1] if is_quant_node else 0
83+
get_quant_node_args(node.all_input_nodes[0]).zp if is_quant_node else 0
8484
)
8585

8686
attr.ConvAttribute(

backends/arm/operators/op_hardtanh.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 Arm Limited and/or its affiliates.
1+
# Copyright 2023-2024 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -11,6 +11,8 @@
1111
register_node_visitor,
1212
)
1313
from executorch.backends.arm.tosa_mapping import TosaArg
14+
15+
from executorch.backends.arm.tosa_quant_utils import get_quant_node_args
1416
from serializer.tosa_serializer import TosaOp
1517

1618

@@ -30,12 +32,31 @@ def define_node(
3032
is_quant_node: bool,
3133
) -> None:
3234
attr = ts.TosaSerializerAttribute()
35+
36+
if is_quant_node:
37+
# Get quant parameters
38+
scale, zp, qmin, qmax = get_quant_node_args(node.all_input_nodes[0])
39+
# Convert to quantized representation
40+
clamp_min_qs = round((inputs[1].number / scale) + zp)
41+
clamp_min_qs = max(clamp_min_qs, qmin)
42+
clamp_max_qs = round((inputs[2].number / scale) + zp)
43+
clamp_max_qs = min(clamp_max_qs, qmax)
44+
# Set fp values to 0.0 since they are not used
45+
clamp_min_fp = 0.0
46+
clamp_max_fp = 0.0
47+
else:
48+
clamp_min_fp = inputs[1].number
49+
clamp_max_fp = inputs[2].number
50+
# Set qs values to 0 since they are not used
51+
clamp_min_qs = 0
52+
clamp_max_qs = 0
53+
3354
attr.ClampAttribute(
3455
tosa_graph.builder,
35-
int(inputs[1].number),
36-
int(inputs[2].number),
37-
inputs[1].number,
38-
inputs[2].number,
56+
clamp_min_qs,
57+
clamp_max_qs,
58+
clamp_min_fp,
59+
clamp_max_fp,
3960
)
4061

4162
tosa_graph.addOperator(TosaOp.Op().CLAMP, [inputs[0].name], [output.name], attr)

backends/arm/operators/op_placeholder.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,13 @@ def process_placeholder(
5050
weight_node = weight_node_permuted.all_input_nodes[0]
5151

5252
if input_node.target == exir_ops.edge.aten.view_copy.default:
53-
input_node_scale, _ = get_quant_node_args(input_node.all_input_nodes[0])
53+
input_node_scale = get_quant_node_args(
54+
input_node.all_input_nodes[0]
55+
).scale
5456
else:
55-
input_node_scale, _ = get_quant_node_args(input_node)
57+
input_node_scale = get_quant_node_args(input_node).scale
5658

57-
weight_node_scale, _ = get_quant_node_args(weight_node)
59+
weight_node_scale = get_quant_node_args(weight_node).scale
5860

5961
bias_values_quantized = (
6062
(parameter_values / (input_node_scale * weight_node_scale))
@@ -81,8 +83,8 @@ def process_placeholder(
8183
bias_node,
8284
) = consumer_node.all_input_nodes
8385

84-
input_node_scale, _ = get_quant_node_args(input_node)
85-
weight_node_scale, _ = get_quant_node_args(weight_node)
86+
input_node_scale = get_quant_node_args(input_node).scale
87+
weight_node_scale = get_quant_node_args(weight_node).scale
8688

8789
bias_scales = input_node_scale * weight_node_scale
8890
parameter_values_quantized = (

backends/arm/test/ops/test_conv_combos.py

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch
1313
from executorch.backends.arm.test import common
1414
from executorch.backends.arm.test.tester.arm_tester import ArmTester
15+
from parameterized import parameterized
1516

1617
logger = logging.getLogger(__name__)
1718
logger.setLevel(logging.INFO)
@@ -126,6 +127,32 @@ def forward(self, x):
126127
return x
127128

128129

130+
class ComboConvRelu6(torch.nn.Module):
131+
edge_op_list = [
132+
"executorch_exir_dialects_edge__ops_aten_convolution_default",
133+
"executorch_exir_dialects_edge__ops_aten_hardtanh_default",
134+
]
135+
136+
test_data = [
137+
(20 * torch.randn(1, 3, 256, 256),),
138+
(5 * torch.randn(1, 3, 256, 256),),
139+
(torch.randn(1, 3, 256, 256),),
140+
(-5 * torch.randn(1, 3, 256, 256),),
141+
]
142+
143+
def __init__(self):
144+
super().__init__()
145+
self.conv2d = torch.nn.Conv2d(
146+
in_channels=3, out_channels=3, kernel_size=3, stride=1, groups=1
147+
)
148+
self.relu6 = torch.nn.ReLU6()
149+
150+
def forward(self, x):
151+
x = self.conv2d(x)
152+
x = self.relu6(x)
153+
return x
154+
155+
129156
class TestConvCombos(unittest.TestCase):
130157
def _test_conv_combo_tosa_MI_pipeline(
131158
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
@@ -222,15 +249,9 @@ def test_conv_batchnorm_relu_tosa_MI(self):
222249
model = ComboConvBatchnormRelu()
223250
self._test_conv_combo_tosa_MI_pipeline(model, model.get_inputs())
224251

225-
# TODO(MLETORCH-85): Investigate numerical issue. This diff is present in legacy
226-
# testcase as well (and also not tested). For now, just increase the
227-
# tolerance, such that we don't skip the test entirely (i.e. we maintain
228-
# functionality).
229252
def test_conv_batchnorm_relu_tosa_BI(self):
230253
model = ComboConvBatchnormRelu()
231-
self._test_conv_combo_tosa_BI_pipeline(
232-
model, model.get_inputs(), atol=1.0, rtol=1.0
233-
)
254+
self._test_conv_combo_tosa_BI_pipeline(model, model.get_inputs())
234255

235256
@unittest.skipIf(
236257
not common.VELA_INSTALLED,
@@ -240,21 +261,41 @@ def test_conv_batchnorm_relu_u55_BI(self):
240261
model = ComboConvBatchnormRelu()
241262
self._test_conv_combo_u55_BI_pipeline(model, model.get_inputs())
242263

264+
##################
265+
## Conv + ReLU6 ##
266+
##################
267+
@parameterized.expand(ComboConvRelu6.test_data)
268+
def test_conv_relu6_tosa_MI(self, test_data: torch.Tensor):
269+
model = ComboConvRelu6()
270+
test_data = (test_data,)
271+
self._test_conv_combo_tosa_MI_pipeline(model, test_data)
272+
273+
@parameterized.expand(ComboConvRelu6.test_data)
274+
def test_conv_relu6_tosa_BI(self, test_data: torch.Tensor):
275+
model = ComboConvRelu6()
276+
test_data = (test_data,)
277+
self._test_conv_combo_tosa_BI_pipeline(model, test_data)
278+
279+
@parameterized.expand(ComboConvRelu6.test_data)
280+
@unittest.skipIf(
281+
not common.VELA_INSTALLED,
282+
"There is no point in running U55 tests if the Vela tool is not installed",
283+
)
284+
def test_conv_relu6_u55_BI(self, test_data: torch.Tensor):
285+
model = ComboConvRelu6()
286+
test_data = (test_data,)
287+
self._test_conv_combo_u55_BI_pipeline(model, test_data)
288+
243289
###############################
244290
## Block bottleneck residual ##
245291
###############################
246292
def test_block_bottleneck_residual_tosa_MI(self):
247293
model = ComboBlockBottleneckResidual()
248294
self._test_conv_combo_tosa_MI_pipeline(model, model.get_inputs())
249295

250-
# TODO(MLETORCH-85): Investigate numerical issue. This diff was present in legacy
251-
# testcase as well. For now, just increase the tolerance, such that
252-
# we don't skip the test entirely (i.e. we maintain functionality).
253296
def test_block_bottleneck_residual_tosa_BI(self):
254297
model = ComboBlockBottleneckResidual()
255-
self._test_conv_combo_tosa_BI_pipeline(
256-
model, model.get_inputs(), atol=1.0, rtol=1.0
257-
)
298+
self._test_conv_combo_tosa_BI_pipeline(model, model.get_inputs())
258299

259300
@unittest.skipIf(
260301
not common.VELA_INSTALLED,

backends/arm/tosa_quant_utils.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
# Utiliy functions for TOSA quantized lowerings
77

88
import math
9+
from typing import NamedTuple
910

1011
import serializer.tosa_serializer as ts
12+
import torch.fx
1113
from executorch.backends.arm.tosa_mapping import TosaArg
1214
from executorch.exir.dialects._ops import ops as exir_ops
1315
from serializer.tosa_serializer import TosaOp, TosaSerializerTensor
@@ -17,7 +19,14 @@
1719
dq_q_ops = [q_op, dq_op]
1820

1921

20-
def is_quant_node(node):
22+
class QuantArgs(NamedTuple):
23+
scale: float
24+
zp: int
25+
qmin: int
26+
qmax: int
27+
28+
29+
def is_quant_node(node: torch.fx.Node):
2130
consumer_node = list(node.users)[0]
2231
input = node.all_input_nodes[0]
2332

@@ -41,10 +50,22 @@ def is_quant_arg(arg):
4150
return consumer_node.target == q_op
4251

4352

44-
def get_quant_node_args(node):
53+
def get_quant_node_args(node: torch.fx.Node):
54+
"""
55+
Get the quantization parameters from a quant node.
56+
57+
Args:
58+
node: The quant node.
59+
Returns:
60+
QuantArgs: scale, zp, qmin, qmax
61+
"""
4562
quant_args = [TosaArg(arg) for arg in node.args]
46-
# Return the scale and zp
47-
return quant_args[1].number, quant_args[2].number
63+
return QuantArgs(
64+
quant_args[1].number,
65+
quant_args[2].number,
66+
quant_args[3].number,
67+
quant_args[4].number,
68+
)
4869

4970

5071
# Check if scale32 mode is used for given output element type

backends/xnnpack/test/tester/tester.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,7 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
595595
f"Output {i} does not match reference output.\n"
596596
f"\tGiven atol: {atol}, rtol: {rtol}.\n"
597597
f"\tOutput tensor shape: {model.shape}, dtype: {model.dtype}\n"
598-
f"\tDifference: max: {torch.max(model-ref)}, abs: {torch.max(torch.abs(model-ref))}.\n"
598+
f"\tDifference: max: {torch.max(model-ref)}, abs: {torch.max(torch.abs(model-ref))}, mean abs error: {torch.mean(torch.abs(model-ref))}.\n"
599599
f"\t-- Model vs. Reference --\n"
600600
f"\t Numel: {model.numel()}, {ref.numel()}\n"
601601
f"\tMedian: {model.median()}, {ref.median()}\n"

0 commit comments

Comments
 (0)