Skip to content

Commit 001cc5f

Browse files
freddan80facebook-github-bot
authored andcommitted
Add convolution unit test for Arm backend (#2427)
Summary: * Added both depthwise and regular conv * Removed corresponding conv legacy unit tests * Fix zero point usage solves issues with random input * Removed add/add2 legacy unit tests as they are already implemented * Bump serialization lib submodule to avoid warning spam * Fixed rounding issue tosa_test_utils Pull Request resolved: #2427 Reviewed By: mcr229 Differential Revision: D55228457 Pulled By: digantdesai fbshipit-source-id: 2e7c9acfba582916c9d5c8e047f5aa57a37807f3
1 parent f6803b8 commit 001cc5f

File tree

12 files changed

+638
-494
lines changed

12 files changed

+638
-494
lines changed

backends/arm/operators/op_conv2d.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
register_node_visitor,
1212
)
1313
from executorch.backends.arm.tosa_mapping import TosaArg
14-
from executorch.backends.arm.tosa_quant_utils import build_rescale_conv_output
14+
from executorch.backends.arm.tosa_quant_utils import (
15+
build_rescale_conv_output,
16+
get_quant_node_args,
17+
)
1518
from executorch.backends.arm.tosa_utils import build_reshape, getNodeArgs
1619

1720
from serializer.tosa_serializer import TosaOp
@@ -76,11 +79,15 @@ def define_node(
7679
dilation_attr[1],
7780
)
7881

82+
input_zp = (
83+
get_quant_node_args(node.all_input_nodes[0])[1] if is_quant_node else 0
84+
)
85+
7986
attr.ConvAttribute(
8087
pad=pad_attr,
8188
stride=stride_attr,
8289
dilation=dilation_attr,
83-
input_zp=0,
90+
input_zp=input_zp,
8491
weight_zp=0,
8592
local_bound=False,
8693
)
@@ -125,37 +132,31 @@ def define_node(
125132
build_reshape(
126133
tosa_graph, weight.name, weight_post_shape, weight_reshaped.name
127134
)
128-
129-
tosa_graph.addOperator(
130-
TosaOp.Op().DEPTHWISE_CONV2D,
131-
[
132-
input.name,
133-
weight_reshaped.name,
134-
bias.name,
135-
],
136-
[conv2d_output_name],
137-
attr,
138-
)
135+
tosa_op = TosaOp.Op().DEPTHWISE_CONV2D
136+
weight_name = weight_reshaped.name
139137
else:
140138
"""Regular convolution case"""
141-
tosa_graph.addOperator(
142-
TosaOp.Op().CONV2D,
143-
[
144-
input.name,
145-
weight.name,
146-
bias.name,
147-
],
148-
[conv2d_output_name],
149-
attr,
150-
)
139+
tosa_op = TosaOp.Op().CONV2D
140+
weight_name = weight.name
141+
142+
tosa_graph.addOperator(
143+
tosa_op,
144+
[
145+
input.name,
146+
weight_name,
147+
bias.name,
148+
],
149+
[conv2d_output_name],
150+
attr,
151+
)
151152

152153
# For quantized convolution, rescale the output value back to the same
153154
# integer value domain of the next op. Otherwise return float32 output.
154155
if is_quant_node:
155156
# Get scale_factor from input, weight, and output.
156157
_, input_scale, _, _, _, _ = getNodeArgs(node.args[0])
157158
_, weight_scale, _, _, _, _ = getNodeArgs(node.args[1])
158-
_, output_scale, _, _, _, _ = getNodeArgs(list(node.users)[0])
159+
_, output_scale, output_zp, _, _, _ = getNodeArgs(list(node.users)[0])
159160
build_rescale_conv_output(
160161
tosa_graph,
161162
conv2d_res,
@@ -164,4 +165,5 @@ def define_node(
164165
input_scale,
165166
weight_scale,
166167
output_scale,
168+
output_zp,
167169
)

backends/arm/test/arm_tosa_reference.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,7 @@
3939
)
4040

4141
SUPPORTED_BI_TEST_LIST = [
42-
"simple_add",
4342
"simple_add_broadcast",
44-
"simple_conv2d_3x3_1x3x256x256_stride1",
45-
"simple_conv2d_1x1_1x2x128x128_stride1",
46-
"simple_conv2d_2x2_1x1x14x14_stride2",
47-
"simple_conv2d_5x5_3x2x128x128_stride1",
48-
"simple_conv2d_2x2_3x1x40x40_non_bias",
49-
"block_two_conv2d",
50-
"block_two_conv2d_non_bias",
5143
"block_bottleneck_residual",
5244
]
5345

backends/arm/test/common.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import shutil
8+
9+
# TODO: fixme! These globs are a temporary workaround. Reasoning:
10+
# Running the jobs in _unittest.yml will not work since that environment doesn't
11+
# have the vela tool, nor the tosa_reference_model tool. Hence, we need a way to
12+
# run what we can in that env temporarily. Long term, vela and tosa_reference_model
13+
# should be installed in the CI env.
14+
TOSA_REF_MODEL_INSTALLED = shutil.which("tosa_reference_model")
15+
VELA_INSTALLED = shutil.which("vela")

backends/arm/test/models/test_mobilenet_v2_arm.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,16 @@
66
# LICENSE file in the root directory of this source tree.
77

88
import logging
9-
import shutil
109
import unittest
1110

1211
import torch
1312
import torchvision.models as models
13+
from executorch.backends.arm.test import common
1414
from executorch.backends.arm.test.test_models import TosaProfile
1515
from executorch.backends.arm.test.tester.arm_tester import ArmBackendSelector, ArmTester
1616
from executorch.backends.xnnpack.test.tester.tester import Quantize
1717
from torchvision.models.mobilenetv2 import MobileNet_V2_Weights
1818

19-
# TODO: fixme! These globs are a temporary workaround. Reasoning:
20-
# Running the jobs in _unittest.yml will not work since that environment don't
21-
# have the vela tool, nor the tosa_reference_model tool. Hence, we need a way to
22-
# run what we can in that env temporarily. Long term, vela and tosa_reference_model
23-
# should be installed in the CI env.
24-
TOSA_REF_MODEL_INSTALLED = shutil.which("tosa_reference_model")
2519

2620
logger = logging.getLogger(__name__)
2721
logger.setLevel(logging.INFO)
@@ -79,15 +73,17 @@ def test_mv2_tosa_BI(self):
7973
.partition()
8074
.to_executorch()
8175
)
82-
83-
if TOSA_REF_MODEL_INSTALLED:
76+
if common.TOSA_REF_MODEL_INSTALLED:
8477
tester.run_method().compare_outputs()
8578
else:
8679
logger.warning(
8780
"TOSA ref model tool not installed, skip numerical correctness tests"
8881
)
8982

