Skip to content

Commit dbd9139

Browse files
jackzhxngfacebook-github-bot
authored andcommitted
Accept model type parameter in export_llama (#6507)
Summary: Specify model to export in the CLI. Test Plan: Exported the stories 110M model. ``` python -m examples.models.llama.export_llama -c stories110M/stories110M.pt -p stories110M/params.json -X -kv ``` PR chain: - [Add kwarg example inputs to eager model base](#5765) - [Llama2 model cleanup](#5859) - **YOU ARE HERE ~>** [Accept model type parameter in export_llama](#5910) - [Export TorchTune llama3_2_vision in ET](#5911) - [Runner changes for TorchTune Llama3.2 vision text decoder](#6610) - [Add et version of TorchTune MHA for swapping with custom op](#5912) Reviewed By: helunwencser Differential Revision: D65612837 Pulled By: dvorjackz
1 parent d8a617f commit dbd9139

File tree

7 files changed

+53
-22
lines changed

7 files changed

+53
-22
lines changed

docs/source/llm/build-run-llama3-qualcomm-ai-engine-direct-backend.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ To export Llama 3 8B instruct with the Qualcomm AI Engine Direct Backend, ensure
3939

4040
```bash
4141
# Please note that calibration_data must include the prompt template for special tokens.
42-
python -m examples.models.llama.export_llama -t <path_to_tokenizer.model>
42+
python -m examples.models.llama.export_llama -t <path_to_tokenizer.model>
4343
llama3/Meta-Llama-3-8B-Instruct/tokenizer.model -p <path_to_params.json> -c <path_to_checkpoint_for_Meta-Llama-3-8B-Instruct> --use_kv_cache --qnn --pt2e_quantize qnn_16a4w --disable_dynamic_shape --num_sharding 8 --calibration_tasks wikitext --calibration_limit 1 --calibration_seq_length 128 --optimized_rotation_path <path_to_optimized_matrix> --calibration_data "<|start_header_id|>system<|end_header_id|>\n\nYou are a funny chatbot.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nCould you tell me about Facebook?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
4444
```
4545

examples/demo-apps/android/LlamaDemo/docs/delegates/qualcomm_README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ To export Llama 3 8B instruct with the Qualcomm AI Engine Direct Backend, ensure
158158
* 8B models might need 16GB RAM on the device to run.
159159
```
160160
# Please note that calibration_data must include the prompt template for special tokens.
161-
python -m examples.models.llama.export_llama -t <path_to_tokenizer.model> -p <path_to_params.json> -c <path_to_checkpoint_for_Meta-Llama-3-8B-Instruct> --use_kv_cache --qnn --pt2e_quantize qnn_16a4w --disable_dynamic_shape --num_sharding 8 --calibration_tasks wikitext --calibration_limit 1 --calibration_seq_length 128 --optimized_rotation_path <path_to_optimized_matrix> --calibration_data "<|start_header_id|>system<|end_header_id|>\n\nYou are a funny chatbot.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nCould you tell me about Facebook?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
161+
python -m examples.models.llama.export_llama -t <path_to_tokenizer.model> -p <path_to_params.json> -c <path_to_checkpoint_for_Meta-Llama-3-8B-Instruct> --use_kv_cache --qnn --pt2e_quantize qnn_16a4w --disable_dynamic_shape --num_sharding 8 --calibration_tasks wikitext --calibration_limit 1 --calibration_seq_length 128 --optimized_rotation_path <path_to_optimized_matrix> --calibration_data "<|start_header_id|>system<|end_header_id|>\n\nYou are a funny chatbot.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nCould you tell me about Facebook?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
162162
```
163163

164164
## Pushing Model and Tokenizer

examples/models/llama/README.md

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,9 +239,19 @@ You can export and run the original Llama 3 8B instruct model.
239239

240240
2. Export model and generate `.pte` file
241241
```
242-
python -m examples.models.llama.export_llama --checkpoint <consolidated.00.pth> -p <params.json> -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 128 -d fp32 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --embedding-quantize 4,32 --output_name="llama3_kv_sdpa_xnn_qe_4_32.pte"
242+
python -m examples.models.llama.export_llama \
243+
--checkpoint <consolidated.00.pth> \
244+
-p <params.json> \
245+
-kv \
246+
--use_sdpa_with_kv_cache \
247+
-X \
248+
-qmode 8da4w \
249+
--group_size 128 \
250+
-d fp32 \
251+
--metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' \
252+
--embedding-quantize 4,32 \
253+
--output_name="llama3_kv_sdpa_xnn_qe_4_32.pte"
243254
```
244-
245255
Due to the larger vocabulary size of Llama 3, we recommend quantizing the embeddings with `--embedding-quantize 4,32` as shown above to further reduce the model size.
246256
247257

examples/models/llama/eval_llama_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def gen_eval_wrapper(
190190

191191
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
192192
# GPTFastEvalWrapper: Create a wrapper around a pre-exported model
193-
manager: LLMEdgeManager = _prepare_for_llama_export(model_name, args)
193+
manager: LLMEdgeManager = _prepare_for_llama_export(args)
194194

195195
if len(quantizers) != 0:
196196
manager = manager.export().pt2e_quantize(quantizers)

examples/models/llama/export_llama.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,9 @@
2323
def main() -> None:
2424
seed = 42
2525
torch.manual_seed(seed)
26-
modelname = "llama2"
2726
parser = build_args_parser()
2827
args = parser.parse_args()
29-
export_llama(modelname, args)
28+
export_llama(args)
3029

3130

3231
if __name__ == "__main__":

examples/models/llama/export_llama_lib.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@
8181
verbosity_setting = None
8282

8383

84+
EXECUTORCH_DEFINED_MODELS = ["stories110m", "llama2", "llama3", "llama3_1", "llama3_2"]
85+
TORCHTUNE_DEFINED_MODELS = []
86+
87+
8488
class WeightType(Enum):
8589
LLAMA = "LLAMA"
8690
FAIRSEQ2 = "FAIRSEQ2"
@@ -105,7 +109,7 @@ def verbose_export():
105109

106110

107111
def build_model(
108-
modelname: str = "model",
112+
modelname: str = "llama3",
109113
extra_opts: str = "",
110114
*,
111115
par_local_output: bool = False,
@@ -116,11 +120,11 @@ def build_model(
116120
else:
117121
output_dir_path = "."
118122

119-
argString = f"--checkpoint par:{modelname}_ckpt.pt --params par:{modelname}_params.json {extra_opts} --output-dir {output_dir_path}"
123+
argString = f"--model {modelname} --checkpoint par:model_ckpt.pt --params par:model_params.json {extra_opts} --output-dir {output_dir_path}"
120124
parser = build_args_parser()
121125
args = parser.parse_args(shlex.split(argString))
122126
# pkg_name = resource_pkg_name
123-
return export_llama(modelname, args)
127+
return export_llama(args)
124128

125129

126130
def build_args_parser() -> argparse.ArgumentParser:
@@ -130,6 +134,12 @@ def build_args_parser() -> argparse.ArgumentParser:
130134
# parser.add_argument(
131135
# "-q", "--quantized_ckpt", default=None, help="quantized checkpoint file"
132136
# )
137+
parser.add_argument(
138+
"--model",
139+
default="llama3",
140+
choices=EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS,
141+
help="The Lllama model architecture to use. stories110M, llama2, llama3, llama3_1, and llama3_2 use the same underlying LlamaTransformer architecture defined in ExecuTorch. All other models use TorchTune model definitions.",
142+
)
133143
parser.add_argument(
134144
"-E",
135145
"--embedding-quantize",
@@ -480,13 +490,13 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str:
480490
return return_val
481491

482492

483-
def export_llama(modelname, args) -> str:
493+
def export_llama(args) -> str:
484494
if args.profile_path is not None:
485495
try:
486496
from executorch.util.python_profiler import CProfilerFlameGraph
487497

488498
with CProfilerFlameGraph(args.profile_path):
489-
builder = _export_llama(modelname, args)
499+
builder = _export_llama(args)
490500
assert (
491501
filename := builder.get_saved_pte_filename()
492502
) is not None, "Fail to get file name from builder"
@@ -497,14 +507,14 @@ def export_llama(modelname, args) -> str:
497507
)
498508
return ""
499509
else:
500-
builder = _export_llama(modelname, args)
510+
builder = _export_llama(args)
501511
assert (
502512
filename := builder.get_saved_pte_filename()
503513
) is not None, "Fail to get file name from builder"
504514
return filename
505515

506516

507-
def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
517+
def _prepare_for_llama_export(args) -> LLMEdgeManager:
508518
"""
509519
Helper function for export_llama. Loads the model from checkpoint and params,
510520
and sets up a LLMEdgeManager with initial transforms and dtype conversion.
@@ -530,7 +540,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
530540

531541
return (
532542
_load_llama_model(
533-
modelname=modelname,
543+
args.model,
534544
checkpoint=checkpoint_path,
535545
checkpoint_dir=checkpoint_dir,
536546
params_path=params_path,
@@ -553,7 +563,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
553563
args=args,
554564
)
555565
.set_output_dir(output_dir_path)
556-
.source_transform(_get_source_transforms(modelname, dtype_override, args))
566+
.source_transform(_get_source_transforms(args.model, dtype_override, args))
557567
)
558568

559569

@@ -627,12 +637,12 @@ def _validate_args(args):
627637
)
628638

629639

630-
def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
640+
def _export_llama(args) -> LLMEdgeManager: # noqa: C901
631641
_validate_args(args)
632642
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
633643

634644
# export_to_edge
635-
builder_exported = _prepare_for_llama_export(modelname, args).export()
645+
builder_exported = _prepare_for_llama_export(args).export()
636646

637647
if args.export_only:
638648
exit()
@@ -830,8 +840,8 @@ def _load_llama_model_metadata(
830840

831841

832842
def _load_llama_model(
843+
modelname: str = "llama3",
833844
*,
834-
modelname: str = "llama2",
835845
checkpoint: Optional[str] = None,
836846
checkpoint_dir: Optional[str] = None,
837847
params_path: str,
@@ -859,15 +869,27 @@ def _load_llama_model(
859869
Returns:
860870
An instance of LLMEdgeManager which contains the eager mode model.
861871
"""
872+
862873
assert (
863874
checkpoint or checkpoint_dir
864875
) and params_path, "Both checkpoint/checkpoint_dir and params can't be empty"
865876
logging.info(
866877
f"Loading model with checkpoint={checkpoint}, params={params_path}, use_kv_cache={use_kv_cache}, weight_type={weight_type}"
867878
)
879+
880+
if modelname in EXECUTORCH_DEFINED_MODELS:
881+
module_name = "llama"
882+
model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py.
883+
elif modelname in TORCHTUNE_DEFINED_MODELS:
884+
raise NotImplementedError(
885+
"Torchtune Llama models are not yet supported in ExecuTorch export."
886+
)
887+
else:
888+
raise ValueError(f"{modelname} is not a valid Llama model.")
889+
868890
model, example_inputs, example_kwarg_inputs, _ = EagerModelFactory.create_model(
869-
module_name="llama",
870-
model_class_name="Llama2Model",
891+
module_name,
892+
model_class_name,
871893
checkpoint=checkpoint,
872894
checkpoint_dir=checkpoint_dir,
873895
params=params_path,

examples/models/llama/runner/eager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(self, args):
3737
model_args=model_args,
3838
device="cuda" if torch.cuda.is_available() else "cpu",
3939
)
40-
manager: LLMEdgeManager = _prepare_for_llama_export("llama", args)
40+
manager: LLMEdgeManager = _prepare_for_llama_export(args)
4141
self.model = manager.model.eval().to(device=self.device)
4242

4343
def forward(

0 commit comments

Comments
 (0)