Skip to content

Commit acdca0c

Browse files
committed
Arm backend: Improve pooling args handling
- Add support for ceil_mode=True for avgpool and maxpool - Add support for count_include_pad==True for avgpool - Add support for divisor_override for avgpool - Fix padding check in pool_2d_support Signed-off-by: Adrian Lundell <[email protected]> Change-Id: I4e4ec8ebaf174279893b640f87b691ea03cb668d
1 parent 994752e commit acdca0c

File tree

10 files changed

+283
-79
lines changed

10 files changed

+283
-79
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .convert_split_to_slice import ConvertSplitToSlicePass # noqa
2121
from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa
2222
from .convert_to_clamp import ConvertToClampPass # noqa
23+
from .decompose_avg_pool2d import DecomposeAvgPool2d # noqa
2324
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
2425
from .decompose_div_pass import DecomposeDivPass # noqa
2526
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
ConvertSplitToSlicePass,
2424
ConvertSqueezesToViewPass,
2525
ConvertToClampPass,
26+
DecomposeAvgPool2d,
2627
DecomposeCosineSimilarityPass,
2728
DecomposeDivPass,
2829
DecomposeEmbeddingPass,
@@ -63,7 +64,6 @@
6364
UnsqueezeBeforeRepeatPass,
6465
UnsqueezeScalarPlaceholdersPass,
6566
)
66-
6767
from executorch.backends.arm.tosa_specification import (
6868
TosaLoweringContext,
6969
TosaSpecification,
@@ -115,6 +115,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
115115
if self.tosa_spec.is_U55_subset:
116116
self.add_pass(BroadcastArgsPass())
117117
self.add_pass(DecomposeLinearPass())
118+
self.add_pass(DecomposeAvgPool2d())
118119
self.add_pass(ComputeConstantOpsAOT(exported_program))
119120

120121
self.add_pass(RemoveClonePass())
@@ -172,6 +173,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
172173
self.add_pass(RetraceFoldedDtypesPass())
173174
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
174175
self.add_pass(MatchArgRanksPass(exported_program))
176+
self.add_pass(DecomposeAvgPool2d())
175177
self.add_pass(ComputeConstantOpsAOT(exported_program))
176178

177179
self.add_pass(RemoveClonePass())
@@ -232,6 +234,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
232234
self.add_pass(DecomposeLinearVectorNormPass())
233235
self.add_pass(DecomposeSqrtPass())
234236
self.add_pass(DecomposeSiluPass())
237+
self.add_pass(DecomposeAvgPool2d())
235238

236239
if self.tosa_spec.is_U55_subset:
237240
# Numerically stable softmax uses amax which is not supported on Ethos-U55
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
import torch
8+
from executorch.backends.arm.operators.operator_validation_utils import (
9+
adjust_pooling_pad_if_needed,
10+
)
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.pass_base import ExportPass
13+
14+
edge_div_ops = (exir_ops.edge.aten.avg_pool2d.default,)
15+
aten_div_ops = (torch.ops.aten.avg_pool2d.default,)
16+
17+
18+
def get_decomposition(op) -> tuple:
19+
if op in edge_div_ops:
20+
return (
21+
exir_ops.edge.aten.full.default,
22+
exir_ops.edge.aten.cat.default,
23+
exir_ops.edge.aten.avg_pool2d.default,
24+
exir_ops.edge.aten.mul.Tensor,
25+
)
26+
if op in aten_div_ops:
27+
return (
28+
torch.ops.aten.full.default,
29+
torch.ops.aten.cat.default,
30+
torch.ops.aten.avg_pool2d.default,
31+
torch.ops.aten.mul.Tensor,
32+
)
33+
raise RuntimeError(f"Can't get div decomposition for op {op}")
34+
35+
36+
class DecomposeAvgPool2d(ExportPass):
37+
""" """
38+
39+
def call_operator(self, op, args, kwargs, meta):
40+
if op not in (edge_div_ops + aten_div_ops):
41+
return super().call_operator(op, args, kwargs, meta)
42+
43+
full_op, cat_op, avgpool_op, mul_op = get_decomposition(op)
44+
45+
x = args[0]
46+
kernel_h, kernel_w = args[1]
47+
kernel_size = kernel_h * kernel_w
48+
stride_h, stride_w = args[2]
49+
pad_h, pad_w = new_pad_h, new_pad_w = args[3] if len(args) > 3 else (0, 0)
50+
ceil_mode = args[4] if len(args) > 4 else False
51+
count_include_pad = args[5] if len(args) > 5 else True
52+
divisor_override = args[6] if len(args) > 6 else None
53+
54+
n, c, h, w = x.data.shape
55+
post_pad_w, post_pad_h = (0, 0)
56+
57+
# Count_include_pad == False means that we use a different divisor for edge elements
58+
# When divisor_override is set, this will be overriden anyways.
59+
# It is easier to replace a constant divisor, so set count_include_pad == True
60+
if divisor_override is not None:
61+
count_include_pad = True
62+
63+
# Add width padding manually if count_include_pad
64+
if count_include_pad and pad_w > 0:
65+
pre_pad_shape = [n, c, h, pad_w]
66+
pre_pad = super().call_operator(full_op, (pre_pad_shape, 0.0), kwargs, meta)
67+
68+
if ceil_mode and divisor_override is None:
69+
post_pad_w = pad_w
70+
else:
71+
post_pad_w = adjust_pooling_pad_if_needed(
72+
w, kernel_w, stride_w, pad_w, ceil_mode
73+
)
74+
75+
if post_pad_w > 0:
76+
post_pad_shape = [n, c, h, post_pad_w]
77+
post_pad = super().call_operator(
78+
full_op, (post_pad_shape, 0.0), kwargs, meta
79+
)
80+
cat_nodes = [pre_pad, x, post_pad]
81+
else:
82+
cat_nodes = [pre_pad, x]
83+
84+
x = super().call_operator(cat_op, (cat_nodes, 3), kwargs, meta)
85+
new_pad_w = 0
86+
87+
# Add height padding manually if count_include_pad
88+
if count_include_pad and pad_h > 0:
89+
pre_pad_shape = [n, c, pad_h, w + pad_w + post_pad_w]
90+
pre_pad = super().call_operator(full_op, (pre_pad_shape, 0.0), kwargs, meta)
91+
92+
if ceil_mode and divisor_override is None:
93+
post_pad_h = pad_h
94+
else:
95+
post_pad_h = adjust_pooling_pad_if_needed(
96+
h, kernel_h, stride_h, pad_h, ceil_mode
97+
)
98+
99+
if post_pad_h > 0:
100+
post_pad_shape = [n, c, post_pad_h, w + pad_w + post_pad_w]
101+
post_pad = super().call_operator(
102+
full_op, (post_pad_shape, 0.0), kwargs, meta
103+
)
104+
cat_nodes = [pre_pad, x, post_pad]
105+
else:
106+
cat_nodes = [pre_pad, x]
107+
108+
x = super().call_operator(cat_op, (cat_nodes, 2), kwargs, meta)
109+
new_pad_h = 0
110+
111+
avgpool_args = (x, args[1], args[2], [new_pad_h, new_pad_w], ceil_mode, False)
112+
x = super().call_operator(avgpool_op, avgpool_args, kwargs, meta)
113+
114+
# Multiply by factor (kernel_size / divisor_override) if divisor_override
115+
if divisor_override is not None and divisor_override != kernel_size:
116+
override_multiplier = super().call_operator(
117+
full_op, ([1, 1, 1, 1], kernel_size / divisor_override), kwargs, meta
118+
)
119+
x = super().call_operator(mul_op, (x, override_multiplier), kwargs, meta)
120+
121+
return x

backends/arm/_passes/decompose_maxpool2d_with_dilation.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def call_operator(self, op, args, kwargs, meta):
3636
stride = args[2]
3737
padding = args[3] if len(args) >= 4 else 0
3838
dilation = args[4] if len(args) >= 5 else 1
39+
ceil_mode = args[5] if len(args) == 6 else False
3940

4041
# Normalize attributes
4142
pad_h, pad_w = (padding, padding) if isinstance(padding, int) else padding
@@ -45,12 +46,9 @@ def call_operator(self, op, args, kwargs, meta):
4546
)
4647
s_h, s_w = (stride, stride) if isinstance(stride, int) else stride
4748

