Skip to content

Commit 5ca5282

Browse files
mcr229facebook-github-bot
authored andcommitted
add flag for dynamic shapes to filter out static ops (#3733)
Summary: Since we have updated XNNPACK to support almost 100% of dynamic shape ops, we can now create static_op lists which do not have any dynamic shape support and filter them out instead Differential Revision: D57787384
1 parent 1343224 commit 5ca5282

File tree

2 files changed

+17
-45
lines changed

2 files changed

+17
-45
lines changed

backends/xnnpack/partition/configs.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -144,14 +144,12 @@
144144

145145
SUPPORTED_DYN_QUANT_MODULES = SUPPORTED_DYN_QUANT_LINEAR_MODULES
146146

147-
# TODO delete this once we catch up to 100% of the supported op with dynamic shape support.
148-
# This is tobe used only during the transition when we may not want to partition all the
149-
# nodes for a dynamic model.
150-
_SUPPORTED_OPS_WITH_DYNAMIC_SHAPE = [
151-
exir_ops.edge.aten.add.Tensor,
152-
exir_ops.edge.aten.mul.Tensor,
153-
]
154-
_SUPPORTED_MODULES_WITH_DYNAMIC_SHAPE = [
155-
torch.nn.Conv1d,
156-
torch.nn.Conv2d,
147+
# XNNPACK supports majority of shape dynamism, however some ops are
148+
# explicitly static, so we maintain a set here to exclude them from
149+
# dynamic shape support.
150+
STATIC_OPS = [
151+
exir_ops.edge.aten.cat.default,
152+
exir_ops.edge.aten.slice_copy.Tensor,
157153
]
154+
155+
STATIC_MODULES = []

backends/xnnpack/partition/xnnpack_partitioner.py

Lines changed: 9 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
import torch
1313

1414
from executorch.backends.xnnpack.partition.configs import (
15-
_SUPPORTED_MODULES_WITH_DYNAMIC_SHAPE,
16-
_SUPPORTED_OPS_WITH_DYNAMIC_SHAPE,
15+
STATIC_MODULES,
16+
STATIC_OPS,
1717
SUPPORTED_DYN_QUANT_LINEAR_MODULES,
1818
SUPPORTED_DYN_QUANT_MODULES,
1919
SUPPORTED_MODULES,
@@ -94,6 +94,7 @@ def __init__(
9494
] = _OP_SUPPORT_CONSTRAINTS,
9595
supported_ops: Optional[List] = None,
9696
unsupported_modules: Optional[List] = None,
97+
dynamic_shapes=False,
9798
):
9899
"""
99100
@Arg constraints_dict: Dict mapping each node to a lambda function that
@@ -111,6 +112,7 @@ def __init__(
111112
exir_ops.edge.aten.mm.default,
112113
exir_ops.edge.aten.bmm.default,
113114
}
115+
self.dynamic_shapes = dynamic_shapes
114116
assert len(self.constraints)
115117

116118
def _check_inputs_are_valid_dtypes(self, node, valid_dtypes):
@@ -838,7 +840,7 @@ def __init__(
838840
supported_quant_modules: List[Callable] = SUPPORTED_QUANT_MODULES,
839841
supported_quant_ops: Optional[List[Callable]] = SUPPORTED_QUANT_OPS,
840842
quant: Optional[bool] = None,
841-
_only_ops_with_dynamic_shape_support: Optional[bool] = False,
843+
dynamic_shape: bool = False,
842844
_lower_recomposed_sdpa: Optional[bool] = True,
843845
):
844846
super().__init__()
@@ -851,44 +853,16 @@ def __init__(
851853

852854
self.quant = quant
853855

854-
if _only_ops_with_dynamic_shape_support is True:
855-
self._update_op_lists_for_dynamic_shapes()
856-
857856
# TODO(T174256335) - remove this once we have a better way to handle >2d Mask
858857
self._lower_recomposed_sdpa: bool = _lower_recomposed_sdpa or True
859858

860859
self.delegation_spec = DelegationSpec(XnnpackBackend.__name__, [])
861860
self.partition_tags: Dict[str, DelegationSpec] = {}
862861

863-
def _update_op_lists_for_dynamic_shapes(self):
864-
# Not ready for quants yet
865-
assert (
866-
self.quant is not True
867-
), "Dynamic shape only supported for valid FP32 ops, no quants support yet."
868-
self.supported_quant_ops = set()
869-
self.supported_quant_modules = set()
870-
871-
# for supported ops
872-
self.supported_ops_with_dynamic_shape = set(_SUPPORTED_OPS_WITH_DYNAMIC_SHAPE)
873-
assert self.supported_ops_with_dynamic_shape.issubset(
874-
self.supported_ops
875-
), "All ops with dynamic shape support must be in SUPPORTED_OPS"
876-
self.supported_ops = self.supported_ops_with_dynamic_shape
877-
log.info(
878-
f"Xnnpack Partitioner updated supported op for dynamic shapes: {self.supported_ops}"
879-
)
880-
881-
# for supported modules
882-
self.supported_modules_with_dynamic_shape = set(
883-
_SUPPORTED_MODULES_WITH_DYNAMIC_SHAPE
884-
)
885-
assert self.supported_modules_with_dynamic_shape.issubset(
886-
self.supported_modules
887-
), "All modules with dynamic shape support must be in SUPPORTED_MODULES"
888-
self.supported_modules = self.supported_modules_with_dynamic_shape
889-
log.info(
890-
f"Xnnpack Partitioner updated supported modules with dynamic shapes: {self.supported_modules}"
891-
)
862+
self.dynamic_shape = dynamic_shape
863+
if dynamic_shape:
864+
self.supported_ops = self.supported_ops - set(STATIC_OPS)
865+
self.supported_modules = self.supported_modules - set(STATIC_MODULES)
892866

893867
def get_supported_modules(self, quant: bool) -> Set[Callable]:
894868
"""

0 commit comments

Comments
 (0)