Skip to content

Commit a044fb1

Browse files
committed
Accept model type parameter in export_llama
1 parent b48f917 commit a044fb1

File tree

3 files changed

+71
-35
lines changed

3 files changed

+71
-35
lines changed

examples/models/llama2/README.md

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ LLAMA_CHECKPOINT=path/to/checkpoint.pth
140140
LLAMA_PARAMS=path/to/params.json
141141
142142
python -m examples.models.llama2.export_llama \
143+
--model llama3_2
143144
--checkpoint "${LLAMA_CHECKPOINT:?}" \
144145
--params "${LLAMA_PARAMS:?}" \
145146
-kv -X \
@@ -158,6 +159,7 @@ LLAMA_QUANTIZED_CHECKPOINT=path/to/spinquant/checkpoint.pth
158159
LLAMA_PARAMS=path/to/params.json
159160
160161
python -m examples.models.llama2.export_llama \
162+
--model llama3_2
161163
--checkpoint "${LLAMA_QUANTIZED_CHECKPOINT:?}" \
162164
--params "${LLAMA_PARAMS:?}" \
163165
--use_sdpa_with_kv_cache \
@@ -181,7 +183,19 @@ You can export and run the original Llama 3 8B instruct model.
181183

182184
2. Export model and generate `.pte` file
183185
```
184-
python -m examples.models.llama2.export_llama --checkpoint <consolidated.00.pth> -p <params.json> -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 128 -d fp32 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --embedding-quantize 4,32 --output_name="llama3_kv_sdpa_xnn_qe_4_32.pte"
186+
python -m examples.models.llama2.export_llama
187+
--model llama3
188+
--checkpoint <consolidated.00.pth>
189+
-p <params.json>
190+
-kv
191+
--use_sdpa_with_kv_cache
192+
-X
193+
-qmode 8da4w
194+
--group_size 128
195+
-d fp32
196+
--metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}'
197+
--embedding-quantize 4,32
198+
--output_name="llama3_kv_sdpa_xnn_qe_4_32.pte"
185199
```
186200
187201
Due to the larger vocabulary size of Llama 3, we recommend quantizing the embeddings with `--embedding-quantize 4,32` as shown above to further reduce the model size.
@@ -201,10 +215,9 @@ If you want to deploy and run a smaller model for educational purposes. From `ex
201215
```
202216
3. Export model and generate `.pte` file.
203217
```
204-
python -m examples.models.llama2.export_llama -c stories110M.pt -p params.json -X -kv
218+
python -m examples.models.llama2.export_llama --model llama2 --checkpoint stories110M.pt --params params.json -X -kv
205219
```
206220
4. Create tokenizer.bin.
207-
208221
```
209222
python -m extension.llm.tokenizer.tokenizer -t <tokenizer.model> -o tokenizer.bin
210223
```
@@ -219,10 +232,9 @@ You can export and run the original Llama 2 7B model.
219232
220233
3. Export model and generate `.pte` file:
221234
```
222-
python -m examples.models.llama2.export_llama --checkpoint <checkpoint.pth> --params <params.json> -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 128 -d fp32
235+
python -m examples.models.llama2.export_llama --model llama2 --checkpoint <checkpoint.pth> --params <params.json> -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 128 -d fp32
223236
```
224237
4. Create tokenizer.bin.
225-
226238
```
227239
python -m extension.llm.tokenizer.tokenizer -t <tokenizer.model> -o tokenizer.bin
228240
```
@@ -414,9 +426,9 @@ Please refer to [this tutorial](https://pytorch.org/executorch/main/llm/llama-de
414426
Currently we supported lowering the stories model to other backends, including, CoreML, MPS and QNN. Please refer to the instruction
415427
for each backend ([CoreML](https://pytorch.org/executorch/main/build-run-coreml.html), [MPS](https://pytorch.org/executorch/main/build-run-mps.html), [QNN](https://pytorch.org/executorch/main/build-run-qualcomm-ai-engine-direct-backend.html)) before trying to lower them. After the backend library is installed, the script to export a lowered model is
416428
417-
- Lower to CoreML: `python -m examples.models.llama2.export_llama -kv --disable_dynamic_shape --coreml -c stories110M.pt -p params.json `
418-
- MPS: `python -m examples.models.llama2.export_llama -kv --disable_dynamic_shape --mps -c stories110M.pt -p params.json `
419-
- QNN: `python -m examples.models.llama2.export_llama -kv --disable_dynamic_shape --qnn -c stories110M.pt -p params.json `
429+
- Lower to CoreML: `python -m examples.models.llama2.export_llama --model llama3 -kv --disable_dynamic_shape --coreml -c stories110M.pt -p params.json `
430+
- MPS: `python -m examples.models.llama2.export_llama --model llama3 -kv --disable_dynamic_shape --mps -c stories110M.pt -p params.json `
431+
- QNN: `python -m examples.models.llama2.export_llama --model llama3 -kv --disable_dynamic_shape --qnn -c stories110M.pt -p params.json `
420432
421433
The iOS LLAMA app supports the CoreML and MPS model and the Android LLAMA app supports the QNN model. On Android, it also allow to cross compiler the llama runner binary, push to the device and run.
422434

examples/models/llama2/export_llama.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,9 @@
2020
def main() -> None:
2121
seed = 42
2222
torch.manual_seed(seed)
23-
modelname = "llama2"
2423
parser = build_args_parser()
2524
args = parser.parse_args()
26-
export_llama(modelname, args)
25+
export_llama(args)
2726

2827

2928
if __name__ == "__main__":

examples/models/llama2/export_llama_lib.py

Lines changed: 50 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@
7777
verbosity_setting = None
7878

7979

80+
EXECUTORCH_DEFINED_MODELS = ["llama2", "llama3", "llama3_1", "llama3_2"]
81+
TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"]
82+
83+
8084
class WeightType(Enum):
8185
LLAMA = "LLAMA"
8286
FAIRSEQ2 = "FAIRSEQ2"
@@ -112,11 +116,11 @@ def build_model(
112116
else:
113117
output_dir_path = "."
114118

115-
argString = f"--checkpoint par:{modelname}_ckpt.pt --params par:{modelname}_params.json {extra_opts} --output-dir {output_dir_path}"
119+
argString = f"--model {modelname} --checkpoint par:{modelname}_ckpt.pt --params par:{modelname}_params.json {extra_opts} --output-dir {output_dir_path}"
116120
parser = build_args_parser()
117121
args = parser.parse_args(shlex.split(argString))
118122
# pkg_name = resource_pkg_name
119-
return export_llama(modelname, args)
123+
return export_llama(args)
120124

121125

122126
def build_args_parser() -> argparse.ArgumentParser:
@@ -126,6 +130,12 @@ def build_args_parser() -> argparse.ArgumentParser:
126130
# parser.add_argument(
127131
# "-q", "--quantized_ckpt", default=None, help="quantized checkpoint file"
128132
# )
133+
parser.add_argument(
134+
"--model",
135+
default="llama2",
136+
choices=EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS,
137+
help="The Lllama model to export. llama2, llama3, llama3_1, llama3_2 share the same architecture, so they are technically interchangeable, given you provide the checkpoint file for the desired version.",
138+
)
129139
parser.add_argument(
130140
"-E",
131141
"--embedding-quantize",
@@ -456,13 +466,13 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str:
456466
return return_val
457467

458468

459-
def export_llama(modelname, args) -> str:
469+
def export_llama(args) -> str:
460470
if args.profile_path is not None:
461471
try:
462472
from executorch.util.python_profiler import CProfilerFlameGraph
463473

464474
with CProfilerFlameGraph(args.profile_path):
465-
builder = _export_llama(modelname, args)
475+
builder = _export_llama(args)
466476
assert (
467477
filename := builder.get_saved_pte_filename()
468478
) is not None, "Fail to get file name from builder"
@@ -473,14 +483,14 @@ def export_llama(modelname, args) -> str:
473483
)
474484
return ""
475485
else:
476-
builder = _export_llama(modelname, args)
486+
builder = _export_llama(args)
477487
assert (
478488
filename := builder.get_saved_pte_filename()
479489
) is not None, "Fail to get file name from builder"
480490
return filename
481491

482492

483-
def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
493+
def _prepare_for_llama_export(args) -> LLMEdgeManager:
484494
"""
485495
Helper function for export_llama. Loads the model from checkpoint and params,
486496
and sets up a LLMEdgeManager with initial transforms and dtype conversion.
@@ -506,7 +516,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
506516

507517
return (
508518
_load_llama_model(
509-
modelname=modelname,
519+
args.model,
510520
checkpoint=checkpoint_path,
511521
checkpoint_dir=checkpoint_dir,
512522
params_path=params_path,
@@ -528,7 +538,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
528538
args=args,
529539
)
530540
.set_output_dir(output_dir_path)
531-
.source_transform(_get_source_transforms(modelname, dtype_override, args))
541+
.source_transform(_get_source_transforms(args.model, dtype_override, args))
532542
)
533543

534544

@@ -566,13 +576,13 @@ def _validate_args(args):
566576
raise ValueError("Model shard is only supported with qnn backend now.")
567577

568578

569-
def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
579+
def _export_llama(args) -> LLMEdgeManager: # noqa: C901
570580
_validate_args(args)
571581
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
572582

573583
# export_to_edge
574584
builder_exported_to_edge = (
575-
_prepare_for_llama_export(modelname, args)
585+
_prepare_for_llama_export(args)
576586
.capture_pre_autograd_graph()
577587
.pt2e_quantize(quantizers)
578588
.export_to_edge()
@@ -746,8 +756,8 @@ def _load_llama_model_metadata(
746756

747757

748758
def _load_llama_model(
759+
modelname: str,
749760
*,
750-
modelname: str = "llama2",
751761
checkpoint: Optional[str] = None,
752762
checkpoint_dir: Optional[str] = None,
753763
params_path: str,
@@ -774,26 +784,41 @@ def _load_llama_model(
774784
Returns:
775785
An instance of LLMEdgeManager which contains the eager mode model.
776786
"""
787+
777788
assert (
778789
checkpoint or checkpoint_dir
779790
) and params_path, "Both checkpoint/checkpoint_dir and params can't be empty"
780791
logging.info(
781792
f"Loading model with checkpoint={checkpoint}, params={params_path}, use_kv_cache={use_kv_cache}, weight_type={weight_type}"
782793
)
783-
model, example_inputs, _ = EagerModelFactory.create_model(
784-
"llama2",
785-
"Llama2Model",
786-
checkpoint=checkpoint,
787-
checkpoint_dir=checkpoint_dir,
788-
params=params_path,
789-
use_kv_cache=use_kv_cache,
790-
use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,
791-
generate_full_logits=generate_full_logits,
792-
fairseq2=weight_type == WeightType.FAIRSEQ2,
793-
max_seq_len=max_seq_len,
794-
enable_dynamic_shape=enable_dynamic_shape,
795-
output_prune_map_path=output_prune_map_path,
796-
args=args,
794+
795+
if modelname in EXECUTORCH_DEFINED_MODELS:
796+
# Set to llama2 because all models in EXECUTORCH_DEFINED_MODELS share the same archteciture as
797+
# defined in example/models/llama2.
798+
modelname = "llama2"
799+
model_class_name = "Llama2Model"
800+
elif modelname in TORCHTUNE_DEFINED_MODELS:
801+
if modelname == "llama3_2_vision":
802+
model_class_name = "Llama3_2Decoder"
803+
else:
804+
raise ValueError(f"{modelname} is not a valid Llama model.")
805+
806+
model, (example_inputs, example_kwarg_inputs), dynamic_shapes = (
807+
EagerModelFactory.create_model(
808+
modelname,
809+
model_class_name,
810+
checkpoint=checkpoint,
811+
checkpoint_dir=checkpoint_dir,
812+
params=params_path,
813+
use_kv_cache=use_kv_cache,
814+
use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,
815+
generate_full_logits=generate_full_logits,
816+
fairseq2=weight_type == WeightType.FAIRSEQ2,
817+
max_seq_len=max_seq_len,
818+
enable_dynamic_shape=enable_dynamic_shape,
819+
output_prune_map_path=output_prune_map_path,
820+
args=args,
821+
)
797822
)
798823
if dtype_override:
799824
assert isinstance(

0 commit comments

Comments
 (0)