48-
# If no dilation: call EXIR edge op with only supported args (x, kernel, stride[, padding])
49+
# If no dilation: call EXIR edge op
4950
if d_h == 1 and d_w == 1:
50-
minimal_args = [x, kernel_size, stride]
51-
# only include padding if non-zero
52-
if (pad_h, pad_w) != (0, 0):
53-
minimal_args.append((pad_h, pad_w))
51+
minimal_args = [x, kernel_size, stride, padding, dilation, ceil_mode]
5452
return super().call_operator(op, tuple(minimal_args), {}, meta)
5553

5654
# Compute padded and packed dimensions for dilation > 1
@@ -102,7 +100,7 @@ def call_operator(self, op, args, kwargs, meta):
102100
if is_with_indices
103101
else exir_ops.edge.aten.max_pool2d.default
104102
)
105-
pool_args = (x2, (k_h, k_w), (s_h, s_w), (0, 0))
103+
pool_args = (x2, (k_h, k_w), (s_h, s_w), (0, 0), 1, ceil_mode)
106104
pool_out = super().call_operator(
107105
pool_edge_op,
108106
pool_args,

backends/arm/operator_support/pool_2d_support.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
register_tosa_support_check,
1313
SupportedTOSAOperatorCheck,
1414
)
15+
from executorch.backends.arm.operators.operator_validation_utils import (
16+
adjust_pooling_pad_if_needed,
17+
)
1518
from executorch.backends.arm.tosa_specification import TosaSpecification
1619
from executorch.exir.dialects._ops import ops as exir_ops
1720

