Skip to content

Commit 0711133

Browse files
committed
Support ceil_mode=True on maxpool2d in XNNPACK delegate
Summary: Support ceil_mode=True for MaxPool2d in XNNPACK delegate by conditionally updating padding. This works only when the input shape is static, as it requires statically computing the change in padding. As such, the partitioner constraints are also updated to reflect this. Differential Revision: D67386151
1 parent eca5d9f commit 0711133

File tree

3 files changed

+105
-13
lines changed

3 files changed

+105
-13
lines changed

backends/xnnpack/operators/op_max_pool2d.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ def define_node(
5151
kwargs["output_id"] = vals_to_ids[node]
5252

5353
# kernel info
54-
kernal_shape = cast(List[int], node.args[1])
55-
kwargs["pooling_height"] = kernal_shape[0]
56-
kwargs["pooling_width"] = kernal_shape[1]
54+
kernel_shape = cast(List[int], node.args[1])
55+
kwargs["pooling_height"] = kernel_shape[0]
56+
kwargs["pooling_width"] = kernel_shape[1]
5757

5858
# stride info
5959
stride = cast(List[int], node.args[2])
@@ -81,6 +81,26 @@ def define_node(
8181
kwargs["dilation_height"] = dilation[0]
8282
kwargs["dilation_width"] = dilation[1]
8383

84+
# ceil mode
85+
ceil_mode = node.args[5] if len(node.args) > 5 else False
86+
if ceil_mode:
87+
# use original input shape as xnnpack input may be permuted
88+
orig_input_shape = node.all_input_nodes[0].meta["val"].shape
89+
kwargs["padding_bottom"] += self.calculate_pad_amount_1d(
90+
orig_input_shape[2],
91+
kernel_shape[0],
92+
stride[0],
93+
padding_shape[0],
94+
dilation[0],
95+
)
96+
kwargs["padding_right"] += self.calculate_pad_amount_1d(
97+
orig_input_shape[3],
98+
kernel_shape[1],
99+
stride[1],
100+
padding_shape[1],
101+
dilation[1],
102+
)
103+
84104
kwargs["flags"] = XNN_FLAG_KEEP_DIMS
85105

86106
ser_node = XNode(
@@ -90,3 +110,25 @@ def define_node(
90110
debug_handle=debug_handle,
91111
)
92112
xnn_graph.xnodes.append(ser_node)
113+
114+
def calculate_pad_amount_1d(self, in_size, kernel_size, stride, padding, dilation):
115+
# Determine the number of padding elements to add along a single dimension
116+
# to match the ceil_mode=True behavior.
117+
# See https://pytorch.org/docs/stable/generated/torch.nn.MaxPool1d.html
118+
119+
# Determine the number of input elements to exactly bump up the output size
120+
# by 1. Note that there is an additional condition to substract 1 from the
121+
# output when ceil_mode=True and (output_size - 1) * stride >= in_size + padding
122+
# In this case, we don't need to pad, as ceil_mode=False and True give the
123+
# same behavior.
124+
numerator_no_ceil = in_size + 2 * padding - dilation * (kernel_size - 1) - 1
125+
numerator = numerator_no_ceil + stride - 1
126+
output_size = numerator // stride + 1
127+
128+
needs_adjust = (output_size - 1) * stride >= in_size + padding
129+
partial_stride = numerator_no_ceil % stride
130+
pad_out = (
131+
(stride - partial_stride) if partial_stride > 0 and not needs_adjust else 0
132+
)
133+
134+
return pad_out

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,9 +287,13 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
287287
if not self.check_common_constraints(node, ep):
288288
return False
289289

290+
# Ceil mode is supported via op padding, which must be statically known.
290291
is_ceil_mode = len(node.args) >= 6 and cast(bool, node.args[5])
291-
if is_ceil_mode:
292-
why(node, reason="ceil mode is not supported")
292+
is_dynamic = "val" in node.meta and any(
293+
isinstance(d, torch.SymInt) for d in node.meta["val"].shape
294+
)
295+
if is_ceil_mode and is_dynamic:
296+
why(node, reason="ceil mode is not supported for dynamic shapes")
293297
return False
294298
return True
295299

backends/xnnpack/test/ops/test_maxpool2d.py

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import itertools
78
import unittest
89

910
import torch
10-
from executorch.backends.xnnpack.test.tester import Tester
11+
from executorch.backends.xnnpack.test.tester import Export, Tester
12+
from torch.export.dynamic_shapes import Dim
1113

1214

1315
class TestMaxPool2d(unittest.TestCase):
@@ -38,10 +40,12 @@ def __init__(self, kernel_size=3, stride=1, padding=0, dilation=1):
3840
def forward(self, x):
3941
return self.max_pool2d_module(x)[1]
4042

41-
class MaxPool2dUnsupportedCeilMode(torch.nn.Module):
42-
def __init__(self):
43+
class MaxPool2dCeilMode(torch.nn.Module):
44+
def __init__(self, kernel_size=3, stride=1, padding=0, dilation=1):
4345
super().__init__()
44-
self.max_pool2d_module = torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
46+
self.max_pool2d_module = torch.nn.MaxPool2d(
47+
kernel_size, stride, padding, dilation, ceil_mode=True
48+
)
4549

4650
def forward(self, x):
4751
return self.max_pool2d_module(x)
@@ -93,14 +97,56 @@ def test_fp32_maxpool2d_unsupported(self):
9397
)
9498
)
9599

96-
def test_fp32_maxpool2d_unsupported_ceilmode(self):
100+
def test_fp32_maxpool2d_ceilmode(self):
101+
input_sizes = [[17, 32], [32, 37]]
102+
kernel_sizes = [2, 3, 12]
103+
strides = [1, 2, 4]
104+
padding = [0, 1, 5]
105+
dilations = [1, 2, 3]
106+
107+
for input_size, kernel_size, stride, pad, dilation in itertools.product(
108+
input_sizes, kernel_sizes, strides, padding, dilations
109+
):
110+
# Check XNNPACK and PyTorch constraints
111+
if pad > ((kernel_size - 1) * dilation + 1) / 2:
112+
continue
113+
if stride > kernel_size:
114+
continue
115+
if any(
116+
(size + 2 * pad - dilation * (kernel_size - 1) - 1) // stride + 1 <= 0
117+
for size in input_size
118+
): # Output size too small
119+
continue
120+
121+
inputs = (torch.randn(1, 1, input_size[0], input_size[1]),)
122+
(
123+
Tester(
124+
self.MaxPool2dCeilMode(kernel_size, stride, pad, dilation), inputs
125+
)
126+
.export()
127+
.check_count({"torch.ops.aten.max_pool2d.default": 1})
128+
.to_edge_transform_and_lower()
129+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
130+
.check_not(
131+
[
132+
"executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default"
133+
]
134+
)
135+
.to_executorch()
136+
.serialize()
137+
.run_method_and_compare_outputs()
138+
)
139+
140+
def test_fp32_maxpool2d_unsupported_dynamic_ceilmode(self):
97141
"""
98-
MaxPool2d with ceil mode is not generally supported (see maxpool2d constraint).
142+
MaxPool2d with ceil mode is supported with dynamic shape (see maxpool2d constraint).
99143
"""
100144
inputs = (torch.randn(1, 32, 23, 23),)
145+
dim3 = Dim("_dim3", min=11, max=50)
146+
dynamic_shapes = {"x": {3: 2 * dim3 - 1}}
101147
(
102-
Tester(self.MaxPool2dUnsupportedCeilMode(), inputs)
103-
.export()
148+
Tester(self.MaxPool2dCeilMode(), inputs)
149+
.export(Export(dynamic_shapes=dynamic_shapes))
104150
.check_count({"torch.ops.aten.max_pool2d.default": 1})
105151
.to_edge_transform_and_lower()
106152
# We expect it not be be delegated.

0 commit comments

Comments
 (0)