Skip to content

Commit 212ef58

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
Support ceil_mode=True on maxpool2d in XNNPACK delegate (#7355)
Summary: Support ceil_mode=True for MaxPool2d in XNNPACK delegate by conditionally updating padding. This works only when the input shape is static, as it requires statically computing the change in padding. As such, the partitioner constraints are also updated to reflect this. Test Plan: Added a new test (test_fp32_maxpool2d_ceilmode) that exercises various permutations of input with ceil_mode=True. Reviewed By: digantdesai, mcr229 Differential Revision: D67386151 Pulled By: GregoryComer
1 parent a396b47 commit 212ef58

File tree

3 files changed

+111
-13
lines changed

3 files changed

+111
-13
lines changed

backends/xnnpack/operators/op_max_pool2d.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-unsafe
8+
79
from typing import cast, Dict, List
810

911
import torch
@@ -51,9 +53,9 @@ def define_node(
5153
kwargs["output_id"] = vals_to_ids[node]
5254

5355
# kernel info
54-
kernal_shape = cast(List[int], node.args[1])
55-
kwargs["pooling_height"] = kernal_shape[0]
56-
kwargs["pooling_width"] = kernal_shape[1]
56+
kernel_shape = cast(List[int], node.args[1])
57+
kwargs["pooling_height"] = kernel_shape[0]
58+
kwargs["pooling_width"] = kernel_shape[1]
5759

5860
# stride info
5961
stride = cast(List[int], node.args[2])
@@ -81,6 +83,26 @@ def define_node(
8183
kwargs["dilation_height"] = dilation[0]
8284
kwargs["dilation_width"] = dilation[1]
8385

86+
# ceil mode
87+
ceil_mode = node.args[5] if len(node.args) > 5 else False
88+
if ceil_mode:
89+
# use original input shape as xnnpack input may be permuted
90+
orig_input_shape = node.all_input_nodes[0].meta["val"].shape
91+
kwargs["padding_bottom"] += self.calculate_pad_amount_1d(
92+
orig_input_shape[2],
93+
kernel_shape[0],
94+
stride[0],
95+
padding_shape[0],
96+
dilation[0],
97+
)
98+
kwargs["padding_right"] += self.calculate_pad_amount_1d(
99+
orig_input_shape[3],
100+
kernel_shape[1],
101+
stride[1],
102+
padding_shape[1],
103+
dilation[1],
104+
)
105+
84106
kwargs["flags"] = XNN_FLAG_KEEP_DIMS
85107

86108
ser_node = XNode(
@@ -90,3 +112,25 @@ def define_node(
90112
debug_handle=debug_handle,
91113
)
92114
xnn_graph.xnodes.append(ser_node)
115+
116+
def calculate_pad_amount_1d(self, in_size, kernel_size, stride, padding, dilation):
117+
# Determine the number of padding elements to add along a single dimension
118+
# to match the ceil_mode=True behavior.
119+
# See https://pytorch.org/docs/stable/generated/torch.nn.MaxPool1d.html
120+
121+
# Determine the number of input elements to exactly bump up the output size
122+
# by 1. Note that there is an additional condition to substract 1 from the
123+
# output when ceil_mode=True and (output_size - 1) * stride >= in_size + padding
124+
# In this case, we don't need to pad, as ceil_mode=False and True give the
125+
# same behavior.
126+
numerator_no_ceil = in_size + 2 * padding - dilation * (kernel_size - 1) - 1
127+
numerator = numerator_no_ceil + stride - 1
128+
output_size = numerator // stride + 1
129+
130+
needs_adjust = (output_size - 1) * stride >= in_size + padding
131+
partial_stride = numerator_no_ceil % stride
132+
pad_out = (
133+
(stride - partial_stride) if partial_stride > 0 and not needs_adjust else 0
134+
)
135+
136+
return pad_out

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-unsafe
8+
79
import logging
810
from typing import cast, List, Optional
911

@@ -287,9 +289,13 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
287289
if not self.check_common_constraints(node, ep):
288290
return False
289291

292+
# Ceil mode is supported via op padding, which must be statically known.
290293
is_ceil_mode = len(node.args) >= 6 and cast(bool, node.args[5])
291-
if is_ceil_mode:
292-
why(node, reason="ceil mode is not supported")
294+
is_dynamic = "val" in node.meta and any(
295+
isinstance(d, torch.SymInt) for d in node.meta["val"].shape
296+
)
297+
if is_ceil_mode and is_dynamic:
298+
why(node, reason="ceil mode is not supported for dynamic shapes")
293299
return False
294300
return True
295301

backends/xnnpack/test/ops/test_maxpool2d.py

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,14 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-unsafe
8+
9+
import itertools
710
import unittest
811

912
import torch
10-
from executorch.backends.xnnpack.test.tester import Tester
13+
from executorch.backends.xnnpack.test.tester import Export, Tester
14+
from torch.export.dynamic_shapes import Dim
1115

1216

1317
class TestMaxPool2d(unittest.TestCase):
@@ -38,10 +42,12 @@ def __init__(self, kernel_size=3, stride=1, padding=0, dilation=1):
3842
def forward(self, x):
3943
return self.max_pool2d_module(x)[1]
4044

41-
class MaxPool2dUnsupportedCeilMode(torch.nn.Module):
42-
def __init__(self):
45+
class MaxPool2dCeilMode(torch.nn.Module):
46+
def __init__(self, kernel_size=3, stride=1, padding=0, dilation=1):
4347
super().__init__()
44-
self.max_pool2d_module = torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
48+
self.max_pool2d_module = torch.nn.MaxPool2d(
49+
kernel_size, stride, padding, dilation, ceil_mode=True
50+
)
4551

4652
def forward(self, x):
4753
return self.max_pool2d_module(x)
@@ -93,14 +99,56 @@ def test_fp32_maxpool2d_unsupported(self):
9399
)
94100
)
95101

96-
def test_fp32_maxpool2d_unsupported_ceilmode(self):
102+
def test_fp32_maxpool2d_ceilmode(self):
103+
input_sizes = [[17, 32], [32, 37]]
104+
kernel_sizes = [2, 3, 12]
105+
strides = [1, 2, 4]
106+
padding = [0, 1, 5]
107+
dilations = [1, 2, 3]
108+
109+
for input_size, kernel_size, stride, pad, dilation in itertools.product(
110+
input_sizes, kernel_sizes, strides, padding, dilations
111+
):
112+
# Check XNNPACK and PyTorch constraints
113+
if pad > ((kernel_size - 1) * dilation + 1) / 2:
114+
continue
115+
if stride > kernel_size:
116+
continue
117+
if any(
118+
(size + 2 * pad - dilation * (kernel_size - 1) - 1) // stride + 1 <= 0
119+
for size in input_size
120+
): # Output size too small
121+
continue
122+
123+
inputs = (torch.randn(1, 1, input_size[0], input_size[1]),)
124+
(
125+
Tester(
126+
self.MaxPool2dCeilMode(kernel_size, stride, pad, dilation), inputs
127+
)
128+
.export()
129+
.check_count({"torch.ops.aten.max_pool2d.default": 1})
130+
.to_edge_transform_and_lower()
131+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
132+
.check_not(
133+
[
134+
"executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default"
135+
]
136+
)
137+
.to_executorch()
138+
.serialize()
139+
.run_method_and_compare_outputs()
140+
)
141+
142+
def test_fp32_maxpool2d_unsupported_dynamic_ceilmode(self):
97143
"""
98-
MaxPool2d with ceil mode is not generally supported (see maxpool2d constraint).
144+
MaxPool2d with ceil mode is supported with dynamic shape (see maxpool2d constraint).
99145
"""
100146
inputs = (torch.randn(1, 32, 23, 23),)
147+
dim3 = Dim("_dim3", min=11, max=50)
148+
dynamic_shapes = {"x": {3: 2 * dim3 - 1}}
101149
(
102-
Tester(self.MaxPool2dUnsupportedCeilMode(), inputs)
103-
.export()
150+
Tester(self.MaxPool2dCeilMode(), inputs)
151+
.export(Export(dynamic_shapes=dynamic_shapes))
104152
.check_count({"torch.ops.aten.max_pool2d.default": 1})
105153
.to_edge_transform_and_lower()
106154
# We expect it not be be delegated.

0 commit comments

Comments
 (0)