Skip to content

Commit c20ef9f

Browse files
author
yifan_shen3
committed
support embedding quantize
1 parent 6a7ccca commit c20ef9f

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,9 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
511511

512512
if args.coreml:
513513
coreml_partitioner = get_coreml_partitioner(
514-
args.use_kv_cache and args.coreml_enable_state, args.pt2e_quantize
514+
args.use_kv_cache and args.coreml_enable_state,
515+
args.embedding_quantize,
516+
args.pt2e_quantize,
515517
)
516518
partitioners.append(coreml_partitioner)
517519
modelname = f"coreml_{modelname}"

extension/llm/export/partitioner_lib.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ def get_mps_partitioner(use_kv_cache: bool = False):
5656

5757

5858
def get_coreml_partitioner(
59-
enable_state: bool = False, pt2e_quantize: Optional[str] = None
59+
enable_state: bool = False,
60+
embedding_quantize: Optional[str] = None,
61+
pt2e_quantize: Optional[str] = None,
6062
):
6163
try:
6264
import coremltools as ct
@@ -76,13 +78,17 @@ def get_coreml_partitioner(
7678
if enable_state:
7779
minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18)
7880
# In Core ML, quantization is introduced in iOS 16
79-
if pt2e_quantize is not None:
81+
if embedding_quantize is not None or pt2e_quantize is not None:
8082
minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS16)
8183
# In Core ML, 8-bit activation quantization is introduced in iOS 17
82-
if pt2e_quantize in ("coreml_8a_c8w", "coreml_baseline_8a_c8w"):
84+
if (
85+
embedding_quantize is not None and int(embedding_quantize.split(",")[0]) == 8
86+
) or pt2e_quantize in ("coreml_8a_c8w", "coreml_baseline_8a_c8w"):
8387
minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS17)
8488
# In Core ML, 4-bit weight compression is introduced in iOS 18
85-
if pt2e_quantize in ("coreml_c4w", "coreml_8a_c4w", "coreml_baseline_8a_c4w"):
89+
if (
90+
embedding_quantize is not None and int(embedding_quantize.split(",")[0]) == 4
91+
) or pt2e_quantize in ("coreml_c4w", "coreml_8a_c4w", "coreml_baseline_8a_c4w"):
8692
minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18)
8793

8894
compile_specs = CoreMLBackend.generate_compile_specs( # pyre-fixme[16]

0 commit comments

Comments
 (0)