@@ -56,25 +59,42 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
5659
input_arg = get_first_fake_tensor(input_arg)
5760
shape = input_arg.data.shape # type: ignore[union-attr]
5861

62+
# Calculate padding used in the final TOSA operator
5963
kernel = cast(tuple[int, int], node.args[1])
6064
stride = cast(tuple[int, int], node.args[2])
61-
if len(node.args) > 3:
62-
padding = cast(tuple[int, int], node.args[3])
63-
# Padding case
64-
if not all(1 <= k <= 8 for k in kernel) and not all(
65-
v == 0 for v in padding
66-
):
67-
self.reporter.report_reject(
68-
node, f"Avgpool2d with padding needs kernel dims < 8, got {kernel}"
69-
)
70-
return False
65+
padding = cast(tuple[int, int], node.args[3]) if len(node.args) > 3 else (0, 0)
66+
ceil_mode = cast(bool, node.args[4]) if len(node.args) > 4 else False
67+
count_include_pad = cast(bool, node.args[5]) if len(node.args) > 5 else True
68+
divisor_override = cast(int, node.args[6]) if len(node.args) > 6 else None
69+
70+
# If count_include_pad is True or divior_override is given, padding is applied
71+
# by concating zero-elements rather than setting it in the avg_pool op.
72+
if count_include_pad or divisor_override is not None:
73+
tosa_padding = (0, 0, 0, 0)
74+
# Otherwise, calculate the padding as done in the node visitor
7175
else:
72-
if not kernel_check(kernel):
73-
self.reporter.report_reject(
74-
node,
75-
f"Avgpool2d needs kernel_y < 256, kernel_x*kernel_y<=65536, got {kernel}",
76-
)
77-
return False
76+
post_pad_h = adjust_pooling_pad_if_needed(
77+
shape[2], kernel[0], stride[0], padding[0], ceil_mode
78+
)
79+
post_pad_w = adjust_pooling_pad_if_needed(
80+
shape[3], kernel[1], stride[1], padding[1], ceil_mode
81+
)
82+
tosa_padding = (padding[0], post_pad_h, padding[1], post_pad_w)
83+
84+
if not all(1 <= k <= 8 for k in kernel) and not all(
85+
v == 0 for v in tosa_padding
86+
):
87+
self.reporter.report_reject(
88+
node, f"Avgpool2d with padding needs kernel dims < 8, got {kernel}"
89+
)
90+
return False
91+
92+
if not kernel_check(kernel):
93+
self.reporter.report_reject(
94+
node,
95+
f"Avgpool2d needs kernel_y < 256, kernel_x*kernel_y<=65536, got {kernel}",
96+
)
97+
return False
7898

