Skip to content

Commit 08dfe52

Browse files
authored
Arm backend: Adjust AvgPool2d padding when window is not divisible by stride (#10972)
* AvgPool2dVisitor will adjust the padding so the pooling window is divisible by the stride * Improve tests in test_max_pool.py Signed-off-by: Tom Allsop <[email protected]>
1 parent d509ee3 commit 08dfe52

File tree

5 files changed

+85
-22
lines changed

5 files changed

+85
-22
lines changed

backends/arm/operators/op_avg_pool2d.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
register_node_visitor,
1818
)
1919
from executorch.backends.arm.operators.operator_validation_utils import (
20+
adjust_pooling_pad_if_needed,
2021
validate_num_inputs,
2122
validate_same_dtype,
2223
)
@@ -63,6 +64,20 @@ def _build_generic_avgpool2d(
6364
except IndexError:
6465
pad_size_list = [0, 0, 0, 0]
6566

67+
# Adjust the padding as necessary
68+
pad_size_list[1] = adjust_pooling_pad_if_needed(
69+
input_tensor.shape[2],
70+
kernel_size_list[0],
71+
stride_size_list[0],
72+
pad_size_list[1],
73+
)
74+
pad_size_list[3] = adjust_pooling_pad_if_needed(
75+
input_tensor.shape[3],
76+
kernel_size_list[1],
77+
stride_size_list[1],
78+
pad_size_list[3],
79+
)
80+
6681
attr = ts.TosaSerializerAttribute()
6782
attr.PoolAttribute(
6883
kernel=kernel_size_list,
@@ -192,6 +207,20 @@ def _build_generic_avgpool2d(
192207
except IndexError:
193208
pad_size_list = [0, 0, 0, 0]
194209

210+
# Adjust the padding as necessary
211+
pad_size_list[1] = adjust_pooling_pad_if_needed(
212+
input_tensor.shape[2],
213+
kernel_size_list[0],
214+
stride_size_list[0],
215+
pad_size_list[1],
216+
)
217+
pad_size_list[3] = adjust_pooling_pad_if_needed(
218+
input_tensor.shape[3],
219+
kernel_size_list[1],
220+
stride_size_list[1],
221+
pad_size_list[3],
222+
)
223+
195224
attr = ts.TosaSerializerAttribute()
196225
attr.AvgPool2dAttribute(
197226
kernel=kernel_size_list,

backends/arm/operators/op_max_pool2d.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,31 +17,14 @@
1717
register_node_visitor,
1818
)
1919
from executorch.backends.arm.operators.operator_validation_utils import (
20+
adjust_pooling_pad_if_needed,
2021
validate_num_inputs,
2122
validate_same_dtype,
2223
)
2324
from executorch.backends.arm.tosa_mapping import TosaArg
2425
from executorch.backends.arm.tosa_specification import TosaSpecification
2526

2627

27-
# Similarly to Conv2d, the TOSA spec requires that following is exactly divisible:
28-
# `(input + 2 * pad - kernel_size) / stride`
29-
# PyTorch however, does not require this, so as needed, we must adjust the padding.
30-
def adjust_pad_if_needed(
31-
input_size: int, kernel_size: int, stride: int, pad: int
32-
) -> int:
33-
if pad == 0:
34-
return pad
35-
36-
mod_remainder = (input_size + 2 * pad - kernel_size) % stride
37-
38-
# No need to adjust
39-
if mod_remainder == 0:
40-
return pad
41-
42-
return pad - mod_remainder
43-
44-
4528
@register_node_visitor
4629
class MaxPool2dVisitor_0_80(NodeVisitor):
4730
target = "aten.max_pool2d.default"
@@ -82,13 +65,13 @@ def define_node(
8265
pad_size_list = [0, 0, 0, 0]
8366

8467
# Adjust the padding as necessary
85-
pad_size_list[1] = adjust_pad_if_needed(
68+
pad_size_list[1] = adjust_pooling_pad_if_needed(
8669
input_tensor.shape[2],
8770
kernel_size[0],
8871
stride[0],
8972
pad_size_list[1],
9073
)
91-
pad_size_list[3] = adjust_pad_if_needed(
74+
pad_size_list[3] = adjust_pooling_pad_if_needed(
9275
input_tensor.shape[3],
9376
kernel_size[1],
9477
stride[1],
@@ -167,13 +150,13 @@ def define_node(
167150
pad_size_list = [0, 0, 0, 0]
168151

169152
# Adjust the padding as necessary
170-
pad_size_list[1] = adjust_pad_if_needed(
153+
pad_size_list[1] = adjust_pooling_pad_if_needed(
171154
input_tensor.shape[2],
172155
kernel_size[0],
173156
stride[0],
174157
pad_size_list[1],
175158
)
176-
pad_size_list[3] = adjust_pad_if_needed(
159+
pad_size_list[3] = adjust_pooling_pad_if_needed(
177160
input_tensor.shape[3],
178161
kernel_size[1],
179162
stride[1],

backends/arm/operators/operator_validation_utils.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,40 @@ def validate_same_dtype(op_name: str, tensors: List[Any]):
9999
f"{op_name}: Expected all tensors to have dtype {reference_dtype}, but "
100100
f"found inconsistent dtype {tensor.dtype}."
101101
)
102+
103+
104+
def adjust_pooling_pad_if_needed(
105+
input_size: int, kernel_size: int, stride: int, pad: int
106+
) -> int:
107+
"""
108+
Calculates the padding that needs to be removed to a pooling window to make it
109+
divisible by the kernels stride. All inputs should correspond to the same dimension.
110+
111+
Parameters:
112+
-----------
113+
input_size : int
114+
The size of the input to the operator.
115+
116+
kernel_size : int
117+
The size of the kernel.
118+
119+
stride : int
120+
The size of the stride.
121+
122+
pad : int
123+
The amount of padding.
124+
125+
Output:
126+
-------
127+
An int, representing the padding to remove to make the window divisible.
128+
"""
129+
if pad == 0:
130+
return pad
131+
132+
mod_remainder = (input_size + 2 * pad - kernel_size) % stride
133+
134+
# No need to adjust
135+
if mod_remainder == 0:
136+
return pad
137+
138+
return pad - mod_remainder

backends/arm/test/ops/test_avg_pool2d.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,18 @@ def forward(self, x):
5959
AvgPool2d((4, 6), (1, 2), (2, 3)),
6060
(torch.rand(1, 16, 50, 32),),
6161
),
62+
"non_divisible_window": lambda: (
63+
AvgPool2d(3, 2, 1),
64+
(torch.rand(1, 16, 112, 112),),
65+
),
66+
"non_divisible_window_height": lambda: (
67+
AvgPool2d(3, (2, 1), 1),
68+
(torch.rand(1, 16, 56, 56),),
69+
),
70+
"non_divisible_window_width": lambda: (
71+
AvgPool2d(3, (1, 2), 1),
72+
(torch.rand(1, 16, 56, 56),),
73+
),
6274
}
6375

6476

backends/arm/test/ops/test_max_pool.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
"ones": lambda: (torch.ones(1, 16, 50, 32), [4, 2, 0]),
2727
"rand": lambda: (torch.rand(1, 16, 52, 16), [4, 3, 0]),
2828
"non_divisible": lambda: (torch.rand(1, 16, 112, 112), [3, 2, 1]),
29+
"non_divisible_window_height": lambda: (torch.rand(1, 16, 56, 56), [3, (2, 1), 1]),
30+
"non_divisible_window_width": lambda: (torch.rand(1, 16, 56, 56), [3, (1, 2), 1]),
2931
}
3032

3133
test_data_suite_mult_batches = {

0 commit comments

Comments
 (0)