Skip to content

Commit be618c2

Browse files
freddan80facebook-github-bot
authored andcommitted
Port Arm backend op unittests: conv op + other ops (#2647)
Summary: Ported unit tests with combined ops. * Removed random seeds * Removed corresponing legacy test cases Pull Request resolved: #2647 Reviewed By: mergennachin Differential Revision: D55316080 Pulled By: digantdesai fbshipit-source-id: 5da3946fb191f47b0669c2157518bc93bd0c7554
1 parent bd6ceab commit be618c2

File tree

7 files changed

+272
-151
lines changed

7 files changed

+272
-151
lines changed

backends/arm/test/arm_tosa_reference.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040

4141
SUPPORTED_BI_TEST_LIST = [
4242
"simple_add_broadcast",
43-
"block_bottleneck_residual",
4443
]
4544

4645

backends/arm/test/ops/test_add.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
logger = logging.getLogger(__name__)
2020
logger.setLevel(logging.INFO)
2121

22-
torch.manual_seed(42)
23-
2422

2523
class TestSimpleAdd(unittest.TestCase):
2624
class Add(torch.nn.Module):

backends/arm/test/ops/test_conv.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
logger = logging.getLogger(__name__)
1919
logger.setLevel(logging.INFO)
2020

21-
torch.manual_seed(42)
22-
2321

2422
class Conv2d(torch.nn.Module):
2523
"""
Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
import unittest
9+
10+
from typing import Tuple
11+
12+
import torch
13+
from executorch.backends.arm.test import common
14+
from executorch.backends.arm.test.test_models import TosaProfile
15+
from executorch.backends.arm.test.tester.arm_tester import ArmBackendSelector, ArmTester
16+
17+
logger = logging.getLogger(__name__)
18+
logger.setLevel(logging.INFO)
19+
20+
"""
21+
This file contain unit tests where conv are combined with other ops.
22+
"""
23+
24+
25+
class ComboBlockBottleneckResidual(torch.nn.Module):
26+
# This is the essence of MobileNetV2. Ref: https://arxiv.org/abs/1801.04381
27+
edge_op_list = [
28+
"executorch_exir_dialects_edge__ops_aten_convolution_default",
29+
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default",
30+
"executorch_exir_dialects_edge__ops_aten_hardtanh_default",
31+
"executorch_exir_dialects_edge__ops_aten_add_Tensor",
32+
]
33+
34+
def __init__(self):
35+
super().__init__()
36+
# (t, c, n, s) = (6, 96, 1, 1)
37+
# 1. 1x1 CONV2d + ReLU6 (Pointwise)
38+
self.pointwise_conv2d = torch.nn.Conv2d(
39+
in_channels=64, out_channels=384, kernel_size=1, stride=1, groups=1
40+
) ## (1, 384, 81, 81)
41+
self.batch_norm2d_16 = torch.nn.BatchNorm2d(384, affine=False)
42+
self.relu6 = torch.nn.ReLU6()
43+
44+
# 2. 3x3 DepthwiseConv2d + ReLu6
45+
self.depthwise_conv2d = torch.nn.Conv2d(
46+
in_channels=384,
47+
out_channels=384,
48+
kernel_size=3,
49+
padding=1,
50+
stride=1,
51+
groups=384,
52+
) ## (1, 384, H, W)
53+
54+
# 3. Linear 1x1 Conv2d
55+
self.pointwise_conv2d_linear = torch.nn.Conv2d(
56+
in_channels=384, out_channels=64, kernel_size=1, stride=1, groups=1
57+
) ## (1, 64, 81, 81)
58+
59+
def get_inputs(self) -> Tuple[torch.Tensor]:
60+
return (torch.randn(1, 64, 81, 81),)
61+
62+
def forward(self, x):
63+
input = x
64+
# 1x1 CONV2d + ReLU6 (Pointwise)
65+
x = self.pointwise_conv2d(x)
66+
x = self.batch_norm2d_16(x)
67+
x = self.relu6(x)
68+
69+
# 3x3 DepthwiseConv2d + ReLu6
70+
x = self.depthwise_conv2d(x)
71+
x = self.batch_norm2d_16(x)
72+
x = self.relu6(x)
73+
74+
# Linear 1x1 Conv2d
75+
x = self.pointwise_conv2d_linear(x)
76+
77+
# Final Residual Connection
78+
x = x + input
79+
80+
return x
81+
82+
83+
class ComboConv2dMeandim(torch.nn.Module):
84+
edge_op_list = [
85+
"executorch_exir_dialects_edge__ops_aten_convolution_default",
86+
"executorch_exir_dialects_edge__ops_aten_mean_dim",
87+
]
88+
89+
def __init__(self):
90+
super().__init__()
91+
self.conv2d = torch.nn.Conv2d(
92+
in_channels=3, out_channels=10, kernel_size=5, stride=1, bias=False
93+
)
94+
# will be specialized to aten.mean.dim
95+
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
96+
97+
def get_inputs(self) -> Tuple[torch.Tensor]:
98+
return (torch.randn(1, 3, 128, 128),)
99+
100+
def forward(self, x):
101+
x = self.conv2d(x)
102+
return self.adaptive_avg_pool2d(x)
103+
104+
105+
class ComboConvBatchnormRelu(torch.nn.Module):
106+
edge_op_list = [
107+
"executorch_exir_dialects_edge__ops_aten_convolution_default",
108+
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default",
109+
"executorch_exir_dialects_edge__ops_aten_hardtanh_default",
110+
]
111+
112+
def __init__(self):
113+
super().__init__()
114+
self.conv2d = torch.nn.Conv2d(
115+
in_channels=3, out_channels=3, kernel_size=3, stride=1, groups=1
116+
)
117+
self.batch_norm2d = torch.nn.BatchNorm2d(3, affine=False)
118+
self.relu6 = torch.nn.ReLU6()
119+
120+
def get_inputs(self) -> Tuple[torch.Tensor]:
121+
return (torch.randn(1, 3, 256, 256),)
122+
123+
def forward(self, x):
124+
x = self.conv2d(x)
125+
x = self.batch_norm2d(x)
126+
x = self.relu6(x)
127+
return x
128+
129+
130+
class TestConvCombos(unittest.TestCase):
131+
def _test_conv_combo_tosa_MI_pipeline(
132+
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
133+
):
134+
tester = (
135+
ArmTester(
136+
module,
137+
inputs=test_data,
138+
profile=TosaProfile.MI,
139+
backend=ArmBackendSelector.TOSA,
140+
permute_memory_to_nhwc=True,
141+
)
142+
.export()
143+
.to_edge()
144+
.partition()
145+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
146+
.check_not(list(module.edge_op_list))
147+
.to_executorch()
148+
)
149+
if common.TOSA_REF_MODEL_INSTALLED:
150+
tester.run_method().compare_outputs()
151+
else:
152+
logger.warning(
153+
"TOSA ref model tool not installed, skip numerical correctness tests"
154+
)
155+
156+
def _test_conv_combo_tosa_BI_pipeline(
157+
self,
158+
module: torch.nn.Module,
159+
test_data: Tuple[torch.Tensor],
160+
atol: float = 1e-3,
161+
rtol: float = 1e-3,
162+
):
163+
tester = (
164+
ArmTester(
165+
module,
166+
inputs=test_data,
167+
profile=TosaProfile.BI,
168+
backend=ArmBackendSelector.TOSA,
169+
permute_memory_to_nhwc=True,
170+
)
171+
.quantize()
172+
.export()
173+
.to_edge()
174+
.partition()
175+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
176+
.check_not(list(module.edge_op_list))
177+
.to_executorch()
178+
)
179+
if common.TOSA_REF_MODEL_INSTALLED:
180+
tester.run_method().compare_outputs(atol=atol, rtol=rtol, qtol=1)
181+
else:
182+
logger.warning(
183+
"TOSA ref model tool not installed, skip numerical correctness tests"
184+
)
185+
186+
def _test_conv_combo_u55_BI_pipeline(
187+
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
188+
):
189+
(
190+
ArmTester(
191+
module,
192+
inputs=test_data,
193+
profile=TosaProfile.BI,
194+
backend=ArmBackendSelector.ETHOS_U55,
195+
permute_memory_to_nhwc=True,
196+
)
197+
.quantize()
198+
.export()
199+
.to_edge()
200+
.partition()
201+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
202+
.check_not(list(module.edge_op_list))
203+
.to_executorch()
204+
)
205+
206+
####################
207+
## Conv + meandim ##
208+
####################
209+
def test_conv_meandim_tosa_MI(self):
210+
model = ComboConv2dMeandim()
211+
self._test_conv_combo_tosa_MI_pipeline(model, model.get_inputs())
212+
213+
def test_conv_meandim_tosa_BI(self):
214+
model = ComboConv2dMeandim()
215+
self._test_conv_combo_tosa_BI_pipeline(model, model.get_inputs())
216+
217+
@unittest.skipIf(
218+
not common.VELA_INSTALLED,
219+
"There is no point in running U55 tests if the Vela tool is not installed",
220+
)
221+
def test_conv_meandim_u55_BI(self):
222+
model = ComboConv2dMeandim()
223+
self._test_conv_combo_u55_BI_pipeline(model, model.get_inputs())
224+
225+
##############################
226+
## Conv + batch norm + relu ##
227+
##############################
228+
def test_conv_batchnorm_relu_tosa_MI(self):
229+
model = ComboConvBatchnormRelu()
230+
self._test_conv_combo_tosa_MI_pipeline(model, model.get_inputs())
231+
232+
# TODO(MLETORCH-85): Investigate numerical issue. This diff is present in legacy
233+
# testcase as well (and also not tested). For now, just increase the
234+
# tolerance, such that we don't skip the test entirely (i.e. we maintain
235+
# functionality).
236+
def test_conv_batchnorm_relu_tosa_BI(self):
237+
model = ComboConvBatchnormRelu()
238+
self._test_conv_combo_tosa_BI_pipeline(
239+
model, model.get_inputs(), atol=1.0, rtol=1.0
240+
)
241+
242+
@unittest.skipIf(
243+
not common.VELA_INSTALLED,
244+
"There is no point in running U55 tests if the Vela tool is not installed",
245+
)
246+
def test_conv_batchnorm_relu_u55_BI(self):
247+
model = ComboConvBatchnormRelu()
248+
self._test_conv_combo_u55_BI_pipeline(model, model.get_inputs())
249+
250+
###############################
251+
## Block bottleneck residual ##
252+
###############################
253+
def test_block_bottleneck_residual_tosa_MI(self):
254+
model = ComboBlockBottleneckResidual()
255+
self._test_conv_combo_tosa_MI_pipeline(model, model.get_inputs())
256+
257+
# TODO(MLETORCH-85): Investigate numerical issue. This diff was present in legacy
258+
# testcase as well. For now, just increase the tolerance, such that
259+
# we don't skip the test entirely (i.e. we maintain functionality).
260+
def test_block_bottleneck_residual_tosa_BI(self):
261+
model = ComboBlockBottleneckResidual()
262+
self._test_conv_combo_tosa_BI_pipeline(
263+
model, model.get_inputs(), atol=1.0, rtol=1.0
264+
)
265+
266+
@unittest.skipIf(
267+
not common.VELA_INSTALLED,
268+
"There is no point in running U55 tests if the Vela tool is not installed",
269+
)
270+
def test_block_bottleneck_residual_u55_BI(self):
271+
model = ComboBlockBottleneckResidual()
272+
self._test_conv_combo_u55_BI_pipeline(model, model.get_inputs())

backends/arm/test/ops/test_depthwise_conv.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
logger = logging.getLogger(__name__)
2020
logger.setLevel(logging.INFO)
2121

22-
torch.manual_seed(42)
23-
2422
"""
2523
The configuration when
2624
groups == in_channels and

backends/arm/test/ops/test_linear.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
logger = logging.getLogger(__name__)
2020
logger.setLevel(logging.INFO)
2121

22-
torch.manual_seed(42)
2322

2423
test_data_suite_rank1 = [
2524
# (test_name, test_data, out_features)

0 commit comments

Comments
 (0)