Skip to content

Support ceil_mode=True on maxpool2d in XNNPACK delegate #7355

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 47 additions & 3 deletions backends/xnnpack/operators/op_max_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from typing import cast, Dict, List

import torch
Expand Down Expand Up @@ -51,9 +53,9 @@ def define_node(
kwargs["output_id"] = vals_to_ids[node]

# kernel info
kernal_shape = cast(List[int], node.args[1])
kwargs["pooling_height"] = kernal_shape[0]
kwargs["pooling_width"] = kernal_shape[1]
kernel_shape = cast(List[int], node.args[1])
kwargs["pooling_height"] = kernel_shape[0]
kwargs["pooling_width"] = kernel_shape[1]

# stride info
stride = cast(List[int], node.args[2])
Expand Down Expand Up @@ -81,6 +83,26 @@ def define_node(
kwargs["dilation_height"] = dilation[0]
kwargs["dilation_width"] = dilation[1]

# ceil mode
ceil_mode = node.args[5] if len(node.args) > 5 else False
if ceil_mode:
# use original input shape as xnnpack input may be permuted
orig_input_shape = node.all_input_nodes[0].meta["val"].shape
kwargs["padding_bottom"] += self.calculate_pad_amount_1d(
orig_input_shape[2],
kernel_shape[0],
stride[0],
padding_shape[0],
dilation[0],
)
kwargs["padding_right"] += self.calculate_pad_amount_1d(
orig_input_shape[3],
kernel_shape[1],
stride[1],
padding_shape[1],
dilation[1],
)

kwargs["flags"] = XNN_FLAG_KEEP_DIMS

ser_node = XNode(
Expand All @@ -90,3 +112,25 @@ def define_node(
debug_handle=debug_handle,
)
xnn_graph.xnodes.append(ser_node)

def calculate_pad_amount_1d(self, in_size, kernel_size, stride, padding, dilation):
# Determine the number of padding elements to add along a single dimension
# to match the ceil_mode=True behavior.
# See https://pytorch.org/docs/stable/generated/torch.nn.MaxPool1d.html

# Determine the number of input elements to exactly bump up the output size
# by 1. Note that there is an additional condition to substract 1 from the
# output when ceil_mode=True and (output_size - 1) * stride >= in_size + padding
# In this case, we don't need to pad, as ceil_mode=False and True give the
# same behavior.
numerator_no_ceil = in_size + 2 * padding - dilation * (kernel_size - 1) - 1
numerator = numerator_no_ceil + stride - 1
output_size = numerator // stride + 1

needs_adjust = (output_size - 1) * stride >= in_size + padding
partial_stride = numerator_no_ceil % stride
pad_out = (
(stride - partial_stride) if partial_stride > 0 and not needs_adjust else 0
)

return pad_out
10 changes: 8 additions & 2 deletions backends/xnnpack/partition/config/generic_node_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import logging
from typing import cast, List, Optional

Expand Down Expand Up @@ -287,9 +289,13 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
if not self.check_common_constraints(node, ep):
return False

# Ceil mode is supported via op padding, which must be statically known.
is_ceil_mode = len(node.args) >= 6 and cast(bool, node.args[5])
if is_ceil_mode:
why(node, reason="ceil mode is not supported")
is_dynamic = "val" in node.meta and any(
isinstance(d, torch.SymInt) for d in node.meta["val"].shape
)
if is_ceil_mode and is_dynamic:
why(node, reason="ceil mode is not supported for dynamic shapes")
return False
return True

Expand Down
64 changes: 56 additions & 8 deletions backends/xnnpack/test/ops/test_maxpool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import itertools
import unittest

import torch
from executorch.backends.xnnpack.test.tester import Tester
from executorch.backends.xnnpack.test.tester import Export, Tester
from torch.export.dynamic_shapes import Dim


class TestMaxPool2d(unittest.TestCase):
Expand Down Expand Up @@ -38,10 +42,12 @@ def __init__(self, kernel_size=3, stride=1, padding=0, dilation=1):
def forward(self, x):
return self.max_pool2d_module(x)[1]

class MaxPool2dUnsupportedCeilMode(torch.nn.Module):
def __init__(self):
class MaxPool2dCeilMode(torch.nn.Module):
def __init__(self, kernel_size=3, stride=1, padding=0, dilation=1):
super().__init__()
self.max_pool2d_module = torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.max_pool2d_module = torch.nn.MaxPool2d(
kernel_size, stride, padding, dilation, ceil_mode=True
)

def forward(self, x):
return self.max_pool2d_module(x)
Expand Down Expand Up @@ -93,14 +99,56 @@ def test_fp32_maxpool2d_unsupported(self):
)
)

def test_fp32_maxpool2d_unsupported_ceilmode(self):
def test_fp32_maxpool2d_ceilmode(self):
input_sizes = [[17, 32], [32, 37]]
kernel_sizes = [2, 3, 12]
strides = [1, 2, 4]
padding = [0, 1, 5]
dilations = [1, 2, 3]

for input_size, kernel_size, stride, pad, dilation in itertools.product(
input_sizes, kernel_sizes, strides, padding, dilations
):
# Check XNNPACK and PyTorch constraints
if pad > ((kernel_size - 1) * dilation + 1) / 2:
continue
if stride > kernel_size:
continue
if any(
(size + 2 * pad - dilation * (kernel_size - 1) - 1) // stride + 1 <= 0
for size in input_size
): # Output size too small
continue

inputs = (torch.randn(1, 1, input_size[0], input_size[1]),)
(
Tester(
self.MaxPool2dCeilMode(kernel_size, stride, pad, dilation), inputs
)
.export()
.check_count({"torch.ops.aten.max_pool2d.default": 1})
.to_edge_transform_and_lower()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.check_not(
[
"executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default"
]
)
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
)

def test_fp32_maxpool2d_unsupported_dynamic_ceilmode(self):
"""
MaxPool2d with ceil mode is not generally supported (see maxpool2d constraint).
MaxPool2d with ceil mode is supported with dynamic shape (see maxpool2d constraint).
"""
inputs = (torch.randn(1, 32, 23, 23),)
dim3 = Dim("_dim3", min=11, max=50)
dynamic_shapes = {"x": {3: 2 * dim3 - 1}}
(
Tester(self.MaxPool2dUnsupportedCeilMode(), inputs)
.export()
Tester(self.MaxPool2dCeilMode(), inputs)
.export(Export(dynamic_shapes=dynamic_shapes))
.check_count({"torch.ops.aten.max_pool2d.default": 1})
.to_edge_transform_and_lower()
# We expect it not be be delegated.
Expand Down