@@ -56,8 +56,7 @@ def get_mps_partitioner(use_kv_cache: bool = False):
56
56
57
57
58
58
def get_coreml_partitioner (
59
- enable_state : bool = False ,
60
- preserve_sdpa : bool = True ,
59
+ ios : int = 15 ,
61
60
embedding_quantize : Optional [str ] = None ,
62
61
pt2e_quantize : Optional [str ] = None ,
63
62
coreml_quantize : Optional [str ] = None ,
@@ -75,29 +74,42 @@ def get_coreml_partitioner(
75
74
"Please install the CoreML backend follwing https://pytorch.org/executorch/main/build-run-coreml.html"
76
75
)
77
76
78
- minimum_deployment_target = ct .target .iOS15
79
- # In Core ML, stateful execution is introduced in iOS 18
80
- if enable_state :
81
- minimum_deployment_target = max (minimum_deployment_target , ct .target .iOS18 )
82
- # In Core ML, sdpa op is introduced in iOS 18
83
- if preserve_sdpa :
84
- minimum_deployment_target = max (minimum_deployment_target , ct .target .iOS18 )
85
- # In Core ML, quantization is introduced in iOS 16
86
- if embedding_quantize is not None or pt2e_quantize is not None :
87
- minimum_deployment_target = max (minimum_deployment_target , ct .target .iOS16 )
88
- # In Core ML, 8-bit activation quantization is introduced in iOS 17
89
- if (
90
- embedding_quantize is not None and int (embedding_quantize .split ("," )[0 ]) == 8
91
- ) or pt2e_quantize in ("coreml_8a_c8w" , "coreml_baseline_8a_c8w" ):
92
- minimum_deployment_target = max (minimum_deployment_target , ct .target .iOS17 )
93
- # In Core ML, 4-bit weight compression is introduced in iOS 18
94
- if (
95
- (embedding_quantize is not None and int (embedding_quantize .split ("," )[0 ]) == 4 )
96
- or pt2e_quantize in ("coreml_c4w" , "coreml_8a_c4w" , "coreml_baseline_8a_c4w" )
97
- or coreml_quantize == "b4w"
98
- ):
99
- minimum_deployment_target = max (minimum_deployment_target , ct .target .iOS18 )
77
+ def _validate_ios_version () -> None :
78
+ assert ios in (15 , 16 , 17 , 18 )
100
79
80
+ if embedding_quantize is not None and ios < 18 :
81
+ raise ValueError (
82
+ "In Core ML, per-block quantization is introduced in iOS 18"
83
+ )
84
+
85
+ use_quantization = pt2e_quantize is not None or coreml_quantize is not None
86
+ if use_quantization and ios < 16 :
87
+ raise ValueError ("In Core ML, quantization is introduced in iOS 16" )
88
+
89
+ use_8a = (pt2e_quantize is not None and "8a" in pt2e_quantize ) or (
90
+ coreml_quantize is not None and "8a" in coreml_quantize
91
+ )
92
+ if use_8a and ios < 17 :
93
+ raise ValueError (
94
+ "In Core ML, 8-bit activation quantization is introduced in iOS 17"
95
+ )
96
+
97
+ use_4w = (pt2e_quantize is not None and "4w" in pt2e_quantize ) or (
98
+ coreml_quantize is not None and "4w" in coreml_quantize
99
+ )
100
+ if use_4w and ios < 18 :
101
+ raise ValueError (
102
+ "In Core ML, 4-bit weight compression is introduced in iOS 18"
103
+ )
104
+
105
+ _validate_ios_version ()
106
+
107
+ minimum_deployment_target = {
108
+ 15 : ct .target .iOS15 ,
109
+ 16 : ct .target .iOS16 ,
110
+ 17 : ct .target .iOS17 ,
111
+ 18 : ct .target .iOS18 ,
112
+ }[ios ]
101
113
op_linear_quantizer_config = None
102
114
if coreml_quantize == "b4w" :
103
115
op_linear_quantizer_config = {
@@ -107,7 +119,6 @@ def get_coreml_partitioner(
107
119
"block_size" : 32 ,
108
120
"weight_threshold" : 512 ,
109
121
}
110
-
111
122
compile_specs = CoreMLBackend .generate_compile_specs ( # pyre-fixme[16]
112
123
minimum_deployment_target = minimum_deployment_target ,
113
124
compute_precision = ct .precision (ct .precision .FLOAT16 .value ),
@@ -116,9 +127,12 @@ def get_coreml_partitioner(
116
127
model_type = CoreMLBackend .MODEL_TYPE .MODEL , # pyre-fixme[16]
117
128
op_linear_quantizer_config = op_linear_quantizer_config ,
118
129
)
130
+
131
+ take_over_mutable_buffer = minimum_deployment_target >= ct .target .iOS18
132
+
119
133
return CoreMLPartitioner ( # pyre-fixme[16]
120
134
compile_specs = compile_specs ,
121
- take_over_mutable_buffer = enable_state ,
135
+ take_over_mutable_buffer = take_over_mutable_buffer ,
122
136
)
123
137
124
138
0 commit comments