Skip to content

Commit 331cc3e

Browse files
committed
Accept model type parameter in export_llama
1 parent d4b9d39 commit 331cc3e

File tree

3 files changed

+71
-33
lines changed

3 files changed

+71
-33
lines changed

examples/models/llama2/README.md

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ LLAMA_CHECKPOINT=path/to/checkpoint.pth
140140
LLAMA_PARAMS=path/to/params.json
141141
142142
python -m examples.models.llama2.export_llama \
143+
--model llama3_2
143144
--checkpoint "${LLAMA_CHECKPOINT:?}" \
144145
--params "${LLAMA_PARAMS:?}" \
145146
-kv \
@@ -160,6 +161,7 @@ LLAMA_QUANTIZED_CHECKPOINT=path/to/spinquant/checkpoint.pth
160161
LLAMA_PARAMS=path/to/params.json
161162
162163
python -m examples.models.llama2.export_llama \
164+
--model llama3_2
163165
--checkpoint "${LLAMA_QUANTIZED_CHECKPOINT:?}" \
164166
--params "${LLAMA_PARAMS:?}" \
165167
--use_sdpa_with_kv_cache \
@@ -183,7 +185,19 @@ You can export and run the original Llama 3 8B instruct model.
183185

184186
2. Export model and generate `.pte` file
185187
```
186-
python -m examples.models.llama2.export_llama --checkpoint <consolidated.00.pth> -p <params.json> -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 128 -d fp32 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --embedding-quantize 4,32 --output_name="llama3_kv_sdpa_xnn_qe_4_32.pte"
188+
python -m examples.models.llama2.export_llama
189+
--model llama3
190+
--checkpoint <consolidated.00.pth>
191+
-p <params.json>
192+
-kv
193+
--use_sdpa_with_kv_cache
194+
-X
195+
-qmode 8da4w
196+
--group_size 128
197+
-d fp32
198+
--metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}'
199+
--embedding-quantize 4,32
200+
--output_name="llama3_kv_sdpa_xnn_qe_4_32.pte"
187201
```
188202
189203
Due to the larger vocabulary size of Llama 3, we recommend quantizing the embeddings with `--embedding-quantize 4,32` as shown above to further reduce the model size.
@@ -203,7 +217,7 @@ If you want to deploy and run a smaller model for educational purposes. From `ex
203217
```
204218
3. Export model and generate `.pte` file.
205219
```
206-
python -m examples.models.llama2.export_llama -c stories110M.pt -p params.json -X -kv
220+
python -m examples.models.llama2.export_llama --model llama2 --checkpoint stories110M.pt --params params.json -X -kv
207221
```
208222
209223
### Option D: Download and export Llama 2 7B model
@@ -216,7 +230,7 @@ You can export and run the original Llama 2 7B model.
216230
217231
3. Export model and generate `.pte` file:
218232
```
219-
python -m examples.models.llama2.export_llama --checkpoint <checkpoint.pth> --params <params.json> -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 128 -d fp32
233+
python -m examples.models.llama2.export_llama --model llama2 --checkpoint <checkpoint.pth> --params <params.json> -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 128 -d fp32
220234
```
221235
4. Create tokenizer.bin.
222236
```
@@ -410,9 +424,9 @@ Please refer to [this tutorial](https://pytorch.org/executorch/main/llm/llama-de
410424
Currently we supported lowering the stories model to other backends, including, CoreML, MPS and QNN. Please refer to the instruction
411425
for each backend ([CoreML](https://pytorch.org/executorch/main/build-run-coreml.html), [MPS](https://pytorch.org/executorch/main/build-run-mps.html), [QNN](https://pytorch.org/executorch/main/build-run-qualcomm-ai-engine-direct-backend.html)) before trying to lower them. After the backend library is installed, the script to export a lowered model is
412426
413-
- Lower to CoreML: `python -m examples.models.llama2.export_llama -kv --disable_dynamic_shape --coreml -c stories110M.pt -p params.json `
414-
- MPS: `python -m examples.models.llama2.export_llama -kv --disable_dynamic_shape --mps -c stories110M.pt -p params.json `
415-
- QNN: `python -m examples.models.llama2.export_llama -kv --disable_dynamic_shape --qnn -c stories110M.pt -p params.json `
427+
- Lower to CoreML: `python -m examples.models.llama2.export_llama --model llama3 -kv --disable_dynamic_shape --coreml -c stories110M.pt -p params.json `
428+
- MPS: `python -m examples.models.llama2.export_llama --model llama3 -kv --disable_dynamic_shape --mps -c stories110M.pt -p params.json `
429+
- QNN: `python -m examples.models.llama2.export_llama --model llama3 -kv --disable_dynamic_shape --qnn -c stories110M.pt -p params.json `
416430
417431
The iOS LLAMA app supports the CoreML and MPS model and the Android LLAMA app supports the QNN model. On Android, it also allow to cross compiler the llama runner binary, push to the device and run.
418432

examples/models/llama2/export_llama.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,9 @@
2020
def main() -> None:
2121
seed = 42
2222
torch.manual_seed(seed)
23-
modelname = "llama2"
2423
parser = build_args_parser()
2524
args = parser.parse_args()
26-
export_llama(modelname, args)
25+
export_llama(args)
2726

2827

2928
if __name__ == "__main__":

examples/models/llama2/export_llama_lib.py

Lines changed: 50 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@
7777
verbosity_setting = None
7878

7979

80+
EXECUTORCH_DEFINED_MODELS = ["llama2", "llama3", "llama3_1", "llama3_2"]
81+
TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"]
82+
83+
8084
class WeightType(Enum):
8185
LLAMA = "LLAMA"
8286
FAIRSEQ2 = "FAIRSEQ2"
@@ -112,11 +116,11 @@ def build_model(
112116
else:
113117
output_dir_path = "."
114118

115-
argString = f"--checkpoint par:{modelname}_ckpt.pt --params par:{modelname}_params.json {extra_opts} --output-dir {output_dir_path}"
119+
argString = f"--model {modelname} --checkpoint par:{modelname}_ckpt.pt --params par:{modelname}_params.json {extra_opts} --output-dir {output_dir_path}"
116120
parser = build_args_parser()
117121
args = parser.parse_args(shlex.split(argString))
118122
# pkg_name = resource_pkg_name
119-
return export_llama(modelname, args)
123+
return export_llama(args)
120124

121125

122126
def build_args_parser() -> argparse.ArgumentParser:
@@ -126,6 +130,12 @@ def build_args_parser() -> argparse.ArgumentParser:
126130
# parser.add_argument(
127131
# "-q", "--quantized_ckpt", default=None, help="quantized checkpoint file"
128132
# )
133+
parser.add_argument(
134+
"--model",
135+
default="llama2",
136+
choices=EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS,
137+
help="The Lllama model to export. llama2, llama3, llama3_1, llama3_2 share the same architecture, so they are technically interchangeable, given you provide the checkpoint file for the desired version.",
138+
)
129139
parser.add_argument(
130140
"-E",
131141
"--embedding-quantize",
@@ -456,13 +466,13 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str:
456466
return return_val
457467

458468

459-
def export_llama(modelname, args) -> str:
469+
def export_llama(args) -> str:
460470
if args.profile_path is not None:
461471
try:
462472
from executorch.util.python_profiler import CProfilerFlameGraph
463473

464474
with CProfilerFlameGraph(args.profile_path):
465-
builder = _export_llama(modelname, args)
475+
builder = _export_llama(args)
466476
assert (
467477
filename := builder.get_saved_pte_filename()
468478
) is not None, "Fail to get file name from builder"
@@ -473,14 +483,14 @@ def export_llama(modelname, args) -> str:
473483
)
474484
return ""
475485
else:
476-
builder = _export_llama(modelname, args)
486+
builder = _export_llama(args)
477487
assert (
478488
filename := builder.get_saved_pte_filename()
479489
) is not None, "Fail to get file name from builder"
480490
return filename
481491

482492

483-
def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
493+
def _prepare_for_llama_export(args) -> LLMEdgeManager:
484494
"""
485495
Helper function for export_llama. Loads the model from checkpoint and params,
486496
and sets up a LLMEdgeManager with initial transforms and dtype conversion.
@@ -506,7 +516,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
506516

507517
return (
508518
_load_llama_model(
509-
modelname=modelname,
519+
args.model,
510520
checkpoint=checkpoint_path,
511521
checkpoint_dir=checkpoint_dir,
512522
params_path=params_path,
@@ -528,7 +538,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
528538
args=args,
529539
)
530540
.set_output_dir(output_dir_path)
531-
.source_transform(_get_source_transforms(modelname, dtype_override, args))
541+
.source_transform(_get_source_transforms(args.model, dtype_override, args))
532542
)
533543

534544

@@ -566,13 +576,13 @@ def _validate_args(args):
566576
raise ValueError("Model shard is only supported with qnn backend now.")
567577

568578

569-
def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
579+
def _export_llama(args) -> LLMEdgeManager: # noqa: C901
570580
_validate_args(args)
571581
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
572582

573583
# export_to_edge
574584
builder_exported_to_edge = (
575-
_prepare_for_llama_export(modelname, args)
585+
_prepare_for_llama_export(args)
576586
.capture_pre_autograd_graph()
577587
.pt2e_quantize(quantizers)
578588
.export_to_edge()
@@ -746,8 +756,8 @@ def _load_llama_model_metadata(
746756

747757

748758
def _load_llama_model(
759+
modelname: str,
749760
*,
750-
modelname: str = "llama2",
751761
checkpoint: Optional[str] = None,
752762
checkpoint_dir: Optional[str] = None,
753763
params_path: str,
@@ -774,26 +784,41 @@ def _load_llama_model(
774784
Returns:
775785
An instance of LLMEdgeManager which contains the eager mode model.
776786
"""
787+
777788
assert (
778789
checkpoint or checkpoint_dir
779790
) and params_path, "Both checkpoint/checkpoint_dir and params can't be empty"
780791
logging.info(
781792
f"Loading model with checkpoint={checkpoint}, params={params_path}, use_kv_cache={use_kv_cache}, weight_type={weight_type}"
782793
)
783-
model, example_inputs, example_kwarg_inputs, _ = EagerModelFactory.create_model(
784-
"llama2",
785-
"Llama2Model",
786-
checkpoint=checkpoint,
787-
checkpoint_dir=checkpoint_dir,
788-
params=params_path,
789-
use_kv_cache=use_kv_cache,
790-
use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,
791-
generate_full_logits=generate_full_logits,
792-
fairseq2=weight_type == WeightType.FAIRSEQ2,
793-
max_seq_len=max_seq_len,
794-
enable_dynamic_shape=enable_dynamic_shape,
795-
output_prune_map_path=output_prune_map_path,
796-
args=args,
794+
795+
if modelname in EXECUTORCH_DEFINED_MODELS:
796+
# Set to llama2 because all models in EXECUTORCH_DEFINED_MODELS share the same archteciture as
797+
# defined in example/models/llama2.
798+
modelname = "llama2"
799+
model_class_name = "Llama2Model"
800+
elif modelname in TORCHTUNE_DEFINED_MODELS:
801+
if modelname == "llama3_2_vision":
802+
model_class_name = "Llama3_2Decoder"
803+
else:
804+
raise ValueError(f"{modelname} is not a valid Llama model.")
805+
806+
model, example_inputs, example_kwarg_inputs, _ = (
807+
EagerModelFactory.create_model(
808+
modelname,
809+
model_class_name,
810+
checkpoint=checkpoint,
811+
checkpoint_dir=checkpoint_dir,
812+
params=params_path,
813+
use_kv_cache=use_kv_cache,
814+
use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,
815+
generate_full_logits=generate_full_logits,
816+
fairseq2=weight_type == WeightType.FAIRSEQ2,
817+
max_seq_len=max_seq_len,
818+
enable_dynamic_shape=enable_dynamic_shape,
819+
output_prune_map_path=output_prune_map_path,
820+
args=args,
821+
)
797822
)
798823
if dtype_override:
799824
assert isinstance(

0 commit comments

Comments
 (0)