Skip to content

Arm backend: Adjust AvgPool2d padding when window is not divisible by stride #10972

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions backends/arm/operators/op_avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
register_node_visitor,
)
from executorch.backends.arm.operators.operator_validation_utils import (
adjust_pooling_pad_if_needed,
validate_num_inputs,
validate_same_dtype,
)
Expand Down Expand Up @@ -63,6 +64,20 @@ def _build_generic_avgpool2d(
except IndexError:
pad_size_list = [0, 0, 0, 0]

# Adjust the padding as necessary
pad_size_list[1] = adjust_pooling_pad_if_needed(
input_tensor.shape[2],
kernel_size_list[0],
stride_size_list[0],
pad_size_list[1],
)
pad_size_list[3] = adjust_pooling_pad_if_needed(
input_tensor.shape[3],
kernel_size_list[1],
stride_size_list[1],
pad_size_list[3],
)

attr = ts.TosaSerializerAttribute()
attr.PoolAttribute(
kernel=kernel_size_list,
Expand Down Expand Up @@ -192,6 +207,20 @@ def _build_generic_avgpool2d(
except IndexError:
pad_size_list = [0, 0, 0, 0]

# Adjust the padding as necessary
pad_size_list[1] = adjust_pooling_pad_if_needed(
input_tensor.shape[2],
kernel_size_list[0],
stride_size_list[0],
pad_size_list[1],
)
pad_size_list[3] = adjust_pooling_pad_if_needed(
input_tensor.shape[3],
kernel_size_list[1],
stride_size_list[1],
pad_size_list[3],
)

attr = ts.TosaSerializerAttribute()
attr.AvgPool2dAttribute(
kernel=kernel_size_list,
Expand Down
27 changes: 5 additions & 22 deletions backends/arm/operators/op_max_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,14 @@
register_node_visitor,
)
from executorch.backends.arm.operators.operator_validation_utils import (
adjust_pooling_pad_if_needed,
validate_num_inputs,
validate_same_dtype,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification


# Similarly to Conv2d, the TOSA spec requires that following is exactly divisible:
# `(input + 2 * pad - kernel_size) / stride`
# PyTorch however, does not require this, so as needed, we must adjust the padding.
def adjust_pad_if_needed(
input_size: int, kernel_size: int, stride: int, pad: int
) -> int:
if pad == 0:
return pad

mod_remainder = (input_size + 2 * pad - kernel_size) % stride

# No need to adjust
if mod_remainder == 0:
return pad

return pad - mod_remainder


@register_node_visitor
class MaxPool2dVisitor_0_80(NodeVisitor):
target = "aten.max_pool2d.default"
Expand Down Expand Up @@ -82,13 +65,13 @@ def define_node(
pad_size_list = [0, 0, 0, 0]

# Adjust the padding as necessary
pad_size_list[1] = adjust_pad_if_needed(
pad_size_list[1] = adjust_pooling_pad_if_needed(
input_tensor.shape[2],
kernel_size[0],
stride[0],
pad_size_list[1],
)
pad_size_list[3] = adjust_pad_if_needed(
pad_size_list[3] = adjust_pooling_pad_if_needed(
input_tensor.shape[3],
kernel_size[1],
stride[1],
Expand Down Expand Up @@ -167,13 +150,13 @@ def define_node(
pad_size_list = [0, 0, 0, 0]

# Adjust the padding as necessary
pad_size_list[1] = adjust_pad_if_needed(
pad_size_list[1] = adjust_pooling_pad_if_needed(
input_tensor.shape[2],
kernel_size[0],
stride[0],
pad_size_list[1],
)
pad_size_list[3] = adjust_pad_if_needed(
pad_size_list[3] = adjust_pooling_pad_if_needed(
input_tensor.shape[3],
kernel_size[1],
stride[1],
Expand Down
37 changes: 37 additions & 0 deletions backends/arm/operators/operator_validation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,40 @@ def validate_same_dtype(op_name: str, tensors: List[Any]):
f"{op_name}: Expected all tensors to have dtype {reference_dtype}, but "
f"found inconsistent dtype {tensor.dtype}."
)


def adjust_pooling_pad_if_needed(
input_size: int, kernel_size: int, stride: int, pad: int
) -> int:
"""
Calculates the padding that needs to be removed to a pooling window to make it
divisible by the kernels stride. All inputs should correspond to the same dimension.

Parameters:
-----------
input_size : int
The size of the input to the operator.

kernel_size : int
The size of the kernel.

stride : int
The size of the stride.

pad : int
The amount of padding.

Output:
-------
An int, representing the padding to remove to make the window divisible.
"""
if pad == 0:
return pad

mod_remainder = (input_size + 2 * pad - kernel_size) % stride

# No need to adjust
if mod_remainder == 0:
return pad

return pad - mod_remainder
12 changes: 12 additions & 0 deletions backends/arm/test/ops/test_avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,18 @@ def forward(self, x):
AvgPool2d((4, 6), (1, 2), (2, 3)),
(torch.rand(1, 16, 50, 32),),
),
"non_divisible_window": lambda: (
AvgPool2d(3, 2, 1),
(torch.rand(1, 16, 112, 112),),
),
"non_divisible_window_height": lambda: (
AvgPool2d(3, (2, 1), 1),
(torch.rand(1, 16, 56, 56),),
),
"non_divisible_window_width": lambda: (
AvgPool2d(3, (1, 2), 1),
(torch.rand(1, 16, 56, 56),),
),
}


Expand Down
2 changes: 2 additions & 0 deletions backends/arm/test/ops/test_max_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
"ones": lambda: (torch.ones(1, 16, 50, 32), [4, 2, 0]),
"rand": lambda: (torch.rand(1, 16, 52, 16), [4, 3, 0]),
"non_divisible": lambda: (torch.rand(1, 16, 112, 112), [3, 2, 1]),
"non_divisible_window_height": lambda: (torch.rand(1, 16, 56, 56), [3, (2, 1), 1]),
"non_divisible_window_width": lambda: (torch.rand(1, 16, 56, 56), [3, (1, 2), 1]),
}

test_data_suite_mult_batches = {
Expand Down
Loading