Skip to content

add flag for dynamic shapes to filter out static ops #3733

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions backends/xnnpack/partition/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,12 @@

SUPPORTED_DYN_QUANT_MODULES = SUPPORTED_DYN_QUANT_LINEAR_MODULES

# TODO delete this once we catch up to 100% of the supported op with dynamic shape support.
# This is tobe used only during the transition when we may not want to partition all the
# nodes for a dynamic model.
_SUPPORTED_OPS_WITH_DYNAMIC_SHAPE = [
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.mul.Tensor,
]
_SUPPORTED_MODULES_WITH_DYNAMIC_SHAPE = [
torch.nn.Conv1d,
torch.nn.Conv2d,
# XNNPACK supports majority of shape dynamism, however some ops are
# explicitly static, so we maintain a set here to exclude them from
# dynamic shape support.
STATIC_OPS = [
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.slice_copy.Tensor,
]

STATIC_MODULES = []
46 changes: 11 additions & 35 deletions backends/xnnpack/partition/xnnpack_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import torch

from executorch.backends.xnnpack.partition.configs import (
_SUPPORTED_MODULES_WITH_DYNAMIC_SHAPE,
_SUPPORTED_OPS_WITH_DYNAMIC_SHAPE,
STATIC_MODULES,
STATIC_OPS,
SUPPORTED_DYN_QUANT_LINEAR_MODULES,
SUPPORTED_DYN_QUANT_MODULES,
SUPPORTED_MODULES,
Expand Down Expand Up @@ -838,7 +838,7 @@ def __init__(
supported_quant_modules: List[Callable] = SUPPORTED_QUANT_MODULES,
supported_quant_ops: Optional[List[Callable]] = SUPPORTED_QUANT_OPS,
quant: Optional[bool] = None,
_only_ops_with_dynamic_shape_support: Optional[bool] = False,
has_dynamic_shapes: bool = False,
_lower_recomposed_sdpa: Optional[bool] = True,
):
super().__init__()
Expand All @@ -851,44 +851,20 @@ def __init__(

self.quant = quant

if _only_ops_with_dynamic_shape_support is True:
self._update_op_lists_for_dynamic_shapes()

# TODO(T174256335) - remove this once we have a better way to handle >2d Mask
self._lower_recomposed_sdpa: bool = _lower_recomposed_sdpa or True

self.delegation_spec = DelegationSpec(XnnpackBackend.__name__, [])
self.partition_tags: Dict[str, DelegationSpec] = {}

def _update_op_lists_for_dynamic_shapes(self):
# Not ready for quants yet
assert (
self.quant is not True
), "Dynamic shape only supported for valid FP32 ops, no quants support yet."
self.supported_quant_ops = set()
self.supported_quant_modules = set()

# for supported ops
self.supported_ops_with_dynamic_shape = set(_SUPPORTED_OPS_WITH_DYNAMIC_SHAPE)
assert self.supported_ops_with_dynamic_shape.issubset(
self.supported_ops
), "All ops with dynamic shape support must be in SUPPORTED_OPS"
self.supported_ops = self.supported_ops_with_dynamic_shape
log.info(
f"Xnnpack Partitioner updated supported op for dynamic shapes: {self.supported_ops}"
)

# for supported modules
self.supported_modules_with_dynamic_shape = set(
_SUPPORTED_MODULES_WITH_DYNAMIC_SHAPE
)
assert self.supported_modules_with_dynamic_shape.issubset(
self.supported_modules
), "All modules with dynamic shape support must be in SUPPORTED_MODULES"
self.supported_modules = self.supported_modules_with_dynamic_shape
log.info(
f"Xnnpack Partitioner updated supported modules with dynamic shapes: {self.supported_modules}"
)
self.has_dynamic_shapes = has_dynamic_shapes
if has_dynamic_shapes:
self.supported_ops = self.supported_ops - set(STATIC_OPS)
self.supported_modules = self.supported_modules - set(STATIC_MODULES)
self.supported_quant_ops = self.supported_quant_ops - set(STATIC_OPS)
self.supported_quant_modules = self.supported_quant_modules - set(
STATIC_MODULES
)

def get_supported_modules(self, quant: bool) -> Set[Callable]:
"""
Expand Down
23 changes: 22 additions & 1 deletion backends/xnnpack/test/ops/slice_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import unittest

import torch
from executorch.backends.xnnpack.test.tester import Tester
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.backends.xnnpack.test.tester import Partition, Tester


class TestSliceCopy(unittest.TestCase):
Expand Down Expand Up @@ -112,6 +113,26 @@ def forward(self, x):
.check_not(["torch.ops.higher_order.executorch_call_delegate"])
)

def test_fp32_static_slice_with_dynamic_dim(self):
"""
XNNPACK does not support dynamic dims with static slice
"""

class SliceCopy(torch.nn.Module):
def forward(self, x):
return x[1:3, -2:, :-1]

inputs = (torch.randn(5, 5, 5),)
(
Tester(SliceCopy(), inputs)
.export()
.to_edge()
.partition(
Partition(partitioner=XnnpackPartitioner(has_dynamic_shapes=True))
)
.check_not(["torch.ops.higher_order.executorch_call_delegate"])
)

# Note: Slice ends up as slice_copy later in the process, but during quantization,
# it's still slice, which isn't supported by the XNNPACK quantizer.
@unittest.skip("T156004676 - slice isn't propagated")
Expand Down
Loading