Skip to content

Commit c080c48

Browse files
yifan_shen3facebook-github-bot
authored andcommitted
Polish UI (#5319)
Summary: Enable all iOS 18 features by a single `--coreml-ios 18` argument, rather than one arg for each individual feature Export our current best model by ``` python -m examples.models.llama2.export_llama \ -c <download-path>/consolidated.00.pth \ -p <download-path>/params.json \ --disable_dynamic_shape -kv \ --coreml \ --coreml-ios 18 \ --coreml-quantize b4w ``` i.e. * FP16 activation * Static sequence length * In-place KV cache * Fused scaled dot product attention * 4-bit per-block weight Pull Request resolved: #5319 Reviewed By: kirklandsign Differential Revision: D62616303 Pulled By: cccclai fbshipit-source-id: 416a8acec785662ad76d04ea0b949ce3393308fe
1 parent 9301ebb commit c080c48

File tree

2 files changed

+50
-29
lines changed

2 files changed

+50
-29
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,13 @@ def build_args_parser() -> argparse.ArgumentParser:
317317
choices=["b4w"],
318318
help="This option is only for coreml: Use coreml quantization, e.g. b4w (for blockwise 4 bit weight)",
319319
)
320+
parser.add_argument(
321+
"--coreml-ios",
322+
type=int,
323+
default=15,
324+
choices=(15, 16, 17, 18),
325+
help="This option is only for coreml: The minimum iOS version to deploy",
326+
)
320327
parser.add_argument(
321328
"--qnn",
322329
action="store_true",
@@ -533,8 +540,7 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
533540

534541
if args.coreml:
535542
coreml_partitioner = get_coreml_partitioner(
536-
args.use_kv_cache and args.coreml_enable_state,
537-
args.coreml_preserve_sdpa,
543+
args.coreml_ios,
538544
args.embedding_quantize,
539545
args.pt2e_quantize,
540546
args.coreml_quantize,
@@ -810,7 +816,8 @@ def _get_source_transforms( # noqa
810816
transforms.append(replace_causal_mask)
811817

812818
elif args.coreml:
813-
if args.coreml_preserve_sdpa:
819+
# iOS 18 introduced fused sdpa op
820+
if args.coreml_ios >= 18:
814821
transforms.append(replace_sdpa_with_coreml_sdpa)
815822
else:
816823
transforms.append(replace_sdpa_with_simple_sdpa)

extension/llm/export/partitioner_lib.py

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,7 @@ def get_mps_partitioner(use_kv_cache: bool = False):
5656

5757

5858
def get_coreml_partitioner(
59-
enable_state: bool = False,
60-
preserve_sdpa: bool = True,
59+
ios: int = 15,
6160
embedding_quantize: Optional[str] = None,
6261
pt2e_quantize: Optional[str] = None,
6362
coreml_quantize: Optional[str] = None,
@@ -75,29 +74,42 @@ def get_coreml_partitioner(
7574
"Please install the CoreML backend follwing https://pytorch.org/executorch/main/build-run-coreml.html"
7675
)
7776

78-
minimum_deployment_target = ct.target.iOS15
79-
# In Core ML, stateful execution is introduced in iOS 18
80-
if enable_state:
81-
minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18)
82-
# In Core ML, sdpa op is introduced in iOS 18
83-
if preserve_sdpa:
84-
minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18)
85-
# In Core ML, quantization is introduced in iOS 16
86-
if embedding_quantize is not None or pt2e_quantize is not None:
87-
minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS16)
88-
# In Core ML, 8-bit activation quantization is introduced in iOS 17
89-
if (
90-
embedding_quantize is not None and int(embedding_quantize.split(",")[0]) == 8
91-
) or pt2e_quantize in ("coreml_8a_c8w", "coreml_baseline_8a_c8w"):
92-
minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS17)
93-
# In Core ML, 4-bit weight compression is introduced in iOS 18
94-
if (
95-
(embedding_quantize is not None and int(embedding_quantize.split(",")[0]) == 4)
96-
or pt2e_quantize in ("coreml_c4w", "coreml_8a_c4w", "coreml_baseline_8a_c4w")
97-
or coreml_quantize == "b4w"
98-
):
99-
minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18)
77+
def _validate_ios_version() -> None:
78+
assert ios in (15, 16, 17, 18)
10079

80+
if embedding_quantize is not None and ios < 18:
81+
raise ValueError(
82+
"In Core ML, per-block quantization is introduced in iOS 18"
83+
)
84+
85+
use_quantization = pt2e_quantize is not None or coreml_quantize is not None
86+
if use_quantization and ios < 16:
87+
raise ValueError("In Core ML, quantization is introduced in iOS 16")
88+
89+
use_8a = (pt2e_quantize is not None and "8a" in pt2e_quantize) or (
90+
coreml_quantize is not None and "8a" in coreml_quantize
91+
)
92+
if use_8a and ios < 17:
93+
raise ValueError(
94+
"In Core ML, 8-bit activation quantization is introduced in iOS 17"
95+
)
96+
97+
use_4w = (pt2e_quantize is not None and "4w" in pt2e_quantize) or (
98+
coreml_quantize is not None and "4w" in coreml_quantize
99+
)
100+
if use_4w and ios < 18:
101+
raise ValueError(
102+
"In Core ML, 4-bit weight compression is introduced in iOS 18"
103+
)
104+
105+
_validate_ios_version()
106+
107+
minimum_deployment_target = {
108+
15: ct.target.iOS15,
109+
16: ct.target.iOS16,
110+
17: ct.target.iOS17,
111+
18: ct.target.iOS18,
112+
}[ios]
101113
op_linear_quantizer_config = None
102114
if coreml_quantize == "b4w":
103115
op_linear_quantizer_config = {
@@ -107,7 +119,6 @@ def get_coreml_partitioner(
107119
"block_size": 32,
108120
"weight_threshold": 512,
109121
}
110-
111122
compile_specs = CoreMLBackend.generate_compile_specs( # pyre-fixme[16]
112123
minimum_deployment_target=minimum_deployment_target,
113124
compute_precision=ct.precision(ct.precision.FLOAT16.value),
@@ -116,9 +127,12 @@ def get_coreml_partitioner(
116127
model_type=CoreMLBackend.MODEL_TYPE.MODEL, # pyre-fixme[16]
117128
op_linear_quantizer_config=op_linear_quantizer_config,
118129
)
130+
131+
take_over_mutable_buffer = minimum_deployment_target >= ct.target.iOS18
132+
119133
return CoreMLPartitioner( # pyre-fixme[16]
120134
compile_specs=compile_specs,
121-
take_over_mutable_buffer=enable_state,
135+
take_over_mutable_buffer=take_over_mutable_buffer,
122136
)
123137

124138

0 commit comments

Comments
 (0)