Skip to content

Commit 283b31b

Browse files
mcr229facebook-github-bot
authored andcommitted
add flag for dynamic shapes to filter out static ops (#3733)
Summary: Since we have updated XNNPACK to support almost 100% of dynamic shape ops, we can now create static_op lists which do not have any dynamic shape support and filter them out instead Reviewed By: digantdesai Differential Revision: D57787384
1 parent f42942a commit 283b31b

File tree

2 files changed

+15
-45
lines changed

2 files changed

+15
-45
lines changed

backends/xnnpack/partition/configs.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -144,14 +144,12 @@
144144

145145
SUPPORTED_DYN_QUANT_MODULES = SUPPORTED_DYN_QUANT_LINEAR_MODULES
146146

147-
# TODO delete this once we catch up to 100% of the supported op with dynamic shape support.
148-
# This is tobe used only during the transition when we may not want to partition all the
149-
# nodes for a dynamic model.
150-
_SUPPORTED_OPS_WITH_DYNAMIC_SHAPE = [
151-
exir_ops.edge.aten.add.Tensor,
152-
exir_ops.edge.aten.mul.Tensor,
153-
]
154-
_SUPPORTED_MODULES_WITH_DYNAMIC_SHAPE = [
155-
torch.nn.Conv1d,
156-
torch.nn.Conv2d,
147+
# XNNPACK supports majority of shape dynamism, however some ops are
148+
# explicitly static, so we maintain a set here to exclude them from
149+
# dynamic shape support.
150+
STATIC_OPS = [
151+
exir_ops.edge.aten.cat.default,
152+
exir_ops.edge.aten.slice_copy.Tensor,
157153
]
154+
155+
STATIC_MODULES = []

backends/xnnpack/partition/xnnpack_partitioner.py

Lines changed: 7 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
import torch
1313

1414
from executorch.backends.xnnpack.partition.configs import (
15-
_SUPPORTED_MODULES_WITH_DYNAMIC_SHAPE,
16-
_SUPPORTED_OPS_WITH_DYNAMIC_SHAPE,
15+
STATIC_MODULES,
16+
STATIC_OPS,
1717
SUPPORTED_DYN_QUANT_LINEAR_MODULES,
1818
SUPPORTED_DYN_QUANT_MODULES,
1919
SUPPORTED_MODULES,
@@ -838,7 +838,7 @@ def __init__(
838838
supported_quant_modules: List[Callable] = SUPPORTED_QUANT_MODULES,
839839
supported_quant_ops: Optional[List[Callable]] = SUPPORTED_QUANT_OPS,
840840
quant: Optional[bool] = None,
841-
_only_ops_with_dynamic_shape_support: Optional[bool] = False,
841+
has_dynamic_shapes: bool = False,
842842
_lower_recomposed_sdpa: Optional[bool] = True,
843843
):
844844
super().__init__()
@@ -851,44 +851,16 @@ def __init__(
851851

852852
self.quant = quant
853853

854-
if _only_ops_with_dynamic_shape_support is True:
855-
self._update_op_lists_for_dynamic_shapes()
856-
857854
# TODO(T174256335) - remove this once we have a better way to handle >2d Mask
858855
self._lower_recomposed_sdpa: bool = _lower_recomposed_sdpa or True
859856

860857
self.delegation_spec = DelegationSpec(XnnpackBackend.__name__, [])
861858
self.partition_tags: Dict[str, DelegationSpec] = {}
862859

863-
def _update_op_lists_for_dynamic_shapes(self):
864-
# Not ready for quants yet
865-
assert (
866-
self.quant is not True
867-
), "Dynamic shape only supported for valid FP32 ops, no quants support yet."
868-
self.supported_quant_ops = set()
869-
self.supported_quant_modules = set()
870-
871-
# for supported ops
872-
self.supported_ops_with_dynamic_shape = set(_SUPPORTED_OPS_WITH_DYNAMIC_SHAPE)
873-
assert self.supported_ops_with_dynamic_shape.issubset(
874-
self.supported_ops
875-
), "All ops with dynamic shape support must be in SUPPORTED_OPS"
876-
self.supported_ops = self.supported_ops_with_dynamic_shape
877-
log.info(
878-
f"Xnnpack Partitioner updated supported op for dynamic shapes: {self.supported_ops}"
879-
)
880-
881-
# for supported modules
882-
self.supported_modules_with_dynamic_shape = set(
883-
_SUPPORTED_MODULES_WITH_DYNAMIC_SHAPE
884-
)
885-
assert self.supported_modules_with_dynamic_shape.issubset(
886-
self.supported_modules
887-
), "All modules with dynamic shape support must be in SUPPORTED_MODULES"
888-
self.supported_modules = self.supported_modules_with_dynamic_shape
889-
log.info(
890-
f"Xnnpack Partitioner updated supported modules with dynamic shapes: {self.supported_modules}"
891-
)
860+
self.has_dynamic_shapes = has_dynamic_shapes
861+
if has_dynamic_shapes:
862+
self.supported_ops = self.supported_ops - set(STATIC_OPS)
863+
self.supported_modules = self.supported_modules - set(STATIC_MODULES)
892864

893865
def get_supported_modules(self, quant: bool) -> Set[Callable]:
894866
"""

0 commit comments

Comments
 (0)