Skip to content

Commit a27bbd1

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
Support ceil_mode=True on maxpool2d in XNNPACK delegate
Summary: Support ceil_mode=True for MaxPool2d in XNNPACK delegate by conditionally updating padding. This works only when the input shape is static, as it requires statically computing the change in padding. As such, the partitioner constraints are also updated to reflect this. Differential Revision: D67386151
1 parent eca5d9f commit a27bbd1

File tree

3 files changed

+85
-13
lines changed

3 files changed

+85
-13
lines changed

backends/xnnpack/operators/op_max_pool2d.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ def define_node(
5151
kwargs["output_id"] = vals_to_ids[node]
5252

5353
# kernel info
54-
kernal_shape = cast(List[int], node.args[1])
55-
kwargs["pooling_height"] = kernal_shape[0]
56-
kwargs["pooling_width"] = kernal_shape[1]
54+
kernel_shape = cast(List[int], node.args[1])
55+
kwargs["pooling_height"] = kernel_shape[0]
56+
kwargs["pooling_width"] = kernel_shape[1]
5757

5858
# stride info
5959
stride = cast(List[int], node.args[2])
@@ -80,6 +80,18 @@ def define_node(
8080
dilation = cast(List[int], node.args[4])
8181
kwargs["dilation_height"] = dilation[0]
8282
kwargs["dilation_width"] = dilation[1]
83+
84+
# ceil mode
85+
ceil_mode = node.args[5] if len(node.args) > 5 else False
86+
if ceil_mode:
87+
# use original input shape as xnnpack input may be permuted
88+
orig_input_shape = node.all_input_nodes[0].meta["val"].shape
89+
kwargs["padding_bottom"] += self.calculate_pad_amount_1d(
90+
orig_input_shape[2], kernel_shape[0], stride[0], padding_shape[0], dilation[0]
91+
)
92+
kwargs["padding_right"] += self.calculate_pad_amount_1d(
93+
orig_input_shape[3], kernel_shape[1], stride[1], padding_shape[1], dilation[1]
94+
)
8395

8496
kwargs["flags"] = XNN_FLAG_KEEP_DIMS
8597

@@ -90,3 +102,23 @@ def define_node(
90102
debug_handle=debug_handle,
91103
)
92104
xnn_graph.xnodes.append(ser_node)
105+
106+
def calculate_pad_amount_1d(self, in_size, kernel_size, stride, padding, dilation):
107+
# Determine the number of padding elements to add along a single dimension
108+
# to match the ceil_mode=True behavior.
109+
# See https://pytorch.org/docs/stable/generated/torch.nn.MaxPool1d.html
110+
111+
# Determine the number of input elements to exactly bump up the output size
112+
# by 1. Note that there is an additional condition to substract 1 from the
113+
# output when ceil_mode=True and (output_size - 1) * stride >= in_size + padding
114+
# In this case, we don't need to pad, as ceil_mode=False and True give the
115+
# same behavior.
116+
numerator_no_ceil = in_size + 2 * padding - dilation * (kernel_size - 1) - 1
117+
numerator = numerator_no_ceil + stride - 1
118+
output_size = numerator // stride + 1
119+
120+
needs_adjust = (output_size - 1) * stride >= in_size + padding
121+
partial_stride = numerator_no_ceil % stride
122+
pad_out = (stride - partial_stride) if partial_stride > 0 and not needs_adjust else 0
123+
124+
return pad_out

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,9 +287,13 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
287287
if not self.check_common_constraints(node, ep):
288288
return False
289289

290+
# Ceil mode is supported via op padding, which must be statically known.
290291
is_ceil_mode = len(node.args) >= 6 and cast(bool, node.args[5])
291-
if is_ceil_mode:
292-
why(node, reason="ceil mode is not supported")
292+
is_dynamic = "val" in node.meta and any(
293+
isinstance(d, torch.SymInt) for d in node.meta["val"].shape
294+
)
295+
if is_ceil_mode and is_dynamic:
296+
why(node, reason="ceil mode is not supported for dynamic shapes")
293297
return False
294298
return True
295299

backends/xnnpack/test/ops/test_maxpool2d.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import itertools
78
import unittest
89

910
import torch
10-
from executorch.backends.xnnpack.test.tester import Tester
11+
from executorch.backends.xnnpack.test.tester import Export, Tester
12+
from torch.export.dynamic_shapes import Dim
1113

1214

1315
class TestMaxPool2d(unittest.TestCase):
@@ -38,10 +40,10 @@ def __init__(self, kernel_size=3, stride=1, padding=0, dilation=1):
3840
def forward(self, x):
3941
return self.max_pool2d_module(x)[1]
4042

41-
class MaxPool2dUnsupportedCeilMode(torch.nn.Module):
42-
def __init__(self):
43+
class MaxPool2dCeilMode(torch.nn.Module):
44+
def __init__(self, kernel_size=3, stride=1, padding=0, dilation=1):
4345
super().__init__()
44-
self.max_pool2d_module = torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
46+
self.max_pool2d_module = torch.nn.MaxPool2d(kernel_size, stride, padding, dilation, ceil_mode=True)
4547

4648
def forward(self, x):
4749
return self.max_pool2d_module(x)
@@ -92,15 +94,49 @@ def test_fp32_maxpool2d_unsupported(self):
9294
}
9395
)
9496
)
97+
def test_fp32_maxpool2d_ceilmode(self):
98+
input_sizes = [[17, 32], [32, 37]]
99+
kernel_sizes = [2, 3, 12]
100+
strides = [1, 2, 4]
101+
padding = [0, 1, 5]
102+
dilation = [1, 2, 3]
103+
104+
for input_size, kernel_size, stride, pad, dilation in itertools.product(input_sizes, kernel_sizes, strides, padding, dilation):
105+
# Check XNNPACK and PyTorch constraints
106+
if pad > ((kernel_size - 1) * dilation + 1) / 2:
107+
continue
108+
if stride > kernel_size:
109+
continue
110+
if any((size + 2 * pad - dilation * (kernel_size - 1) - 1) // stride + 1 <= 0 for size in input_size): # Output size too small
111+
continue
112+
113+
inputs = (torch.randn(1, 1, input_size[0], input_size[1]),)
114+
(
115+
Tester(self.MaxPool2dCeilMode(kernel_size, stride, pad, dilation), inputs)
116+
.export()
117+
.check_count({"torch.ops.aten.max_pool2d.default": 1})
118+
.to_edge_transform_and_lower()
119+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
120+
.check_not(
121+
[
122+
"executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default"
123+
]
124+
)
125+
.to_executorch()
126+
.serialize()
127+
.run_method_and_compare_outputs()
128+
)
95129

96-
def test_fp32_maxpool2d_unsupported_ceilmode(self):
130+
def test_fp32_maxpool2d_unsupported_dynamic_ceilmode(self):
97131
"""
98-
MaxPool2d with ceil mode is not generally supported (see maxpool2d constraint).
132+
MaxPool2d with ceil mode is supported with dynamic shape (see maxpool2d constraint).
99133
"""
100134
inputs = (torch.randn(1, 32, 23, 23),)
135+
dim3 = Dim('_dim3', min=11, max=50)
136+
dynamic_shapes = {"x": {3: 2*dim3 - 1}}
101137
(
102-
Tester(self.MaxPool2dUnsupportedCeilMode(), inputs)
103-
.export()
138+
Tester(self.MaxPool2dCeilMode(), inputs)
139+
.export(Export(dynamic_shapes=dynamic_shapes))
104140
.check_count({"torch.ops.aten.max_pool2d.default": 1})
105141
.to_edge_transform_and_lower()
106142
# We expect it not be be delegated.

0 commit comments

Comments
 (0)