Skip to content

Commit 8348ec0

Browse files
authored
Arm backend: Remove N=1 constraint for MaxPool2d on U55 (#8765)
This is not needed anymore after and earlier Ethos-U compiler (Vela) version bump this is supported/handled by Vela. Signed-off-by: Erik Lundell <[email protected]>
1 parent 9c12c8f commit 8348ec0

File tree

2 files changed

+21
-5
lines changed

2 files changed

+21
-5
lines changed

backends/arm/operator_support/pool_2d_support.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ def stride_check(strides: tuple[int, int]) -> bool:
2626

2727

2828
def dim_check(shape=torch.Size) -> bool:
29-
check = shape[0] == 1
30-
for dim in shape:
29+
check = True
30+
for dim in shape[1:]:
3131
check &= 1 <= dim <= 65536
3232
return check
3333

@@ -59,7 +59,7 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
5959
if not kernel_check(kernel):
6060
return False
6161

62-
return dim_check(shape) and stride_check(stride)
62+
return dim_check(shape) and shape[0] == 1 and stride_check(stride)
6363

6464

6565
@register_tosa_support_check

backends/arm/test/ops/test_max_pool.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
# Copyright 2024-2025 Arm Limited and/or its affiliates.
32
# All rights reserved.
3+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
44
#
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
@@ -232,8 +232,24 @@ def test_maxpool2d_tosa_u85_BI_mult_batches(
232232
if conftest.is_option_enabled("corstone_fvp"):
233233
tester.run_method_and_compare_outputs(qtol=1, inputs=(test_data,))
234234

235+
@parameterized.expand(test_data_suite_mult_batches)
236+
@pytest.mark.corstone_fvp
237+
@conftest.expectedFailureOnFVP # TODO: MLETORCH-433
238+
def test_maxpool2d_tosa_u55_BI_mult_batches(
239+
self,
240+
test_name: str,
241+
test_data: torch.Tensor,
242+
model_params: int | Tuple[int, int],
243+
):
244+
tester = self._test_maxpool2d_tosa_ethos_BI_pipeline(
245+
self.MaxPool2d(*model_params),
246+
common.get_u55_compile_spec(),
247+
(test_data,),
248+
)
249+
if conftest.is_option_enabled("corstone_fvp"):
250+
tester.run_method_and_compare_outputs(qtol=1, inputs=(test_data,))
251+
235252
reject_data_suite = [
236-
(MaxPool2d(1, 1, 0), torch.rand(2, 5, 5, 5)),
237253
(MaxPool2d(1, 4, 0), torch.rand(1, 10, 10, 10)),
238254
(MaxPool2d((1, 257), 1, 0), torch.rand(1, 16, 5, 300)),
239255
(MaxPool2d((800, 90), 1, 0), torch.rand(1, 16, 850, 100)),

0 commit comments

Comments
 (0)