Skip to content

Arm backend: Improve pooling args handling #11819

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .convert_split_to_slice import ConvertSplitToSlicePass # noqa
from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa
from .convert_to_clamp import ConvertToClampPass # noqa
from .decompose_avg_pool2d import DecomposeAvgPool2d # noqa
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
from .decompose_div_pass import DecomposeDivPass # noqa
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa
Expand Down
5 changes: 4 additions & 1 deletion backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
ConvertSplitToSlicePass,
ConvertSqueezesToViewPass,
ConvertToClampPass,
DecomposeAvgPool2d,
DecomposeCosineSimilarityPass,
DecomposeDivPass,
DecomposeEmbeddingPass,
Expand Down Expand Up @@ -63,7 +64,6 @@
UnsqueezeBeforeRepeatPass,
UnsqueezeScalarPlaceholdersPass,
)

from executorch.backends.arm.tosa_specification import (
TosaLoweringContext,
TosaSpecification,
Expand Down Expand Up @@ -115,6 +115,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
if self.tosa_spec.is_U55_subset:
self.add_pass(BroadcastArgsPass())
self.add_pass(DecomposeLinearPass())
self.add_pass(DecomposeAvgPool2d())
self.add_pass(ComputeConstantOpsAOT(exported_program))

self.add_pass(RemoveClonePass())
Expand Down Expand Up @@ -172,6 +173,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(RetraceFoldedDtypesPass())
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
self.add_pass(MatchArgRanksPass(exported_program))
self.add_pass(DecomposeAvgPool2d())
self.add_pass(ComputeConstantOpsAOT(exported_program))

self.add_pass(RemoveClonePass())
Expand Down Expand Up @@ -232,6 +234,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(DecomposeLinearVectorNormPass())
self.add_pass(DecomposeSqrtPass())
self.add_pass(DecomposeSiluPass())
self.add_pass(DecomposeAvgPool2d())

if self.tosa_spec.is_U55_subset:
# Numerically stable softmax uses amax which is not supported on Ethos-U55
Expand Down
121 changes: 121 additions & 0 deletions backends/arm/_passes/decompose_avg_pool2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import torch
from executorch.backends.arm.operators.operator_validation_utils import (
adjust_pooling_pad_if_needed,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

edge_div_ops = (exir_ops.edge.aten.avg_pool2d.default,)
aten_div_ops = (torch.ops.aten.avg_pool2d.default,)


def get_decomposition(op) -> tuple:
if op in edge_div_ops:
return (
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.avg_pool2d.default,
exir_ops.edge.aten.mul.Tensor,
)
if op in aten_div_ops:
return (
torch.ops.aten.full.default,
torch.ops.aten.cat.default,
torch.ops.aten.avg_pool2d.default,
torch.ops.aten.mul.Tensor,
)
raise RuntimeError(f"Can't get div decomposition for op {op}")


class DecomposeAvgPool2d(ExportPass):
""" """

def call_operator(self, op, args, kwargs, meta):
if op not in (edge_div_ops + aten_div_ops):
return super().call_operator(op, args, kwargs, meta)

full_op, cat_op, avgpool_op, mul_op = get_decomposition(op)

x = args[0]
kernel_h, kernel_w = args[1]
kernel_size = kernel_h * kernel_w
stride_h, stride_w = args[2]
pad_h, pad_w = new_pad_h, new_pad_w = args[3] if len(args) > 3 else (0, 0)
ceil_mode = args[4] if len(args) > 4 else False
count_include_pad = args[5] if len(args) > 5 else True
divisor_override = args[6] if len(args) > 6 else None

n, c, h, w = x.data.shape
post_pad_w, post_pad_h = (0, 0)

# Count_include_pad == False means that we use a different divisor for edge elements
# When divisor_override is set, this will be overriden anyways.
# It is easier to replace a constant divisor, so set count_include_pad == True
if divisor_override is not None:
count_include_pad = True

# Add width padding manually if count_include_pad
if count_include_pad and pad_w > 0:
pre_pad_shape = [n, c, h, pad_w]
pre_pad = super().call_operator(full_op, (pre_pad_shape, 0.0), kwargs, meta)

if ceil_mode and divisor_override is None:
post_pad_w = pad_w
else:
post_pad_w = adjust_pooling_pad_if_needed(
w, kernel_w, stride_w, pad_w, ceil_mode
)

if post_pad_w > 0:
post_pad_shape = [n, c, h, post_pad_w]
post_pad = super().call_operator(
full_op, (post_pad_shape, 0.0), kwargs, meta
)
cat_nodes = [pre_pad, x, post_pad]
else:
cat_nodes = [pre_pad, x]

x = super().call_operator(cat_op, (cat_nodes, 3), kwargs, meta)
new_pad_w = 0

# Add height padding manually if count_include_pad
if count_include_pad and pad_h > 0:
pre_pad_shape = [n, c, pad_h, w + pad_w + post_pad_w]
pre_pad = super().call_operator(full_op, (pre_pad_shape, 0.0), kwargs, meta)

if ceil_mode and divisor_override is None:
post_pad_h = pad_h
else:
post_pad_h = adjust_pooling_pad_if_needed(
h, kernel_h, stride_h, pad_h, ceil_mode
)

if post_pad_h > 0:
post_pad_shape = [n, c, post_pad_h, w + pad_w + post_pad_w]
post_pad = super().call_operator(
full_op, (post_pad_shape, 0.0), kwargs, meta
)
cat_nodes = [pre_pad, x, post_pad]
else:
cat_nodes = [pre_pad, x]

x = super().call_operator(cat_op, (cat_nodes, 2), kwargs, meta)
new_pad_h = 0

avgpool_args = (x, args[1], args[2], [new_pad_h, new_pad_w], ceil_mode, False)
x = super().call_operator(avgpool_op, avgpool_args, kwargs, meta)

# Multiply by factor (kernel_size / divisor_override) if divisor_override
if divisor_override is not None and divisor_override != kernel_size:
override_multiplier = super().call_operator(
full_op, ([1, 1, 1, 1], kernel_size / divisor_override), kwargs, meta
)
x = super().call_operator(mul_op, (x, override_multiplier), kwargs, meta)

return x
10 changes: 4 additions & 6 deletions backends/arm/_passes/decompose_maxpool2d_with_dilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def call_operator(self, op, args, kwargs, meta):
stride = args[2]
padding = args[3] if len(args) >= 4 else 0
dilation = args[4] if len(args) >= 5 else 1
ceil_mode = args[5] if len(args) == 6 else False

# Normalize attributes
pad_h, pad_w = (padding, padding) if isinstance(padding, int) else padding
Expand All @@ -45,12 +46,9 @@ def call_operator(self, op, args, kwargs, meta):
)
s_h, s_w = (stride, stride) if isinstance(stride, int) else stride

# If no dilation: call EXIR edge op with only supported args (x, kernel, stride[, padding])
# If no dilation: call EXIR edge op
if d_h == 1 and d_w == 1:
minimal_args = [x, kernel_size, stride]
# only include padding if non-zero
if (pad_h, pad_w) != (0, 0):
minimal_args.append((pad_h, pad_w))
minimal_args = [x, kernel_size, stride, padding, dilation, ceil_mode]
return super().call_operator(op, tuple(minimal_args), {}, meta)

# Compute padded and packed dimensions for dilation > 1
Expand Down Expand Up @@ -102,7 +100,7 @@ def call_operator(self, op, args, kwargs, meta):
if is_with_indices
else exir_ops.edge.aten.max_pool2d.default
)
pool_args = (x2, (k_h, k_w), (s_h, s_w), (0, 0))
pool_args = (x2, (k_h, k_w), (s_h, s_w), (0, 0), 1, ceil_mode)
pool_out = super().call_operator(
pool_edge_op,
pool_args,
Expand Down
52 changes: 36 additions & 16 deletions backends/arm/operator_support/pool_2d_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
register_tosa_support_check,
SupportedTOSAOperatorCheck,
)
from executorch.backends.arm.operators.operator_validation_utils import (
adjust_pooling_pad_if_needed,
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.exir.dialects._ops import ops as exir_ops

Expand Down Expand Up @@ -56,25 +59,42 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
input_arg = get_first_fake_tensor(input_arg)
shape = input_arg.data.shape # type: ignore[union-attr]

# Calculate padding used in the final TOSA operator
kernel = cast(tuple[int, int], node.args[1])
stride = cast(tuple[int, int], node.args[2])
if len(node.args) > 3:
padding = cast(tuple[int, int], node.args[3])
# Padding case
if not all(1 <= k <= 8 for k in kernel) and not all(
v == 0 for v in padding
):
self.reporter.report_reject(
node, f"Avgpool2d with padding needs kernel dims < 8, got {kernel}"
)
return False
padding = cast(tuple[int, int], node.args[3]) if len(node.args) > 3 else (0, 0)
ceil_mode = cast(bool, node.args[4]) if len(node.args) > 4 else False
count_include_pad = cast(bool, node.args[5]) if len(node.args) > 5 else True
divisor_override = cast(int, node.args[6]) if len(node.args) > 6 else None

# If count_include_pad is True or divior_override is given, padding is applied
# by concating zero-elements rather than setting it in the avg_pool op.
if count_include_pad or divisor_override is not None:
tosa_padding = (0, 0, 0, 0)
# Otherwise, calculate the padding as done in the node visitor
else:
if not kernel_check(kernel):
self.reporter.report_reject(
node,
f"Avgpool2d needs kernel_y < 256, kernel_x*kernel_y<=65536, got {kernel}",
)
return False
post_pad_h = adjust_pooling_pad_if_needed(
shape[2], kernel[0], stride[0], padding[0], ceil_mode
)
post_pad_w = adjust_pooling_pad_if_needed(
shape[3], kernel[1], stride[1], padding[1], ceil_mode
)
tosa_padding = (padding[0], post_pad_h, padding[1], post_pad_w)

if not all(1 <= k <= 8 for k in kernel) and not all(
v == 0 for v in tosa_padding
):
self.reporter.report_reject(
node, f"Avgpool2d with padding needs kernel dims < 8, got {kernel}"
)
return False

if not kernel_check(kernel):
self.reporter.report_reject(
node,
f"Avgpool2d needs kernel_y < 256, kernel_x*kernel_y<=65536, got {kernel}",
)
return False

if not dim_check(shape):
self.reporter.report_reject(
Expand Down
22 changes: 18 additions & 4 deletions backends/arm/operators/op_avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ def _build_generic_avgpool2d(
kernel_size_list = inputs[1].special
stride_size_list = inputs[2].special

if len(inputs) > 4:
ceil_mode = bool(inputs[4].number)
else:
ceil_mode = False

try:
pad_size_list = inputs[3].special
pad_size_list = [
Expand All @@ -71,12 +76,14 @@ def _build_generic_avgpool2d(
kernel_size_list[0],
stride_size_list[0],
pad_size_list[1],
ceil_mode,
)
pad_size_list[3] = adjust_pooling_pad_if_needed(
input_tensor.shape[3],
kernel_size_list[1],
stride_size_list[1],
pad_size_list[3],
ceil_mode,
)

attr = ts.TosaSerializerAttribute()
Expand Down Expand Up @@ -105,7 +112,7 @@ def define_node(
) -> None:
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, [3, 4, 6])
validate_num_inputs(self.target, inputs, [3, 4, 5, 6, 7])
validate_same_dtype(self.target, [inputs[0], output], ts)
validate_valid_dtype(
self.target, [inputs[0], output], ts.DType.INT8, output.tosa_spec
Expand Down Expand Up @@ -141,7 +148,7 @@ def define_node(
) -> None:
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, [3, 4, 6])
validate_num_inputs(self.target, inputs, [3, 4, 5, 6, 7])
validate_same_dtype(self.target, [inputs[0], output], ts)
validate_valid_dtype(
self.target,
Expand Down Expand Up @@ -192,6 +199,11 @@ def _build_generic_avgpool2d(
kernel_size_list = inputs[1].special
stride_size_list = inputs[2].special

if len(inputs) > 4:
ceil_mode = bool(inputs[4].number)
else:
ceil_mode = False

try:
pad_size_list = inputs[3].special
pad_size_list = [
Expand All @@ -209,12 +221,14 @@ def _build_generic_avgpool2d(
kernel_size_list[0],
stride_size_list[0],
pad_size_list[1],
ceil_mode,
)
pad_size_list[3] = adjust_pooling_pad_if_needed(
input_tensor.shape[3],
kernel_size_list[1],
stride_size_list[1],
pad_size_list[3],
ceil_mode,
)

attr = ts.TosaSerializerAttribute()
Expand Down Expand Up @@ -247,7 +261,7 @@ def define_node(
) -> None:
import serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, [3, 4, 6])
validate_num_inputs(self.target, inputs, [3, 4, 5, 6, 7])
validate_same_dtype(self.target, [inputs[0], output], ts)
validate_valid_dtype(
self.target, [inputs[0], output], ts.DType.INT8, output.tosa_spec
Expand Down Expand Up @@ -286,7 +300,7 @@ def define_node(
) -> None:
import serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, [3, 4, 6])
validate_num_inputs(self.target, inputs, [3, 4, 5, 6, 7])
validate_same_dtype(self.target, [inputs[0], output], ts)
validate_valid_dtype(
self.target,
Expand Down
Loading
Loading