Skip to content

Commit 4ef6875

Browse files
author
yifan_shen3
committed
Add Core ML quantizer options then use them in Llama export; use appropriate iOS version accordingly
1 parent 5a20a49 commit 4ef6875

File tree

3 files changed

+74
-2
lines changed

3 files changed

+74
-2
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
)
3636

3737
from executorch.extension.llm.export.quantizer_lib import (
38+
get_coreml_quantizer,
3839
get_pt2e_quantization_params,
3940
get_pt2e_quantizers,
4041
get_qnn_quantizer,
@@ -128,6 +129,10 @@ def build_args_parser() -> argparse.ArgumentParser:
128129
"qnn_8a8w",
129130
"qnn_16a16w",
130131
"qnn_16a4w",
132+
"coreml",
133+
"coreml_qc4",
134+
"coreml_baseline",
135+
"coreml_baseline_qc4",
131136
],
132137
help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.",
133138
)
@@ -416,6 +421,10 @@ def get_quantizer_and_quant_params(args):
416421
args.pt2e_quantize, args.quantization_mode
417422
)
418423
quantizers.append(qnn_quantizer)
424+
if args.coreml and args.pt2e_quantize:
425+
assert len(quantizers) == 0, "Should not enable both xnnpack / qnn and coreml"
426+
coreml_quantizer = get_coreml_quantizer(args.pt2e_quantize)
427+
quantizers.append(coreml_quantizer)
419428
logging.info(f"Applying quantizers: {quantizers}")
420429
return pt2e_quant_params, quantizers, quant_dtype
421430

@@ -469,7 +478,8 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
469478
modelname = f"mps_{modelname}"
470479

471480
if args.coreml:
472-
partitioners.append(get_coreml_partitioner(args.use_kv_cache))
481+
coreml_partitioner = get_coreml_partitioner(args.use_kv_cache, args.pt2e_quantize)
482+
partitioners.append(coreml_partitioner)
473483
modelname = f"coreml_{modelname}"
474484

475485
if args.qnn:

extension/llm/export/partitioner_lib.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def get_mps_partitioner(use_kv_cache: bool = False):
5555
return MPSPartitioner(compile_specs)
5656

5757

58-
def get_coreml_partitioner(use_kv_cache: bool = False):
58+
def get_coreml_partitioner(use_kv_cache: bool = False, pt2e_quantize: str = None):
5959
assert (
6060
use_kv_cache is True
6161
), "CoreML backend currently only supports static shape and use_kv_cache=True is the only way to support it at the moment"
@@ -72,7 +72,23 @@ def get_coreml_partitioner(use_kv_cache: bool = False):
7272
"Please install the CoreML backend follwing https://pytorch.org/executorch/main/build-run-coreml.html"
7373
)
7474

75+
minimum_deployment_target = ct.target.iOS15
76+
# In Core ML, activation quantization is introduced in iOS 17
77+
if pt2e_quantize is not None:
78+
minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS17)
79+
# In Core ML, 4-bit weight compression is introduced in iOS 18
80+
if pt2e_quantize in ("coreml_qc4", "coreml_baseline_qc4"):
81+
minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18)
82+
# In Core ML, stateful execution is introduced in iOS 18
83+
# TODO (https://github.com/pytorch/executorch/issues/4209)
84+
# For now, since mutable buffer is kept in executorch runtime,
85+
# state is out of place and can be handled by older iOS.
86+
# Once mutable buffer can be handed over to delegate, i.e. state becomes in-place, we will have
87+
# if use_kv_cache:
88+
# minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18)
89+
7590
compile_specs = CoreMLBackend.generate_compile_specs(
91+
minimum_deployment_target=minimum_deployment_target,
7692
compute_precision=ct.precision(ct.precision.FLOAT16.value),
7793
# using `ComputeUnit.ALL` can increase the model load time, default to `ComputeUnit.CPU_AND_GPU`
7894
compute_unit=ct.ComputeUnit[ct.ComputeUnit.CPU_AND_GPU.name.upper()],

extension/llm/export/quantizer_lib.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,49 @@ def get_qnn_quantizer(
193193
), "Currently qnn backend only supports QnnQuantizer via pt2e flow"
194194
qnn_quantizer.add_custom_quant_annotations(custom_annotations)
195195
return qnn_quantizer, quant_dtype
196+
197+
198+
def get_coreml_quantizer(pt2e_quantize: str):
199+
try:
200+
from coremltools.optimize.torch.quantization.quantization_config import (
201+
LinearQuantizerConfig,
202+
QuantizationScheme,
203+
)
204+
from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer
205+
except ImportError:
206+
raise ImportError(
207+
"Please install the CoreML backend follwing https://pytorch.org/executorch/main/build-run-coreml.html"
208+
)
209+
210+
if pt2e_quantize == "coreml":
211+
config = LinearQuantizerConfig.from_dict(
212+
{
213+
"global_config": {
214+
"quantization_scheme": QuantizationScheme.affine,
215+
"activation_dtype": torch.quint8,
216+
"weight_dtype": torch.qint8,
217+
"weight_per_channel": True,
218+
}
219+
}
220+
)
221+
quantizer = CoreMLQuantizer(config)
222+
223+
elif pt2e_quantize == "coreml_qc4":
224+
raise NotImplementedError("4-bit coreml quantizer is still under development")
225+
226+
elif pt2e_quantize == "coreml_baseline":
227+
config = get_symmetric_quantization_config(
228+
is_per_channel=True, is_dynamic=False
229+
)
230+
quantizer = XNNPACKQuantizer().set_global(config)
231+
232+
elif pt2e_quantize == "coreml_baseline_qc4":
233+
config = get_symmetric_quantization_config(
234+
is_per_channel=True, is_dynamic=False, weight_qmin=-8, weight_qmax=7
235+
)
236+
quantizer = XNNPACKQuantizer().set_global(config)
237+
238+
else:
239+
raise ValueError(f"Unsupported Core ML quantizer specification {pt2e_quantize}")
240+
241+
return quantizer

0 commit comments

Comments
 (0)