Skip to content

Commit 68d345c

Browse files
author
yifan_shen3
committed
Add Core ML quantizer options then use them in Llama export; use appropriate iOS version accordingly
1 parent 5a20a49 commit 68d345c

File tree

3 files changed

+85
-2
lines changed

3 files changed

+85
-2
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
)
3636

3737
from executorch.extension.llm.export.quantizer_lib import (
38+
get_coreml_quantizer,
3839
get_pt2e_quantization_params,
3940
get_pt2e_quantizers,
4041
get_qnn_quantizer,
@@ -128,6 +129,11 @@ def build_args_parser() -> argparse.ArgumentParser:
128129
"qnn_8a8w",
129130
"qnn_16a16w",
130131
"qnn_16a4w",
132+
"coreml_c4w",
133+
"coreml_8a_c8w",
134+
"coreml_8a_c4w",
135+
"coreml_baseline_8a_c8w",
136+
"coreml_baseline_8a_c4w",
131137
],
132138
help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.",
133139
)
@@ -416,6 +422,10 @@ def get_quantizer_and_quant_params(args):
416422
args.pt2e_quantize, args.quantization_mode
417423
)
418424
quantizers.append(qnn_quantizer)
425+
if args.coreml and args.pt2e_quantize:
426+
assert len(quantizers) == 0, "Should not enable both xnnpack / qnn and coreml"
427+
coreml_quantizer = get_coreml_quantizer(args.pt2e_quantize)
428+
quantizers.append(coreml_quantizer)
419429
logging.info(f"Applying quantizers: {quantizers}")
420430
return pt2e_quant_params, quantizers, quant_dtype
421431

@@ -469,7 +479,10 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
469479
modelname = f"mps_{modelname}"
470480

471481
if args.coreml:
472-
partitioners.append(get_coreml_partitioner(args.use_kv_cache))
482+
coreml_partitioner = get_coreml_partitioner(
483+
args.use_kv_cache, args.pt2e_quantize
484+
)
485+
partitioners.append(coreml_partitioner)
473486
modelname = f"coreml_{modelname}"
474487

475488
if args.qnn:

extension/llm/export/partitioner_lib.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ def get_mps_partitioner(use_kv_cache: bool = False):
5555
return MPSPartitioner(compile_specs)
5656

5757

58-
def get_coreml_partitioner(use_kv_cache: bool = False):
58+
def get_coreml_partitioner(
59+
use_kv_cache: bool = False, pt2e_quantize: Optional[str] = None
60+
):
5961
assert (
6062
use_kv_cache is True
6163
), "CoreML backend currently only supports static shape and use_kv_cache=True is the only way to support it at the moment"
@@ -72,7 +74,26 @@ def get_coreml_partitioner(use_kv_cache: bool = False):
7274
"Please install the CoreML backend follwing https://pytorch.org/executorch/main/build-run-coreml.html"
7375
)
7476

77+
minimum_deployment_target = ct.target.iOS15
78+
# In Core ML, quantization in introduced in iOS 16
79+
if pt2e_quantize is not None:
80+
minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS16)
81+
# In Core ML, 8-bit activation quantization is introduced in iOS 17
82+
if pt2e_quantize in ("coreml_8a_c8w", "coreml_baseline_8a_c8w"):
83+
minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS17)
84+
# In Core ML, 4-bit weight compression is introduced in iOS 18
85+
if pt2e_quantize in ("coreml_c4w", "coreml_8a_c4w", "coreml_baseline_8a_c4w"):
86+
minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18)
87+
# In Core ML, stateful execution is introduced in iOS 18
88+
# TODO (https://github.com/pytorch/executorch/issues/4209)
89+
# For now, since mutable buffer is kept in executorch runtime,
90+
# state is out of place and can be handled by older iOS.
91+
# Once mutable buffer can be handed over to delegate, i.e. state becomes in-place, we will have
92+
# if use_kv_cache:
93+
# minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18)
94+
7595
compile_specs = CoreMLBackend.generate_compile_specs(
96+
minimum_deployment_target=minimum_deployment_target,
7697
compute_precision=ct.precision(ct.precision.FLOAT16.value),
7798
# using `ComputeUnit.ALL` can increase the model load time, default to `ComputeUnit.CPU_AND_GPU`
7899
compute_unit=ct.ComputeUnit[ct.ComputeUnit.CPU_AND_GPU.name.upper()],

extension/llm/export/quantizer_lib.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,52 @@ def get_qnn_quantizer(
193193
), "Currently qnn backend only supports QnnQuantizer via pt2e flow"
194194
qnn_quantizer.add_custom_quant_annotations(custom_annotations)
195195
return qnn_quantizer, quant_dtype
196+
197+
198+
def get_coreml_quantizer(pt2e_quantize: str):
199+
try:
200+
from coremltools.optimize.torch.quantization.quantization_config import (
201+
LinearQuantizerConfig,
202+
QuantizationScheme,
203+
)
204+
205+
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.apple.coreml.quantizer`.
206+
from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer
207+
except ImportError:
208+
raise ImportError(
209+
"Please install the CoreML backend follwing https://pytorch.org/executorch/main/build-run-coreml.html"
210+
)
211+
212+
if pt2e_quantize == "coreml_8a_c8w":
213+
config = LinearQuantizerConfig.from_dict(
214+
{
215+
"global_config": {
216+
"quantization_scheme": QuantizationScheme.affine,
217+
"activation_dtype": torch.quint8,
218+
"weight_dtype": torch.qint8,
219+
"weight_per_channel": True,
220+
}
221+
}
222+
)
223+
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `apple`.
224+
quantizer = CoreMLQuantizer(config)
225+
226+
elif pt2e_quantize in ("coreml_c4w", "coreml_8a_c4w"):
227+
raise NotImplementedError("4-bit Core ML quantizer is still under development")
228+
229+
elif pt2e_quantize == "coreml_baseline_8a_c8w":
230+
config = get_symmetric_quantization_config(
231+
is_per_channel=True, is_dynamic=False
232+
)
233+
quantizer = XNNPACKQuantizer().set_global(config)
234+
235+
elif pt2e_quantize == "coreml_baseline_8a_c4w":
236+
config = get_symmetric_quantization_config(
237+
is_per_channel=True, is_dynamic=False, weight_qmin=-8, weight_qmax=7
238+
)
239+
quantizer = XNNPACKQuantizer().set_global(config)
240+
241+
else:
242+
raise ValueError(f"Unsupported Core ML quantizer specification {pt2e_quantize}")
243+
244+
return quantizer

0 commit comments

Comments
 (0)