12
12
import torch
13
13
14
14
from executorch .backends .xnnpack .partition .configs import (
15
- _SUPPORTED_MODULES_WITH_DYNAMIC_SHAPE ,
16
- _SUPPORTED_OPS_WITH_DYNAMIC_SHAPE ,
15
+ STATIC_MODULES ,
16
+ STATIC_OPS ,
17
17
SUPPORTED_DYN_QUANT_LINEAR_MODULES ,
18
18
SUPPORTED_DYN_QUANT_MODULES ,
19
19
SUPPORTED_MODULES ,
@@ -838,7 +838,7 @@ def __init__(
838
838
supported_quant_modules : List [Callable ] = SUPPORTED_QUANT_MODULES ,
839
839
supported_quant_ops : Optional [List [Callable ]] = SUPPORTED_QUANT_OPS ,
840
840
quant : Optional [bool ] = None ,
841
- _only_ops_with_dynamic_shape_support : Optional [ bool ] = False ,
841
+ has_dynamic_shapes : bool = False ,
842
842
_lower_recomposed_sdpa : Optional [bool ] = True ,
843
843
):
844
844
super ().__init__ ()
@@ -851,44 +851,16 @@ def __init__(
851
851
852
852
self .quant = quant
853
853
854
- if _only_ops_with_dynamic_shape_support is True :
855
- self ._update_op_lists_for_dynamic_shapes ()
856
-
857
854
# TODO(T174256335) - remove this once we have a better way to handle >2d Mask
858
855
self ._lower_recomposed_sdpa : bool = _lower_recomposed_sdpa or True
859
856
860
857
self .delegation_spec = DelegationSpec (XnnpackBackend .__name__ , [])
861
858
self .partition_tags : Dict [str , DelegationSpec ] = {}
862
859
863
- def _update_op_lists_for_dynamic_shapes (self ):
864
- # Not ready for quants yet
865
- assert (
866
- self .quant is not True
867
- ), "Dynamic shape only supported for valid FP32 ops, no quants support yet."
868
- self .supported_quant_ops = set ()
869
- self .supported_quant_modules = set ()
870
-
871
- # for supported ops
872
- self .supported_ops_with_dynamic_shape = set (_SUPPORTED_OPS_WITH_DYNAMIC_SHAPE )
873
- assert self .supported_ops_with_dynamic_shape .issubset (
874
- self .supported_ops
875
- ), "All ops with dynamic shape support must be in SUPPORTED_OPS"
876
- self .supported_ops = self .supported_ops_with_dynamic_shape
877
- log .info (
878
- f"Xnnpack Partitioner updated supported op for dynamic shapes: { self .supported_ops } "
879
- )
880
-
881
- # for supported modules
882
- self .supported_modules_with_dynamic_shape = set (
883
- _SUPPORTED_MODULES_WITH_DYNAMIC_SHAPE
884
- )
885
- assert self .supported_modules_with_dynamic_shape .issubset (
886
- self .supported_modules
887
- ), "All modules with dynamic shape support must be in SUPPORTED_MODULES"
888
- self .supported_modules = self .supported_modules_with_dynamic_shape
889
- log .info (
890
- f"Xnnpack Partitioner updated supported modules with dynamic shapes: { self .supported_modules } "
891
- )
860
+ self .has_dynamic_shapes = has_dynamic_shapes
861
+ if has_dynamic_shapes :
862
+ self .supported_ops = self .supported_ops - set (STATIC_OPS )
863
+ self .supported_modules = self .supported_modules - set (STATIC_MODULES )
892
864
893
865
def get_supported_modules (self , quant : bool ) -> Set [Callable ]:
894
866
"""
0 commit comments