Skip to content

Commit 953640f

Browse files
authored
Add stride constraint to XNN MaxPool
Differential Revision: D67380978 Pull Request resolved: #7354
1 parent fc04436 commit 953640f

File tree

3 files changed

+51
-11
lines changed

3 files changed

+51
-11
lines changed

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from executorch.exir.backend.canonical_partitioners.config_partitioner import (
2020
format_target_name,
2121
)
22-
from executorch.exir.backend.utils import WhyNoPartition
22+
from executorch.exir.backend.utils import is_shape_dynamic, WhyNoPartition
2323
from torch.export import ExportedProgram
2424

2525
logger = logging.getLogger(__name__)
@@ -284,19 +284,27 @@ class MaxPool2dConfig(GenericNodePartitionerConfig):
284284

285285
def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
286286
"""
287-
XNNPACK's maxpool2d does not support ceil mode
287+
XNNPACK's maxpool2d does not support ceil mode and requires stride <= kernel_size
288288
"""
289289
if not self.check_common_constraints(node, ep):
290290
return False
291291

292-
# Ceil mode is supported via op padding, which must be statically known.
292+
kernel_size = node.args[1]
293+
stride = node.args[2]
293294
is_ceil_mode = len(node.args) >= 6 and cast(bool, node.args[5])
294-
is_dynamic = "val" in node.meta and any(
295-
isinstance(d, torch.SymInt) for d in node.meta["val"].shape
296-
)
297-
if is_ceil_mode and is_dynamic:
295+
296+
# Ceil mode is supported via op padding, which must be statically known.
297+
if is_ceil_mode and is_shape_dynamic(node):
298298
why(node, reason="ceil mode is not supported for dynamic shapes")
299299
return False
300+
301+
if stride[0] > kernel_size[0] or stride[1] > kernel_size[1]: # pyre-ignore[16]
302+
why(
303+
node,
304+
reason=f"stride ({stride}) must be less than or equal to kernel size ({kernel_size})",
305+
)
306+
return False
307+
300308
return True
301309

302310
def supported_precision_types(self) -> List[ConfigPrecisionType]:
@@ -316,10 +324,7 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
316324
if not self.check_common_constraints(node, ep):
317325
return False
318326

319-
is_output_dynamic = "val" in node.meta and any(
320-
isinstance(d, torch.SymInt) for d in node.meta["val"].shape
321-
)
322-
if is_output_dynamic:
327+
if is_shape_dynamic(node):
323328
why(node, reason="dynamic output sizes are not supported")
324329
return False
325330
return True

backends/xnnpack/test/ops/test_maxpool2d.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,28 @@ def test_fp32_maxpool2d_unsupported_dynamic_ceilmode(self):
163163
.run_method_and_compare_outputs()
164164
)
165165

166+
def test_fp32_maxpool2d_unsupported_stride(self):
167+
"""
168+
XNNPACK MaxPool2d requires stride <= kernel_size.
169+
"""
170+
inputs = (torch.randn(1, 32, 23, 23),)
171+
(
172+
Tester(self.MaxPool2d(kernel_size=2, stride=3), inputs)
173+
.export()
174+
.check_count({"torch.ops.aten.max_pool2d.default": 1})
175+
.to_edge_transform_and_lower()
176+
# We expect it not be be delegated.
177+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 0})
178+
.check_count(
179+
{
180+
"executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default": 1
181+
}
182+
)
183+
.to_executorch()
184+
.serialize()
185+
.run_method_and_compare_outputs()
186+
)
187+
166188
def test_qs8_maxpool2d(self):
167189
class MaxPool(torch.nn.Module):
168190
def __init__(self, maxpool_params):

exir/backend/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-unsafe
8+
79
import logging
810
import operator
911
from collections import defaultdict
@@ -417,6 +419,17 @@ def tag_mutated_buffer(edge_program: ExportedProgram) -> None:
417419
node.meta["delegation_tag"] = user_tags.pop()
418420

419421

422+
def is_shape_dynamic(node: torch.fx.Node) -> bool:
423+
"""
424+
Check if the node shape is dynamic.
425+
"""
426+
427+
# Shape is dynamic if any of the dimensions don't evaluate to a static value
428+
return "val" in node.meta and any(
429+
isinstance(d, torch.SymInt) for d in node.meta["val"].shape
430+
)
431+
432+
420433
# TODO - style: use templated types
421434
class DelegateMappingBuilder:
422435
"""

0 commit comments

Comments
 (0)