90-
@unittest.skip("This test is not supported yet")
83+
@unittest.skipIf(
84+
not common.VELA_INSTALLED,
85+
"There is no point in running U55 tests if the Vela tool is not installed",
86+
)
9187
def test_mv2_u55_BI(self):
9288
(
9389
ArmTester(

backends/arm/test/ops/test_add.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,26 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
# Copyright 2023-2024 Arm Limited and/or its affiliates.
2+
# Copyright 2024 Arm Limited and/or its affiliates.
33
# All rights reserved.
44
#
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

88
import logging
9-
import shutil
109
import unittest
1110

1211
from typing import Tuple
1312

1413
import torch
14+
from executorch.backends.arm.test import common
1515
from executorch.backends.arm.test.test_models import TosaProfile
1616
from executorch.backends.arm.test.tester.arm_tester import ArmBackendSelector, ArmTester
1717
from parameterized import parameterized
1818

19-
# TODO: fixme! These globs are a temporary workaround. Reasoning:
20-
# Running the jobs in _unittest.yml will not work since that environment don't
21-
# have the vela tool, nor the tosa_reference_model tool. Hence, we need a way to
22-
# run what we can in that env temporarily. Long term, vela and tosa_reference_model
23-
# should be installed in the CI env.
24-
TOSA_REF_MODEL_INSTALLED = shutil.which("tosa_reference_model")
25-
VELA_INSTALLED = shutil.which("vela")
26-
2719
logger = logging.getLogger(__name__)
2820
logger.setLevel(logging.INFO)
2921

22+
torch.manual_seed(42)
23+
3024

3125
class TestSimpleAdd(unittest.TestCase):
3226
class Add(torch.nn.Module):
@@ -77,7 +71,7 @@ def _test_add_tosa_MI_pipeline(
7771
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
7872
.to_executorch()
7973
)
80-
if TOSA_REF_MODEL_INSTALLED:
74+
if common.TOSA_REF_MODEL_INSTALLED:
8175
tester.run_method().compare_outputs()
8276
else:
8377
logger.warning(
@@ -104,7 +98,8 @@ def _test_add_tosa_BI_pipeline(
10498
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
10599
.to_executorch()
106100
)
107-
if TOSA_REF_MODEL_INSTALLED:
101+
102+
if common.TOSA_REF_MODEL_INSTALLED:
108103
tester.run_method().compare_outputs(qtol=1)
109104
else:
110105
logger.warning(
@@ -143,7 +138,7 @@ def test_add_tosa_BI(self, test_data: torch.Tensor):
143138

144139
@parameterized.expand(Add.test_parameters)
145140
@unittest.skipIf(
146-
not VELA_INSTALLED,
141+
not common.VELA_INSTALLED,
147142
"There is no point in running U55 tests if the Vela tool is not installed",
148143
)
149144
def test_add_u55_BI(self, test_data: torch.Tensor):
@@ -162,7 +157,7 @@ def test_add2_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
162157

163158
@parameterized.expand(Add2.test_parameters)
164159
@unittest.skipIf(
165-
not VELA_INSTALLED,
160+
not common.VELA_INSTALLED,
166161
"There is no point in running U55 tests if the Vela tool is not installed",
167162
)
168163
def test_add2_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):

0 commit comments

Comments
 (0)