Skip to content

Commit a6f96a2

Browse files
committed
Accept model type parameter in export_llama
1 parent e8715ba commit a6f96a2

File tree

3 files changed

+71
-33
lines changed

3 files changed

+71
-33
lines changed

examples/models/llama2/README.md

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ LLAMA_CHECKPOINT=path/to/checkpoint.pth
142142
LLAMA_PARAMS=path/to/params.json
143143
144144
python -m examples.models.llama2.export_llama \
145+
--model llama3_2
145146
--checkpoint "${LLAMA_CHECKPOINT:?}" \
146147
--params "${LLAMA_PARAMS:?}" \
147148
-kv \
@@ -162,6 +163,7 @@ LLAMA_QUANTIZED_CHECKPOINT=path/to/spinquant/checkpoint.pth
162163
LLAMA_PARAMS=path/to/params.json
163164
164165
python -m examples.models.llama2.export_llama \
166+
--model llama3_2
165167
--checkpoint "${LLAMA_QUANTIZED_CHECKPOINT:?}" \
166168
--params "${LLAMA_PARAMS:?}" \
167169
--use_sdpa_with_kv_cache \
@@ -185,7 +187,19 @@ You can export and run the original Llama 3 8B instruct model.
185187

186188
2. Export model and generate `.pte` file
187189
```
188-
python -m examples.models.llama2.export_llama --checkpoint <consolidated.00.pth> -p <params.json> -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 128 -d fp32 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --embedding-quantize 4,32 --output_name="llama3_kv_sdpa_xnn_qe_4_32.pte"
190+
python -m examples.models.llama2.export_llama
191+
--model llama3
192+
--checkpoint <consolidated.00.pth>
193+
-p <params.json>
194+
-kv
195+
--use_sdpa_with_kv_cache
196+
-X
197+
-qmode 8da4w
198+
--group_size 128
199+
-d fp32
200+
--metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}'
201+
--embedding-quantize 4,32
202+
--output_name="llama3_kv_sdpa_xnn_qe_4_32.pte"
189203
```
190204
191205
Due to the larger vocabulary size of Llama 3, we recommend quantizing the embeddings with `--embedding-quantize 4,32` as shown above to further reduce the model size.
@@ -205,7 +219,7 @@ If you want to deploy and run a smaller model for educational purposes. From `ex
205219
```
206220
3. Export model and generate `.pte` file.
207221
```
208-
python -m examples.models.llama2.export_llama -c stories110M.pt -p params.json -X -kv
222+
python -m examples.models.llama2.export_llama --model llama2 --checkpoint stories110M.pt --params params.json -X -kv
209223
```
210224
211225
### Option D: Download and export Llama 2 7B model
@@ -218,7 +232,7 @@ You can export and run the original Llama 2 7B model.
218232
219233
3. Export model and generate `.pte` file:
220234
```
221-
python -m examples.models.llama2.export_llama --checkpoint <checkpoint.pth> --params <params.json> -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 128 -d fp32
235+
python -m examples.models.llama2.export_llama --model llama2 --checkpoint <checkpoint.pth> --params <params.json> -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 128 -d fp32
222236
```
223237
4. Create tokenizer.bin.
224238
```
@@ -432,9 +446,9 @@ Please refer to [this tutorial](https://pytorch.org/executorch/main/llm/llama-de
432446
Currently we supported lowering the stories model to other backends, including, CoreML, MPS and QNN. Please refer to the instruction
433447
for each backend ([CoreML](https://pytorch.org/executorch/main/build-run-coreml.html), [MPS](https://pytorch.org/executorch/main/build-run-mps.html), [QNN](https://pytorch.org/executorch/main/build-run-qualcomm-ai-engine-direct-backend.html)) before trying to lower them. After the backend library is installed, the script to export a lowered model is
434448
435-
- Lower to CoreML: `python -m examples.models.llama2.export_llama -kv --disable_dynamic_shape --coreml -c stories110M.pt -p params.json `
436-
- MPS: `python -m examples.models.llama2.export_llama -kv --disable_dynamic_shape --mps -c stories110M.pt -p params.json `
437-
- QNN: `python -m examples.models.llama2.export_llama -kv --disable_dynamic_shape --qnn -c stories110M.pt -p params.json `
449+
- Lower to CoreML: `python -m examples.models.llama2.export_llama --model llama3 -kv --disable_dynamic_shape --coreml -c stories110M.pt -p params.json `
450+
- MPS: `python -m examples.models.llama2.export_llama --model llama3 -kv --disable_dynamic_shape --mps -c stories110M.pt -p params.json `
451+
- QNN: `python -m examples.models.llama2.export_llama --model llama3 -kv --disable_dynamic_shape --qnn -c stories110M.pt -p params.json `
438452
439453
The iOS LLAMA app supports the CoreML and MPS model and the Android LLAMA app supports the QNN model. On Android, it also allow to cross compiler the llama runner binary, push to the device and run.
440454

examples/models/llama2/export_llama.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,9 @@
2020
def main() -> None:
2121
seed = 42
2222
torch.manual_seed(seed)
23-
modelname = "llama2"
2423
parser = build_args_parser()
2524
args = parser.parse_args()
26-
export_llama(modelname, args)
25+
export_llama(args)
2726

2827

2928
if __name__ == "__main__":

examples/models/llama2/export_llama_lib.py

Lines changed: 50 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,10 @@
7878
verbosity_setting = None
7979

8080

81+
EXECUTORCH_DEFINED_MODELS = ["llama2", "llama3", "llama3_1", "llama3_2"]
82+
TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"]
83+
84+
8185
class WeightType(Enum):
8286
LLAMA = "LLAMA"
8387
FAIRSEQ2 = "FAIRSEQ2"
@@ -113,11 +117,11 @@ def build_model(
113117
else:
114118
output_dir_path = "."
115119

116-
argString = f"--checkpoint par:{modelname}_ckpt.pt --params par:{modelname}_params.json {extra_opts} --output-dir {output_dir_path}"
120+
argString = f"--model {modelname} --checkpoint par:{modelname}_ckpt.pt --params par:{modelname}_params.json {extra_opts} --output-dir {output_dir_path}"
117121
parser = build_args_parser()
118122
args = parser.parse_args(shlex.split(argString))
119123
# pkg_name = resource_pkg_name
120-
return export_llama(modelname, args)
124+
return export_llama(args)
121125

122126

123127
def build_args_parser() -> argparse.ArgumentParser:
@@ -127,6 +131,12 @@ def build_args_parser() -> argparse.ArgumentParser:
127131
# parser.add_argument(
128132
# "-q", "--quantized_ckpt", default=None, help="quantized checkpoint file"
129133
# )
134+
parser.add_argument(
135+
"--model",
136+
default="llama2",
137+
choices=EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS,
138+
help="The Lllama model to export. llama2, llama3, llama3_1, llama3_2 share the same architecture, so they are technically interchangeable, given you provide the checkpoint file for the desired version.",
139+
)
130140
parser.add_argument(
131141
"-E",
132142
"--embedding-quantize",
@@ -458,13 +468,13 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str:
458468
return return_val
459469

460470

461-
def export_llama(modelname, args) -> str:
471+
def export_llama(args) -> str:
462472
if args.profile_path is not None:
463473
try:
464474
from executorch.util.python_profiler import CProfilerFlameGraph
465475

466476
with CProfilerFlameGraph(args.profile_path):
467-
builder = _export_llama(modelname, args)
477+
builder = _export_llama(args)
468478
assert (
469479
filename := builder.get_saved_pte_filename()
470480
) is not None, "Fail to get file name from builder"
@@ -475,14 +485,14 @@ def export_llama(modelname, args) -> str:
475485
)
476486
return ""
477487
else:
478-
builder = _export_llama(modelname, args)
488+
builder = _export_llama(args)
479489
assert (
480490
filename := builder.get_saved_pte_filename()
481491
) is not None, "Fail to get file name from builder"
482492
return filename
483493

484494

485-
def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
495+
def _prepare_for_llama_export(args) -> LLMEdgeManager:
486496
"""
487497
Helper function for export_llama. Loads the model from checkpoint and params,
488498
and sets up a LLMEdgeManager with initial transforms and dtype conversion.
@@ -508,7 +518,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
508518

509519
return (
510520
_load_llama_model(
511-
modelname=modelname,
521+
args.model,
512522
checkpoint=checkpoint_path,
513523
checkpoint_dir=checkpoint_dir,
514524
params_path=params_path,
@@ -530,7 +540,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
530540
args=args,
531541
)
532542
.set_output_dir(output_dir_path)
533-
.source_transform(_get_source_transforms(modelname, dtype_override, args))
543+
.source_transform(_get_source_transforms(args.model, dtype_override, args))
534544
)
535545

536546

@@ -574,13 +584,13 @@ def _validate_args(args):
574584
raise ValueError("Model shard is only supported with qnn backend now.")
575585

576586

577-
def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
587+
def _export_llama(args) -> LLMEdgeManager: # noqa: C901
578588
_validate_args(args)
579589
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
580590

581591
# export_to_edge
582592
builder_exported_to_edge = (
583-
_prepare_for_llama_export(modelname, args)
593+
_prepare_for_llama_export(args)
584594
.capture_pre_autograd_graph()
585595
.pt2e_quantize(quantizers)
586596
.export_to_edge()
@@ -748,8 +758,8 @@ def _load_llama_model_metadata(
748758

749759

750760
def _load_llama_model(
761+
modelname: str,
751762
*,
752-
modelname: str = "llama2",
753763
checkpoint: Optional[str] = None,
754764
checkpoint_dir: Optional[str] = None,
755765
params_path: str,
@@ -776,26 +786,41 @@ def _load_llama_model(
776786
Returns:
777787
An instance of LLMEdgeManager which contains the eager mode model.
778788
"""
789+
779790
assert (
780791
checkpoint or checkpoint_dir
781792
) and params_path, "Both checkpoint/checkpoint_dir and params can't be empty"
782793
logging.info(
783794
f"Loading model with checkpoint={checkpoint}, params={params_path}, use_kv_cache={use_kv_cache}, weight_type={weight_type}"
784795
)
785-
model, example_inputs, example_kwarg_inputs, _ = EagerModelFactory.create_model(
786-
"llama2",
787-
"Llama2Model",
788-
checkpoint=checkpoint,
789-
checkpoint_dir=checkpoint_dir,
790-
params=params_path,
791-
use_kv_cache=use_kv_cache,
792-
use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,
793-
generate_full_logits=generate_full_logits,
794-
fairseq2=weight_type == WeightType.FAIRSEQ2,
795-
max_seq_len=max_seq_len,
796-
enable_dynamic_shape=enable_dynamic_shape,
797-
output_prune_map_path=output_prune_map_path,
798-
args=args,
796+
797+
if modelname in EXECUTORCH_DEFINED_MODELS:
798+
# Set to llama2 because all models in EXECUTORCH_DEFINED_MODELS share the same archteciture as
799+
# defined in example/models/llama2.
800+
modelname = "llama2"
801+
model_class_name = "Llama2Model"
802+
elif modelname in TORCHTUNE_DEFINED_MODELS:
803+
if modelname == "llama3_2_vision":
804+
model_class_name = "Llama3_2Decoder"
805+
else:
806+
raise ValueError(f"{modelname} is not a valid Llama model.")
807+
808+
model, example_inputs, example_kwarg_inputs, _ = (
809+
EagerModelFactory.create_model(
810+
modelname,
811+
model_class_name,
812+
checkpoint=checkpoint,
813+
checkpoint_dir=checkpoint_dir,
814+
params=params_path,
815+
use_kv_cache=use_kv_cache,
816+
use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,
817+
generate_full_logits=generate_full_logits,
818+
fairseq2=weight_type == WeightType.FAIRSEQ2,
819+
max_seq_len=max_seq_len,
820+
enable_dynamic_shape=enable_dynamic_shape,
821+
output_prune_map_path=output_prune_map_path,
822+
args=args,
823+
)
799824
)
800825
if dtype_override:
801826
assert isinstance(

0 commit comments

Comments
 (0)