Skip to content

Commit 65f3e18

Browse files
Add pooling and softmax unittests for Arm backend (#2645)
Summary: * Add SoftMax unittests * Add mean_dim unittests * Add AvgPool2d unittests * Enable linear u55 test Pull Request resolved: #2645 Reviewed By: mergennachin Differential Revision: D55316136 Pulled By: digantdesai fbshipit-source-id: 80d9a2e34a786742b47cb08b9b9c640f6b61964d
1 parent b8a28d4 commit 65f3e18

File tree

5 files changed

+487
-68
lines changed

5 files changed

+487
-68
lines changed
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Copyright 2024 Arm Limited and/or its affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import logging
9+
import unittest
10+
11+
from typing import Tuple
12+
13+
import torch
14+
from executorch.backends.arm.test import common
15+
from executorch.backends.arm.test.test_models import TosaProfile
16+
from executorch.backends.arm.test.tester.arm_tester import ArmBackendSelector, ArmTester
17+
from parameterized import parameterized
18+
19+
logger = logging.getLogger(__name__)
20+
logger.setLevel(logging.INFO)
21+
22+
test_data_suite = [
23+
# (test_name, test_data, [kernel_size, stride, padding])
24+
("zeros", torch.zeros(20, 16, 50, 32), [4, 2, 0]),
25+
("ones", torch.zeros(20, 16, 50, 32), [4, 2, 0]),
26+
("rand", torch.rand(20, 16, 50, 32), [4, 2, 0]),
27+
("randn", torch.randn(20, 16, 50, 32), [4, 2, 0]),
28+
]
29+
30+
31+
class TestAvgPool2d(unittest.TestCase):
32+
class AvgPool2d(torch.nn.Module):
33+
def __init__(
34+
self,
35+
kernel_size: int | Tuple[int, int],
36+
stride: int | Tuple[int, int],
37+
padding: int | Tuple[int, int],
38+
):
39+
super().__init__()
40+
self.avg_pool_2d = torch.nn.AvgPool2d(
41+
kernel_size=kernel_size, stride=stride, padding=padding
42+
)
43+
44+
def forward(self, x):
45+
return self.avg_pool_2d(x)
46+
47+
def _test_avgpool2d_tosa_MI_pipeline(
48+
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
49+
):
50+
tester = (
51+
ArmTester(
52+
module,
53+
inputs=test_data,
54+
profile=TosaProfile.MI,
55+
backend=ArmBackendSelector.TOSA,
56+
permute_memory_to_nhwc=True,
57+
)
58+
.export()
59+
.check(["torch.ops.aten.avg_pool2d.default"])
60+
.check_not(["torch.ops.quantized_decomposed"])
61+
.to_edge()
62+
.partition()
63+
.check_not(["executorch_exir_dialects_edge__ops_aten_avg_pool2d_default"])
64+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
65+
.to_executorch()
66+
)
67+
if common.TOSA_REF_MODEL_INSTALLED:
68+
tester.run_method().compare_outputs()
69+
else:
70+
logger.warning(
71+
"TOSA ref model tool not installed, skip numerical correctness tests"
72+
)
73+
74+
def _test_avgpool2d_tosa_BI_pipeline(
75+
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
76+
):
77+
tester = (
78+
ArmTester(
79+
module,
80+
inputs=test_data,
81+
profile=TosaProfile.BI,
82+
backend=ArmBackendSelector.TOSA,
83+
permute_memory_to_nhwc=True,
84+
)
85+
.quantize()
86+
.export()
87+
.check_count({"torch.ops.aten.avg_pool2d.default": 1})
88+
.check(["torch.ops.quantized_decomposed"])
89+
.to_edge()
90+
.partition()
91+
.check_not(["executorch_exir_dialects_edge__ops_aten_avg_pool2d_default"])
92+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
93+
.to_executorch()
94+
)
95+
if common.TOSA_REF_MODEL_INSTALLED:
96+
tester.run_method().compare_outputs(qtol=1)
97+
else:
98+
logger.warning(
99+
"TOSA ref model tool not installed, skip numerical correctness tests"
100+
)
101+
102+
def _test_avgpool2d_tosa_u55_BI_pipeline(
103+
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
104+
):
105+
(
106+
ArmTester(
107+
module,
108+
inputs=test_data,
109+
profile=TosaProfile.BI,
110+
backend=ArmBackendSelector.ETHOS_U55,
111+
permute_memory_to_nhwc=True,
112+
)
113+
.quantize()
114+
.export()
115+
.check_count({"torch.ops.aten.avg_pool2d.default": 1})
116+
.check(["torch.ops.quantized_decomposed"])
117+
.to_edge()
118+
.partition()
119+
.check_not(["executorch_exir_dialects_edge__ops_aten_avg_pool2d_default"])
120+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
121+
.to_executorch()
122+
)
123+
124+
@parameterized.expand(test_data_suite)
125+
def test_avgpool2d_tosa_MI(
126+
self,
127+
test_name: str,
128+
test_data: torch.Tensor,
129+
model_params: int | Tuple[int, int],
130+
):
131+
self._test_avgpool2d_tosa_MI_pipeline(
132+
self.AvgPool2d(*model_params), (test_data,)
133+
)
134+
135+
# Expected to fail since ArmQuantizer cannot quantize a AvgPool2D layer
136+
# TODO(MLETORCH-93)
137+
@parameterized.expand(test_data_suite)
138+
@unittest.expectedFailure
139+
def test_avgpool2d_tosa_BI(
140+
self,
141+
test_name: str,
142+
test_data: torch.Tensor,
143+
model_params: int | Tuple[int, int],
144+
):
145+
self._test_avgpool2d_tosa_BI_pipeline(
146+
self.AvgPool2d(*model_params), (test_data,)
147+
)
148+
149+
# Expected to fail since ArmQuantizer cannot quantize a AvgPool2D layer
150+
# TODO(MLETORCH-93)
151+
@parameterized.expand(test_data_suite)
152+
@unittest.skipIf(
153+
not common.VELA_INSTALLED,
154+
"There is no point in running U55 tests if the Vela tool is not installed",
155+
)
156+
@unittest.expectedFailure
157+
def test_avgpool2d_tosa_u55_BI(
158+
self,
159+
test_name: str,
160+
test_data: torch.Tensor,
161+
model_params: int | Tuple[int, int],
162+
):
163+
self._test_avgpool2d_tosa_u55_BI_pipeline(
164+
self.AvgPool2d(*model_params), (test_data,)
165+
)

backends/arm/test/ops/test_linear.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,38 +21,42 @@
2121

2222
torch.manual_seed(42)
2323

24-
test_data_suite = [
24+
test_data_suite_rank1 = [
2525
# (test_name, test_data, out_features)
2626
(
2727
"model_linear_rank1_zeros",
28-
torch.zeros(10, 10),
28+
torch.zeros(10),
2929
10,
3030
),
3131
(
3232
"model_linear_rank1_ones",
33-
torch.ones(10, 10),
33+
torch.ones(10),
3434
10,
3535
),
3636
(
3737
"model_linear_rank1_negative_ones",
38-
torch.ones(10, 10) * (-1),
38+
torch.ones(10) * (-1),
3939
10,
4040
),
4141
(
4242
"model_linear_rank1_rand",
43-
torch.rand(10, 10),
43+
torch.rand(10),
4444
10,
4545
),
4646
(
4747
"model_linear_rank1_negative_large_rand",
48-
torch.rand(10, 10) * (-100),
48+
torch.rand(10) * (-100),
4949
10,
5050
),
5151
(
5252
"model_linear_rank1_large_randn",
53-
torch.randn(10, 10) * 100,
53+
torch.randn(10) * 100,
5454
10,
5555
),
56+
]
57+
58+
test_data_suite_rank4 = [
59+
# (test_name, test_data, out_features)
5660
(
5761
"model_linear_rank4_zeros",
5862
torch.zeros(5, 10, 25, 20),
@@ -175,7 +179,7 @@ def _test_linear_tosa_u55_BI_pipeline(
175179
.to_executorch()
176180
)
177181

178-
@parameterized.expand(test_data_suite)
182+
@parameterized.expand(test_data_suite_rank1 + test_data_suite_rank4)
179183
def test_linear_tosa_MI(
180184
self,
181185
test_name: str,
@@ -192,7 +196,7 @@ def test_linear_tosa_MI(
192196
test_data,
193197
)
194198

195-
@parameterized.expand(test_data_suite)
199+
@parameterized.expand(test_data_suite_rank1 + test_data_suite_rank4)
196200
def test_linear_tosa_BI(
197201
self,
198202
test_name: str,
@@ -205,8 +209,11 @@ def test_linear_tosa_BI(
205209
self.Linear(in_features=in_features, out_features=out_features), test_data
206210
)
207211

208-
@parameterized.expand(test_data_suite)
209-
@unittest.skip("This does not work as of now")
212+
@parameterized.expand(test_data_suite_rank1)
213+
@unittest.skipIf(
214+
not common.VELA_INSTALLED,
215+
"There is no point in running U55 tests if the Vela tool is not installed",
216+
)
210217
def test_linear_tosa_u55_BI(
211218
self,
212219
test_name: str,

0 commit comments

Comments
 (0)