Skip to content

Commit 713d8a1

Browse files
Add max_pool2d op to Arm backend (#6285)
* Add max_pool2d op to Arm backend. - Adds node visitor and unittests - Adds remove_getitem_op pass to convert (maxpool_get inidices + getitem) -> maxpool2d op * Expected failures only for FVP
1 parent 545535b commit 713d8a1

File tree

7 files changed

+341
-0
lines changed

7 files changed

+341
-0
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from executorch.backends.arm._passes.unsqueeze_scalar_placeholders_pass import (
4444
UnsqueezeScalarPlaceholdersPass,
4545
)
46+
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
4647
from executorch.exir import ExportedProgram
4748
from executorch.exir.backend.compile_spec_schema import CompileSpec
4849
from executorch.exir.pass_manager import PassManager
@@ -58,6 +59,7 @@ def transform_to_backend_pipeline(
5859
):
5960
"""Apply passes before transforming program to backend"""
6061
self.add_pass(CastInt64ToInt32Pass(exported_program))
62+
self.add_pass(RemoveGetItemPass())
6163
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
6264
self.add_pass(SizeAdjustConv2DPass())
6365
self.add_pass(RemoveClonePass())

backends/arm/arm_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
5555
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
5656
exir_ops.edge.aten.native_layer_norm.default,
5757
exir_ops.edge.aten.avg_pool2d.default,
58+
exir_ops.edge.aten.max_pool2d_with_indices.default,
5859
exir_ops.edge.aten.sigmoid.default,
5960
exir_ops.edge.aten.mm.default,
6061
exir_ops.edge.aten.repeat.default,

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
op_get_item,
2121
op_hardtanh,
2222
op_log,
23+
op_max_pool2d,
2324
op_mm,
2425
op_mul,
2526
op_permute,
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
from typing import cast, List
8+
9+
import serializer.tosa_serializer as ts
10+
import torch
11+
from executorch.backends.arm.operators.node_visitor import (
12+
NodeVisitor,
13+
register_node_visitor,
14+
)
15+
from executorch.backends.arm.tosa_mapping import TosaArg
16+
from executorch.backends.arm.tosa_utils import get_quant_node_args
17+
18+
from serializer.tosa_serializer import TosaOp
19+
20+
21+
@register_node_visitor
22+
class MaxPool2dVisitor(NodeVisitor):
23+
target = "aten.max_pool2d.default"
24+
25+
def __init__(self, *args):
26+
super().__init__(*args)
27+
28+
def define_node(
29+
self,
30+
node: torch.fx.Node,
31+
tosa_graph: ts.TosaSerializer,
32+
inputs: List[TosaArg],
33+
output: TosaArg,
34+
is_quant_node: bool,
35+
) -> None:
36+
37+
input_tensor = inputs[0]
38+
kernel_size = inputs[1].special
39+
stride = inputs[2].special
40+
41+
try:
42+
padding = [*inputs[3].special, *inputs[3].special]
43+
except IndexError:
44+
padding = [0, 0, 0, 0]
45+
46+
accumulator_type = input_tensor.dtype
47+
48+
if is_quant_node:
49+
# Accumulator type always is int8 when input tensor is an integer type.
50+
accumulator_type = ts.DType.INT8
51+
52+
# Initilize zero point to zero.
53+
input_zp = 0
54+
output_zp = 0
55+
56+
if is_quant_node:
57+
input_zp = get_quant_node_args(
58+
cast(torch.fx.Node, node.all_input_nodes[0])
59+
).zp
60+
output_zp = get_quant_node_args(list(node.users)[0]).zp
61+
62+
attr = ts.TosaSerializerAttribute()
63+
attr.PoolAttribute(
64+
kernel=kernel_size,
65+
stride=stride,
66+
pad=padding,
67+
input_zp=input_zp,
68+
output_zp=output_zp,
69+
accum_dtype=accumulator_type,
70+
)
71+
72+
tosa_graph.addOperator(
73+
TosaOp.Op().MAX_POOL2D,
74+
[input_tensor.name],
75+
[output.name],
76+
attr,
77+
)

backends/arm/quantizer/arm_quantizer_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def is_share_obs_or_fq_op(op: Callable) -> bool:
147147
# TODO: remove?
148148
torch.ops.aten.adaptive_avg_pool2d.default,
149149
torch.ops.aten.avg_pool2d.default,
150+
torch.ops.aten.max_pool2d.default,
150151
torch.ops.aten.full.default,
151152
torch.ops.aten.flatten.using_ints,
152153
torch.ops.aten.dropout.default,

backends/arm/test/common.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,17 @@ def pytest_sessionfinish(session, exitstatus):
9191

9292
# ==== End of Pytest hooks =====
9393

94+
# ==== Custom Pytest decorators =====
95+
96+
97+
def expectedFailureOnFVP(test_item):
98+
if is_option_enabled("corstone300"):
99+
test_item.__unittest_expecting_failure__ = True
100+
return test_item
101+
102+
103+
# ==== End of Custom Pytest decorators =====
104+
94105

95106
def load_libquantized_ops_aot_lib():
96107
so_ext = {
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Copyright 2024 Arm Limited and/or its affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import logging
9+
import unittest
10+
11+
from typing import Tuple
12+
13+
import torch
14+
from executorch.backends.arm.quantizer.arm_quantizer import (
15+
ArmQuantizer,
16+
get_symmetric_quantization_config,
17+
)
18+
from executorch.backends.arm.test import common
19+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
20+
21+
from executorch.backends.xnnpack.test.tester.tester import Quantize
22+
from executorch.exir.backend.backend_details import CompileSpec
23+
from parameterized import parameterized
24+
25+
logger = logging.getLogger(__name__)
26+
logger.setLevel(logging.INFO)
27+
28+
test_data_suite = [
29+
# (test_name, test_data, [kernel_size, stride, padding])
30+
("zeros", torch.zeros(1, 1, 4, 8), [2, 2, 1]),
31+
("ones", torch.ones(1, 16, 50, 32), [4, 2, 0]),
32+
("rand", torch.rand(1, 16, 52, 16), [4, 3, 0]),
33+
]
34+
35+
test_data_suite_mult_batches = [
36+
("randn", torch.randn(5, 16, 50, 32), [4, 2, 0]),
37+
]
38+
39+
40+
class TestMaxPool2d(unittest.TestCase):
41+
"""Tests MaxPool2d."""
42+
43+
class MaxPool2d(torch.nn.Module):
44+
def __init__(
45+
self,
46+
kernel_size: int | Tuple[int, int],
47+
stride: int | Tuple[int, int],
48+
padding: int | Tuple[int, int],
49+
):
50+
super().__init__()
51+
self.max_pool_2d = torch.nn.MaxPool2d(
52+
kernel_size=kernel_size, stride=stride, padding=padding
53+
)
54+
55+
def forward(self, x):
56+
return self.max_pool_2d(x)
57+
58+
def _test_maxpool2d_tosa_MI_pipeline(
59+
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
60+
):
61+
(
62+
ArmTester(
63+
module,
64+
example_inputs=test_data,
65+
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True),
66+
)
67+
.export()
68+
.check(["torch.ops.aten.max_pool2d.default"])
69+
.check_not(["torch.ops.quantized_decomposed"])
70+
.to_edge()
71+
.partition()
72+
.check_not(["executorch_exir_dialects_edge__ops_aten_max_pool2d_default"])
73+
.check_not(
74+
[
75+
"executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default"
76+
]
77+
)
78+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
79+
.to_executorch()
80+
)
81+
82+
def _test_maxpool2d_tosa_BI_pipeline(
83+
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
84+
):
85+
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
86+
(
87+
ArmTester(
88+
module,
89+
example_inputs=test_data,
90+
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True),
91+
)
92+
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
93+
.export()
94+
.check_count({"torch.ops.aten.max_pool2d.default": 1})
95+
.check(["torch.ops.quantized_decomposed"])
96+
.to_edge()
97+
.partition()
98+
.check_not(["executorch_exir_dialects_edge__ops_aten_max_pool2d_default"])
99+
.check_not(
100+
[
101+
"executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default"
102+
]
103+
)
104+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
105+
.to_executorch()
106+
.run_method_and_compare_outputs(inputs=test_data, qtol=1)
107+
)
108+
109+
def _test_maxpool2d_tosa_ethos_BI_pipeline(
110+
self,
111+
module: torch.nn.Module,
112+
compile_spec: CompileSpec,
113+
test_data: Tuple[torch.tensor],
114+
):
115+
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
116+
tester = (
117+
ArmTester(
118+
module,
119+
example_inputs=test_data,
120+
compile_spec=compile_spec,
121+
)
122+
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
123+
.export()
124+
.check_count({"torch.ops.aten.max_pool2d.default": 1})
125+
.check(["torch.ops.quantized_decomposed"])
126+
.to_edge()
127+
.partition()
128+
.check_not(["executorch_exir_dialects_edge__ops_aten_max_pool2d_default"])
129+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
130+
.to_executorch()
131+
.serialize()
132+
)
133+
134+
return tester
135+
136+
@parameterized.expand(test_data_suite)
137+
def test_maxpool2d_tosa_MI(
138+
self,
139+
test_name: str,
140+
test_data: torch.Tensor,
141+
model_params: int | Tuple[int, int],
142+
):
143+
self._test_maxpool2d_tosa_MI_pipeline(
144+
self.MaxPool2d(*model_params), (test_data,)
145+
)
146+
147+
@parameterized.expand(test_data_suite)
148+
def test_maxpool2d_tosa_BI(
149+
self,
150+
test_name: str,
151+
test_data: torch.Tensor,
152+
model_params: int | Tuple[int, int],
153+
):
154+
self._test_maxpool2d_tosa_BI_pipeline(
155+
self.MaxPool2d(*model_params), (test_data,)
156+
)
157+
158+
@parameterized.expand(test_data_suite)
159+
def test_maxpool2d_tosa_u55_BI(
160+
self,
161+
test_name: str,
162+
test_data: torch.Tensor,
163+
model_params: int | Tuple[int, int],
164+
):
165+
tester = self._test_maxpool2d_tosa_ethos_BI_pipeline(
166+
self.MaxPool2d(*model_params),
167+
common.get_u55_compile_spec(permute_memory_to_nhwc=True),
168+
(test_data,),
169+
)
170+
if common.is_option_enabled("corstone300"):
171+
tester.run_method_and_compare_outputs(
172+
qtol=1, inputs=(test_data,), target_board="corstone-300"
173+
)
174+
175+
@parameterized.expand(test_data_suite)
176+
def test_maxpool2d_tosa_u85_BI(
177+
self,
178+
test_name: str,
179+
test_data: torch.Tensor,
180+
model_params: int | Tuple[int, int],
181+
):
182+
tester = self._test_maxpool2d_tosa_ethos_BI_pipeline(
183+
self.MaxPool2d(*model_params),
184+
common.get_u85_compile_spec(permute_memory_to_nhwc=True),
185+
(test_data,),
186+
)
187+
if common.is_option_enabled("corstone300"):
188+
tester.run_method_and_compare_outputs(
189+
qtol=1, inputs=(test_data,), target_board="corstone-320"
190+
)
191+
192+
@parameterized.expand(test_data_suite_mult_batches)
193+
def test_maxpool2d_tosa_MI_mult_batches(
194+
self,
195+
test_name: str,
196+
test_data: torch.Tensor,
197+
model_params: int | Tuple[int, int],
198+
):
199+
self._test_maxpool2d_tosa_MI_pipeline(
200+
self.MaxPool2d(*model_params), (test_data,)
201+
)
202+
203+
@parameterized.expand(test_data_suite_mult_batches)
204+
def test_maxpool2d_tosa_BI_mult_batches(
205+
self,
206+
test_name: str,
207+
test_data: torch.Tensor,
208+
model_params: int | Tuple[int, int],
209+
):
210+
self._test_maxpool2d_tosa_BI_pipeline(
211+
self.MaxPool2d(*model_params), (test_data,)
212+
)
213+
214+
@parameterized.expand(test_data_suite_mult_batches)
215+
@common.expectedFailureOnFVP # TODO: MLETORCH-433
216+
def test_maxpool2d_tosa_u55_BI_mult_batches(
217+
self,
218+
test_name: str,
219+
test_data: torch.Tensor,
220+
model_params: int | Tuple[int, int],
221+
):
222+
tester = self._test_maxpool2d_tosa_ethos_BI_pipeline(
223+
self.MaxPool2d(*model_params),
224+
common.get_u55_compile_spec(permute_memory_to_nhwc=True),
225+
(test_data,),
226+
)
227+
if common.is_option_enabled("corstone300"):
228+
tester.run_method_and_compare_outputs(
229+
qtol=1, inputs=(test_data,), target_board="corstone-300"
230+
)
231+
232+
@parameterized.expand(test_data_suite_mult_batches)
233+
@common.expectedFailureOnFVP # TODO: MLETORCH-433
234+
def test_maxpool2d_tosa_u85_BI_mult_batches(
235+
self,
236+
test_name: str,
237+
test_data: torch.Tensor,
238+
model_params: int | Tuple[int, int],
239+
):
240+
tester = self._test_maxpool2d_tosa_ethos_BI_pipeline(
241+
self.MaxPool2d(*model_params),
242+
common.get_u85_compile_spec(permute_memory_to_nhwc=True),
243+
(test_data,),
244+
)
245+
if common.is_option_enabled("corstone300"):
246+
tester.run_method_and_compare_outputs(
247+
qtol=1, inputs=(test_data,), target_board="corstone-320"
248+
)

0 commit comments

Comments
 (0)