12
12
import torch
13
13
14
14
from executorch .backends .xnnpack .partition .configs import (
15
- _SUPPORTED_MODULES_WITH_DYNAMIC_SHAPE ,
16
- _SUPPORTED_OPS_WITH_DYNAMIC_SHAPE ,
15
+ STATIC_MODULES ,
16
+ STATIC_OPS ,
17
17
SUPPORTED_DYN_QUANT_LINEAR_MODULES ,
18
18
SUPPORTED_DYN_QUANT_MODULES ,
19
19
SUPPORTED_MODULES ,
@@ -94,6 +94,7 @@ def __init__(
94
94
] = _OP_SUPPORT_CONSTRAINTS ,
95
95
supported_ops : Optional [List ] = None ,
96
96
unsupported_modules : Optional [List ] = None ,
97
+ dynamic_shapes = False ,
97
98
):
98
99
"""
99
100
@Arg constraints_dict: Dict mapping each node to a lambda function that
@@ -111,6 +112,7 @@ def __init__(
111
112
exir_ops .edge .aten .mm .default ,
112
113
exir_ops .edge .aten .bmm .default ,
113
114
}
115
+ self .dynamic_shapes = dynamic_shapes
114
116
assert len (self .constraints )
115
117
116
118
def _check_inputs_are_valid_dtypes (self , node , valid_dtypes ):
@@ -838,7 +840,7 @@ def __init__(
838
840
supported_quant_modules : List [Callable ] = SUPPORTED_QUANT_MODULES ,
839
841
supported_quant_ops : Optional [List [Callable ]] = SUPPORTED_QUANT_OPS ,
840
842
quant : Optional [bool ] = None ,
841
- _only_ops_with_dynamic_shape_support : Optional [ bool ] = False ,
843
+ dynamic_shape : bool = False ,
842
844
_lower_recomposed_sdpa : Optional [bool ] = True ,
843
845
):
844
846
super ().__init__ ()
@@ -851,44 +853,16 @@ def __init__(
851
853
852
854
self .quant = quant
853
855
854
- if _only_ops_with_dynamic_shape_support is True :
855
- self ._update_op_lists_for_dynamic_shapes ()
856
-
857
856
# TODO(T174256335) - remove this once we have a better way to handle >2d Mask
858
857
self ._lower_recomposed_sdpa : bool = _lower_recomposed_sdpa or True
859
858
860
859
self .delegation_spec = DelegationSpec (XnnpackBackend .__name__ , [])
861
860
self .partition_tags : Dict [str , DelegationSpec ] = {}
862
861
863
- def _update_op_lists_for_dynamic_shapes (self ):
864
- # Not ready for quants yet
865
- assert (
866
- self .quant is not True
867
- ), "Dynamic shape only supported for valid FP32 ops, no quants support yet."
868
- self .supported_quant_ops = set ()
869
- self .supported_quant_modules = set ()
870
-
871
- # for supported ops
872
- self .supported_ops_with_dynamic_shape = set (_SUPPORTED_OPS_WITH_DYNAMIC_SHAPE )
873
- assert self .supported_ops_with_dynamic_shape .issubset (
874
- self .supported_ops
875
- ), "All ops with dynamic shape support must be in SUPPORTED_OPS"
876
- self .supported_ops = self .supported_ops_with_dynamic_shape
877
- log .info (
878
- f"Xnnpack Partitioner updated supported op for dynamic shapes: { self .supported_ops } "
879
- )
880
-
881
- # for supported modules
882
- self .supported_modules_with_dynamic_shape = set (
883
- _SUPPORTED_MODULES_WITH_DYNAMIC_SHAPE
884
- )
885
- assert self .supported_modules_with_dynamic_shape .issubset (
886
- self .supported_modules
887
- ), "All modules with dynamic shape support must be in SUPPORTED_MODULES"
888
- self .supported_modules = self .supported_modules_with_dynamic_shape
889
- log .info (
890
- f"Xnnpack Partitioner updated supported modules with dynamic shapes: { self .supported_modules } "
891
- )
862
+ self .dynamic_shape = dynamic_shape
863
+ if dynamic_shape :
864
+ self .supported_ops = self .supported_ops - set (STATIC_OPS )
865
+ self .supported_modules = self .supported_modules - set (STATIC_MODULES )
892
866
893
867
def get_supported_modules (self , quant : bool ) -> Set [Callable ]:
894
868
"""
0 commit comments