Skip to content

Commit b705860

Browse files
committed
Port Arm op unittests: conv op + other ops
Signed-off-by: Fredrik Knutsson <[email protected]> Change-Id: I0a2f83ff9eb245ebc6dc7714d8a3080eaf394a8e
1 parent 9e922d3 commit b705860

File tree

7 files changed

+268
-151
lines changed

7 files changed

+268
-151
lines changed

backends/arm/test/arm_tosa_reference.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040

4141
SUPPORTED_BI_TEST_LIST = [
4242
"simple_add_broadcast",
43-
"block_bottleneck_residual",
4443
]
4544

4645

backends/arm/test/ops/test_add.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
logger = logging.getLogger(__name__)
2020
logger.setLevel(logging.INFO)
2121

22-
torch.manual_seed(42)
23-
2422

2523
class TestSimpleAdd(unittest.TestCase):
2624
class Add(torch.nn.Module):

backends/arm/test/ops/test_conv.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
logger = logging.getLogger(__name__)
1919
logger.setLevel(logging.INFO)
2020

21-
torch.manual_seed(42)
22-
2321

2422
class Conv2d(torch.nn.Module):
2523
"""
Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
from typing import Tuple
10+
11+
import torch
12+
from executorch.backends.arm.test import common
13+
from executorch.backends.arm.test.test_models import TosaProfile
14+
from executorch.backends.arm.test.tester.arm_tester import ArmBackendSelector, ArmTester
15+
16+
"""
17+
This file contain unit tests where conv are combined with other ops.
18+
"""
19+
20+
21+
class ComboBlockBottleneckResidual(torch.nn.Module):
22+
# This is the essence of MobileNetV2. Ref: https://arxiv.org/abs/1801.04381
23+
edge_op_list = [
24+
"executorch_exir_dialects_edge__ops_aten_convolution_default",
25+
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default",
26+
"executorch_exir_dialects_edge__ops_aten_hardtanh_default",
27+
"executorch_exir_dialects_edge__ops_aten_add_Tensor",
28+
]
29+
30+
def __init__(self):
31+
super().__init__()
32+
# (t, c, n, s) = (6, 96, 1, 1)
33+
# 1. 1x1 CONV2d + ReLU6 (Pointwise)
34+
self.pointwise_conv2d = torch.nn.Conv2d(
35+
in_channels=64, out_channels=384, kernel_size=1, stride=1, groups=1
36+
) ## (1, 384, 81, 81)
37+
self.batch_norm2d_16 = torch.nn.BatchNorm2d(384, affine=False)
38+
self.relu6 = torch.nn.ReLU6()
39+
40+
# 2. 3x3 DepthwiseConv2d + ReLu6
41+
self.depthwise_conv2d = torch.nn.Conv2d(
42+
in_channels=384,
43+
out_channels=384,
44+
kernel_size=3,
45+
padding=1,
46+
stride=1,
47+
groups=384,
48+
) ## (1, 384, H, W)
49+
50+
# 3. Linear 1x1 Conv2d
51+
self.pointwise_conv2d_linear = torch.nn.Conv2d(
52+
in_channels=384, out_channels=64, kernel_size=1, stride=1, groups=1
53+
) ## (1, 64, 81, 81)
54+
55+
def get_inputs(self) -> Tuple[torch.Tensor]:
56+
return (torch.randn(1, 64, 81, 81),)
57+
58+
def forward(self, x):
59+
input = x
60+
# 1x1 CONV2d + ReLU6 (Pointwise)
61+
x = self.pointwise_conv2d(x)
62+
x = self.batch_norm2d_16(x)
63+
x = self.relu6(x)
64+
65+
# 3x3 DepthwiseConv2d + ReLu6
66+
x = self.depthwise_conv2d(x)
67+
x = self.batch_norm2d_16(x)
68+
x = self.relu6(x)
69+
70+
# Linear 1x1 Conv2d
71+
x = self.pointwise_conv2d_linear(x)
72+
73+
# Final Residual Connection
74+
x = x + input
75+
76+
return x
77+
78+
79+
class ComboConv2dMeandim(torch.nn.Module):
80+
edge_op_list = [
81+
"executorch_exir_dialects_edge__ops_aten_convolution_default",
82+
"executorch_exir_dialects_edge__ops_aten_mean_dim",
83+
]
84+
85+
def __init__(self):
86+
super().__init__()
87+
self.conv2d = torch.nn.Conv2d(
88+
in_channels=3, out_channels=10, kernel_size=5, stride=1, bias=False
89+
)
90+
# will be specialized to aten.mean.dim
91+
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
92+
93+
def get_inputs(self) -> Tuple[torch.Tensor]:
94+
return (torch.randn(1, 3, 128, 128),)
95+
96+
def forward(self, x):
97+
x = self.conv2d(x)
98+
return self.adaptive_avg_pool2d(x)
99+
100+
101+
class ComboConvBatchnormRelu(torch.nn.Module):
102+
edge_op_list = [
103+
"executorch_exir_dialects_edge__ops_aten_convolution_default",
104+
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default",
105+
"executorch_exir_dialects_edge__ops_aten_hardtanh_default",
106+
]
107+
108+
def __init__(self):
109+
super().__init__()
110+
self.conv2d = torch.nn.Conv2d(
111+
in_channels=3, out_channels=3, kernel_size=3, stride=1, groups=1
112+
)
113+
self.batch_norm2d = torch.nn.BatchNorm2d(3, affine=False)
114+
self.relu6 = torch.nn.ReLU6()
115+
116+
def get_inputs(self) -> Tuple[torch.Tensor]:
117+
return (torch.randn(1, 3, 256, 256),)
118+
119+
def forward(self, x):
120+
x = self.conv2d(x)
121+
x = self.batch_norm2d(x)
122+
x = self.relu6(x)
123+
return x
124+
125+
126+
class TestConvCombos(unittest.TestCase):
127+
def _test_conv_combo_tosa_MI_pipeline(
128+
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
129+
):
130+
tester = (
131+
ArmTester(
132+
module,
133+
inputs=test_data,
134+
profile=TosaProfile.MI,
135+
backend=ArmBackendSelector.TOSA,
136+
permute_memory_to_nhwc=True,
137+
)
138+
.export()
139+
.to_edge()
140+
.partition()
141+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
142+
.check_not(list(module.edge_op_list))
143+
.to_executorch()
144+
)
145+
if common.TOSA_REF_MODEL_INSTALLED:
146+
tester.run_method().compare_outputs()
147+
else:
148+
common.logger.warning(
149+
"TOSA ref model tool not installed, skip numerical correctness tests"
150+
)
151+
152+
def _test_conv_combo_tosa_BI_pipeline(
153+
self,
154+
module: torch.nn.Module,
155+
test_data: Tuple[torch.Tensor],
156+
atol: float = 1e-3,
157+
rtol: float = 1e-3,
158+
):
159+
tester = (
160+
ArmTester(
161+
module,
162+
inputs=test_data,
163+
profile=TosaProfile.BI,
164+
backend=ArmBackendSelector.TOSA,
165+
permute_memory_to_nhwc=True,
166+
)
167+
.quantize()
168+
.export()
169+
.to_edge()
170+
.partition()
171+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
172+
.check_not(list(module.edge_op_list))
173+
.to_executorch()
174+
)
175+
if common.TOSA_REF_MODEL_INSTALLED:
176+
tester.run_method().compare_outputs(atol=atol, rtol=rtol, qtol=1)
177+
else:
178+
common.logger.warning(
179+
"TOSA ref model tool not installed, skip numerical correctness tests"
180+
)
181+
182+
def _test_conv_combo_u55_BI_pipeline(
183+
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
184+
):
185+
(
186+
ArmTester(
187+
module,
188+
inputs=test_data,
189+
profile=TosaProfile.BI,
190+
backend=ArmBackendSelector.ETHOS_U55,
191+
permute_memory_to_nhwc=True,
192+
)
193+
.quantize()
194+
.export()
195+
.to_edge()
196+
.partition()
197+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
198+
.check_not(list(module.edge_op_list))
199+
.to_executorch()
200+
)
201+
202+
####################
203+
## Conv + meandim ##
204+
####################
205+
def test_conv_meandim_tosa_MI(self):
206+
model = ComboConv2dMeandim()
207+
self._test_conv_combo_tosa_MI_pipeline(model, model.get_inputs())
208+
209+
def test_conv_meandim_tosa_BI(self):
210+
model = ComboConv2dMeandim()
211+
self._test_conv_combo_tosa_BI_pipeline(model, model.get_inputs())
212+
213+
@unittest.skipIf(
214+
not common.VELA_INSTALLED,
215+
"There is no point in running U55 tests if the Vela tool is not installed",
216+
)
217+
def test_conv_meandim_u55_BI(self):
218+
model = ComboConv2dMeandim()
219+
self._test_conv_combo_u55_BI_pipeline(model, model.get_inputs())
220+
221+
##############################
222+
## Conv + batch norm + relu ##
223+
##############################
224+
def test_conv_batchnorm_relu_tosa_MI(self):
225+
model = ComboConvBatchnormRelu()
226+
self._test_conv_combo_tosa_MI_pipeline(model, model.get_inputs())
227+
228+
# TODO(MLETORCH-85): Investigate numerical issue. This diff is present in legacy
229+
# testcase as well (and also not tested). For now, just increase the
230+
# tolerance, such that we don't skip the test entirely (i.e. we maintain
231+
# functionality).
232+
def test_conv_batchnorm_relu_tosa_BI(self):
233+
model = ComboConvBatchnormRelu()
234+
self._test_conv_combo_tosa_BI_pipeline(
235+
model, model.get_inputs(), atol=1.0, rtol=1.0
236+
)
237+
238+
@unittest.skipIf(
239+
not common.VELA_INSTALLED,
240+
"There is no point in running U55 tests if the Vela tool is not installed",
241+
)
242+
def test_conv_batchnorm_relu_u55_BI(self):
243+
model = ComboConvBatchnormRelu()
244+
self._test_conv_combo_u55_BI_pipeline(model, model.get_inputs())
245+
246+
###############################
247+
## Block bottleneck residual ##
248+
###############################
249+
def test_block_bottleneck_residual_tosa_MI(self):
250+
model = ComboBlockBottleneckResidual()
251+
self._test_conv_combo_tosa_MI_pipeline(model, model.get_inputs())
252+
253+
# TODO(MLETORCH-85): Investigate numerical issue. This diff was present in legacy
254+
# testcase as well. For now, just increase the tolerance, such that
255+
# we don't skip the test entirely (i.e. we maintain functionality).
256+
def test_block_bottleneck_residual_tosa_BI(self):
257+
model = ComboBlockBottleneckResidual()
258+
self._test_conv_combo_tosa_BI_pipeline(
259+
model, model.get_inputs(), atol=1.0, rtol=1.0
260+
)
261+
262+
@unittest.skipIf(
263+
not common.VELA_INSTALLED,
264+
"There is no point in running U55 tests if the Vela tool is not installed",
265+
)
266+
def test_block_bottleneck_residual_u55_BI(self):
267+
model = ComboBlockBottleneckResidual()
268+
self._test_conv_combo_u55_BI_pipeline(model, model.get_inputs())

backends/arm/test/ops/test_depthwise_conv.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
logger = logging.getLogger(__name__)
2020
logger.setLevel(logging.INFO)
2121

22-
torch.manual_seed(42)
23-
2422
"""
2523
The configuration when
2624
groups == in_channels and

backends/arm/test/ops/test_linear.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
logger = logging.getLogger(__name__)
2020
logger.setLevel(logging.INFO)
2121

22-
torch.manual_seed(42)
2322

2423
test_data_suite = [
2524
# (test_name, test_data, out_features)

0 commit comments

Comments
 (0)