7999
if not dim_check(shape):
80100
self.reporter.report_reject(

backends/arm/operators/op_avg_pool2d.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ def _build_generic_avgpool2d(
5454
kernel_size_list = inputs[1].special
5555
stride_size_list = inputs[2].special
5656

57+
if len(inputs) > 4:
58+
ceil_mode = bool(inputs[4].number)
59+
else:
60+
ceil_mode = False
61+
5762
try:
5863
pad_size_list = inputs[3].special
5964
pad_size_list = [
@@ -71,12 +76,14 @@ def _build_generic_avgpool2d(
7176
kernel_size_list[0],
7277
stride_size_list[0],
7378
pad_size_list[1],
79+
ceil_mode,
7480
)
7581
pad_size_list[3] = adjust_pooling_pad_if_needed(
7682
input_tensor.shape[3],
7783
kernel_size_list[1],
7884
stride_size_list[1],
7985
pad_size_list[3],
86+
ceil_mode,
8087
)
8188

8289
attr = ts.TosaSerializerAttribute()
@@ -105,7 +112,7 @@ def define_node(
105112
) -> None:
106113
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
107114

108-
validate_num_inputs(self.target, inputs, [3, 4, 6])
115+
validate_num_inputs(self.target, inputs, [3, 4, 5, 6, 7])
109116
validate_same_dtype(self.target, [inputs[0], output], ts)
110117
validate_valid_dtype(
111118
self.target, [inputs[0], output], ts.DType.INT8, output.tosa_spec
@@ -141,7 +148,7 @@ def define_node(
141148
) -> None:
142149
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
143150

144-
validate_num_inputs(self.target, inputs, [3, 4, 6])
151+
validate_num_inputs(self.target, inputs, [3, 4, 5, 6, 7])
145152
validate_same_dtype(self.target, [inputs[0], output], ts)
146153
validate_valid_dtype(
147154
self.target,
@@ -192,6 +199,11 @@ def _build_generic_avgpool2d(
192199
kernel_size_list = inputs[1].special
193200
stride_size_list = inputs[2].special
194201

202+
if len(inputs) > 4:
203+
ceil_mode = bool(inputs[4].number)
204+
else:
205+
ceil_mode = False
206+
195207
try:
196208
pad_size_list = inputs[3].special
197209
pad_size_list = [
@@ -209,12 +221,14 @@ def _build_generic_avgpool2d(
209221
kernel_size_list[0],
210222
stride_size_list[0],
211223
pad_size_list[1],
224+
ceil_mode,
212225
)
213226
pad_size_list[3] = adjust_pooling_pad_if_needed(
214227
input_tensor.shape[3],
215228
kernel_size_list[1],
216229
stride_size_list[1],
217230
pad_size_list[3],
231+
ceil_mode,
218232
)
219233

220234
attr = ts.TosaSerializerAttribute()
@@ -247,7 +261,7 @@ def define_node(
247261
) -> None:
248262
import serializer.tosa_serializer as ts # type: ignore
249263

250-
validate_num_inputs(self.target, inputs, [3, 4, 6])
264+
validate_num_inputs(self.target, inputs, [3, 4, 5, 6, 7])
251265
validate_same_dtype(self.target, [inputs[0], output], ts)
252266
validate_valid_dtype(
253267
self.target, [inputs[0], output], ts.DType.INT8, output.tosa_spec
@@ -286,7 +300,7 @@ def define_node(
286300
) -> None:
287301
import serializer.tosa_serializer as ts # type: ignore
288302

289-
validate_num_inputs(self.target, inputs, [3, 4, 6])
303+
validate_num_inputs(self.target, inputs, [3, 4, 5, 6, 7])
290304
validate_same_dtype(self.target, [inputs[0], output], ts)
291305
validate_valid_dtype(
292306
self.target,

0 commit comments

Comments
 (0)