Skip to content

Commit 82763a9

Browse files
authored
Support ceil_mode=True on maxpool2d in XNNPACK delegate
Differential Revision: D67386151 Pull Request resolved: #7355
1 parent 6c3a792 commit 82763a9

File tree

3 files changed

+111
-13
lines changed

3 files changed

+111
-13
lines changed

backends/xnnpack/operators/op_max_pool2d.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-unsafe
8+
79
from typing import cast, Dict, List
810

911
import torch
@@ -51,9 +53,9 @@ def define_node(
5153
kwargs["output_id"] = vals_to_ids[node]
5254

5355
# kernel info
54-
kernal_shape = cast(List[int], node.args[1])
55-
kwargs["pooling_height"] = kernal_shape[0]
56-
kwargs["pooling_width"] = kernal_shape[1]
56+
kernel_shape = cast(List[int], node.args[1])
57+
kwargs["pooling_height"] = kernel_shape[0]
58+
kwargs["pooling_width"] = kernel_shape[1]
5759

5860
# stride info
5961
stride = cast(List[int], node.args[2])
@@ -81,6 +83,26 @@ def define_node(
8183
kwargs["dilation_height"] = dilation[0]
8284
kwargs["dilation_width"] = dilation[1]
8385

86+
# ceil mode
87+
ceil_mode = node.args[5] if len(node.args) > 5 else False
88+
if ceil_mode:
89+
# use original input shape as xnnpack input may be permuted
90+
orig_input_shape = node.all_input_nodes[0].meta["val"].shape
91+
kwargs["padding_bottom"] += self.calculate_pad_amount_1d(
92+
orig_input_shape[2],
93+
kernel_shape[0],
94+
stride[0],
95+
padding_shape[0],
96+
dilation[0],
97+
)
98+
kwargs["padding_right"] += self.calculate_pad_amount_1d(
99+
orig_input_shape[3],
100+
kernel_shape[1],
101+
stride[1],
102+
padding_shape[1],
103+
dilation[1],
104+
)
105+
84106
kwargs["flags"] = XNN_FLAG_KEEP_DIMS
85107

86108
ser_node = XNode(
@@ -90,3 +112,25 @@ def define_node(
90112
debug_handle=debug_handle,
91113
)
92114
xnn_graph.xnodes.append(ser_node)
115+
116+
def calculate_pad_amount_1d(self, in_size, kernel_size, stride, padding, dilation):
117+
# Determine the number of padding elements to add along a single dimension
118+
# to match the ceil_mode=True behavior.
119+
# See https://pytorch.org/docs/stable/generated/torch.nn.MaxPool1d.html
120+
121+
# Determine the number of input elements to exactly bump up the output size
122+
# by 1. Note that there is an additional condition to substract 1 from the
123+
# output when ceil_mode=True and (output_size - 1) * stride >= in_size + padding
124+
# In this case, we don't need to pad, as ceil_mode=False and True give the
125+
# same behavior.
126+
numerator_no_ceil = in_size + 2 * padding - dilation * (kernel_size - 1) - 1
127+
numerator = numerator_no_ceil + stride - 1
128+
output_size = numerator // stride + 1
129+
130+
needs_adjust = (output_size - 1) * stride >= in_size + padding
131+
partial_stride = numerator_no_ceil % stride
132+
pad_out = (
133+
(stride - partial_stride) if partial_stride > 0 and not needs_adjust else 0
134+
)
135+
136+
return pad_out

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-unsafe
8+
79
import logging
810
from typing import cast, List, Optional
911

@@ -287,9 +289,13 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
287289
if not self.check_common_constraints(node, ep):
288290
return False
289291

292+
# Ceil mode is supported via op padding, which must be statically known.
290293
is_ceil_mode = len(node.args) >= 6 and cast(bool, node.args[5])
291-
if is_ceil_mode:
292-
why(node, reason="ceil mode is not supported")
294+
is_dynamic = "val" in node.meta and any(
295+
isinstance(d, torch.SymInt) for d in node.meta["val"].shape
296+
)
297+
if is_ceil_mode and is_dynamic:
298+
why(node, reason="ceil mode is not supported for dynamic shapes")
293299
return False
294300
return True
295301

backends/xnnpack/test/ops/test_maxpool2d.py

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,14 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-unsafe
8+
9+
import itertools
710
import unittest
811

912
import torch
10-
from executorch.backends.xnnpack.test.tester import Tester
13+
from executorch.backends.xnnpack.test.tester import Export, Tester
14+
from torch.export.dynamic_shapes import Dim
1115

1216

1317
class TestMaxPool2d(unittest.TestCase):
@@ -38,10 +42,12 @@ def __init__(self, kernel_size=3, stride=1, padding=0, dilation=1):
3842
def forward(self, x):
3943
return self.max_pool2d_module(x)[1]
4044

41-
class MaxPool2dUnsupportedCeilMode(torch.nn.Module):
42-
def __init__(self):
45+
class MaxPool2dCeilMode(torch.nn.Module):
46+
def __init__(self, kernel_size=3, stride=1, padding=0, dilation=1):
4347
super().__init__()
44-
self.max_pool2d_module = torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
48+
self.max_pool2d_module = torch.nn.MaxPool2d(
49+
kernel_size, stride, padding, dilation, ceil_mode=True
50+
)
4551

4652
def forward(self, x):
4753
return self.max_pool2d_module(x)
@@ -93,14 +99,56 @@ def test_fp32_maxpool2d_unsupported(self):
9399
)
94100
)
95101

96-
def test_fp32_maxpool2d_unsupported_ceilmode(self):
102+
def test_fp32_maxpool2d_ceilmode(self):
103+
input_sizes = [[17, 32], [32, 37]]
104+
kernel_sizes = [2, 3, 12]
105+
strides = [1, 2, 4]
106+
padding = [0, 1, 5]
107+
dilations = [1, 2, 3]
108+
109+
for input_size, kernel_size, stride, pad, dilation in itertools.product(
110+
input_sizes, kernel_sizes, strides, padding, dilations
111+
):
112+
# Check XNNPACK and PyTorch constraints
113+
if pad > ((kernel_size - 1) * dilation + 1) / 2:
114+
continue
115+
if stride > kernel_size:
116+
continue
117+
if any(
118+
(size + 2 * pad - dilation * (kernel_size - 1) - 1) // stride + 1 <= 0
119+
for size in input_size
120+
): # Output size too small
121+
continue
122+
123+
inputs = (torch.randn(1, 1, input_size[0], input_size[1]),)
124+
(
125+
Tester(
126+
self.MaxPool2dCeilMode(kernel_size, stride, pad, dilation), inputs
127+
)
128+
.export()
129+
.check_count({"torch.ops.aten.max_pool2d.default": 1})
130+
.to_edge_transform_and_lower()
131+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
132+
.check_not(
133+
[
134+
"executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default"
135+
]
136+
)
137+
.to_executorch()
138+
.serialize()
139+
.run_method_and_compare_outputs()
140+
)
141+
142+
def test_fp32_maxpool2d_unsupported_dynamic_ceilmode(self):
97143
"""
98-
MaxPool2d with ceil mode is not generally supported (see maxpool2d constraint).
144+
MaxPool2d with ceil mode is supported with dynamic shape (see maxpool2d constraint).
99145
"""
100146
inputs = (torch.randn(1, 32, 23, 23),)
147+
dim3 = Dim("_dim3", min=11, max=50)
148+
dynamic_shapes = {"x": {3: 2 * dim3 - 1}}
101149
(
102-
Tester(self.MaxPool2dUnsupportedCeilMode(), inputs)
103-
.export()
150+
Tester(self.MaxPool2dCeilMode(), inputs)
151+
.export(Export(dynamic_shapes=dynamic_shapes))
104152
.check_count({"torch.ops.aten.max_pool2d.default": 1})
105153
.to_edge_transform_and_lower()
106154
# We expect it not be be delegated.

0 commit comments

Comments
 (0)