Skip to content

Commit eb49bcf

Browse files
jackzhxngfacebook-github-bot
authored andcommitted
Accept model type parameter in export_llama (#6507)
Summary: Specify model to export in the CLI. Test Plan: Exported the stories 110M model. ``` python -m examples.models.llama.export_llama -c stories110M/stories110M.pt -p stories110M/params.json -X -kv ``` PR chain: - [Add kwarg example inputs to eager model base](#5765) - [Llama2 model cleanup](#5859) - **YOU ARE HERE ~>** [Accept model type parameter in export_llama](#5910) - [Export TorchTune llama3_2_vision in ET](#5911) - [Runner changes for TorchTune Llama3.2 vision text decoder](#6610) - [Add et version of TorchTune MHA for swapping with custom op](#5912) Differential Revision: D65612837 Pulled By: dvorjackz
1 parent b23c9e6 commit eb49bcf

File tree

7 files changed

+53
-22
lines changed

7 files changed

+53
-22
lines changed

docs/source/llm/build-run-llama3-qualcomm-ai-engine-direct-backend.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ To export Llama 3 8B instruct with the Qualcomm AI Engine Direct Backend, ensure
3939

4040
```bash
4141
# Please note that calibration_data must include the prompt template for special tokens.
42-
python -m examples.models.llama.export_llama -t <path_to_tokenizer.model>
42+
python -m examples.models.llama.export_llama -t <path_to_tokenizer.model>
4343
llama3/Meta-Llama-3-8B-Instruct/tokenizer.model -p <path_to_params.json> -c <path_to_checkpoint_for_Meta-Llama-3-8B-Instruct> --use_kv_cache --qnn --pt2e_quantize qnn_16a4w --disable_dynamic_shape --num_sharding 8 --calibration_tasks wikitext --calibration_limit 1 --calibration_seq_length 128 --optimized_rotation_path <path_to_optimized_matrix> --calibration_data "<|start_header_id|>system<|end_header_id|>\n\nYou are a funny chatbot.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nCould you tell me about Facebook?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
4444
```
4545

examples/demo-apps/android/LlamaDemo/docs/delegates/qualcomm_README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ To export Llama 3 8B instruct with the Qualcomm AI Engine Direct Backend, ensure
158158
* 8B models might need 16GB RAM on the device to run.
159159
```
160160
# Please note that calibration_data must include the prompt template for special tokens.
161-
python -m examples.models.llama.export_llama -t <path_to_tokenizer.model> -p <path_to_params.json> -c <path_to_checkpoint_for_Meta-Llama-3-8B-Instruct> --use_kv_cache --qnn --pt2e_quantize qnn_16a4w --disable_dynamic_shape --num_sharding 8 --calibration_tasks wikitext --calibration_limit 1 --calibration_seq_length 128 --optimized_rotation_path <path_to_optimized_matrix> --calibration_data "<|start_header_id|>system<|end_header_id|>\n\nYou are a funny chatbot.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nCould you tell me about Facebook?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
161+
python -m examples.models.llama.export_llama -t <path_to_tokenizer.model> -p <path_to_params.json> -c <path_to_checkpoint_for_Meta-Llama-3-8B-Instruct> --use_kv_cache --qnn --pt2e_quantize qnn_16a4w --disable_dynamic_shape --num_sharding 8 --calibration_tasks wikitext --calibration_limit 1 --calibration_seq_length 128 --optimized_rotation_path <path_to_optimized_matrix> --calibration_data "<|start_header_id|>system<|end_header_id|>\n\nYou are a funny chatbot.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nCould you tell me about Facebook?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
162162
```
163163

164164
## Pushing Model and Tokenizer

examples/models/llama/README.md

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,9 +239,19 @@ You can export and run the original Llama 3 8B instruct model.
239239

240240
2. Export model and generate `.pte` file
241241
```
242-
python -m examples.models.llama.export_llama --checkpoint <consolidated.00.pth> -p <params.json> -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 128 -d fp32 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --embedding-quantize 4,32 --output_name="llama3_kv_sdpa_xnn_qe_4_32.pte"
242+
python -m examples.models.llama.export_llama \
243+
--checkpoint <consolidated.00.pth> \
244+
-p <params.json> \
245+
-kv \
246+
--use_sdpa_with_kv_cache \
247+
-X \
248+
-qmode 8da4w \
249+
--group_size 128 \
250+
-d fp32 \
251+
--metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' \
252+
--embedding-quantize 4,32 \
253+
--output_name="llama3_kv_sdpa_xnn_qe_4_32.pte"
243254
```
244-
245255
Due to the larger vocabulary size of Llama 3, we recommend quantizing the embeddings with `--embedding-quantize 4,32` as shown above to further reduce the model size.
246256
247257

examples/models/llama/eval_llama_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def gen_eval_wrapper(
190190

191191
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
192192
# GPTFastEvalWrapper: Create a wrapper around a pre-exported model
193-
manager: LLMEdgeManager = _prepare_for_llama_export(model_name, args)
193+
manager: LLMEdgeManager = _prepare_for_llama_export(args)
194194

195195
if len(quantizers) != 0:
196196
manager = manager.export().pt2e_quantize(quantizers)

examples/models/llama/export_llama.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,9 @@
2020
def main() -> None:
2121
seed = 42
2222
torch.manual_seed(seed)
23-
modelname = "llama2"
2423
parser = build_args_parser()
2524
args = parser.parse_args()
26-
export_llama(modelname, args)
25+
export_llama(args)
2726

2827

2928
if __name__ == "__main__":

examples/models/llama/export_llama_lib.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@
7979
verbosity_setting = None
8080

8181

82+
EXECUTORCH_DEFINED_MODELS = ["stories110m", "llama2", "llama3", "llama3_1", "llama3_2"]
83+
TORCHTUNE_DEFINED_MODELS = []
84+
85+
8286
class WeightType(Enum):
8387
LLAMA = "LLAMA"
8488
FAIRSEQ2 = "FAIRSEQ2"
@@ -103,7 +107,7 @@ def verbose_export():
103107

104108

105109
def build_model(
106-
modelname: str = "model",
110+
modelname: str,
107111
extra_opts: str = "",
108112
*,
109113
par_local_output: bool = False,
@@ -114,11 +118,11 @@ def build_model(
114118
else:
115119
output_dir_path = "."
116120

117-
argString = f"--checkpoint par:{modelname}_ckpt.pt --params par:{modelname}_params.json {extra_opts} --output-dir {output_dir_path}"
121+
argString = f"--modelname {modelname} --checkpoint par:model_ckpt.pt --params par:model_params.json {extra_opts} --output-dir {output_dir_path}"
118122
parser = build_args_parser()
119123
args = parser.parse_args(shlex.split(argString))
120124
# pkg_name = resource_pkg_name
121-
return export_llama(modelname, args)
125+
return export_llama(args)
122126

123127

124128
def build_args_parser() -> argparse.ArgumentParser:
@@ -128,6 +132,12 @@ def build_args_parser() -> argparse.ArgumentParser:
128132
# parser.add_argument(
129133
# "-q", "--quantized_ckpt", default=None, help="quantized checkpoint file"
130134
# )
135+
parser.add_argument(
136+
"--model",
137+
default="llama3",
138+
choices=EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS,
139+
help="The Lllama model architecture to use. stories110M, llama2, llama3, llama3_1, and llama3_2 use the same underlying LlamaTransformer architecture defined in ExecuTorch. All other models use TorchTune model definitions.",
140+
)
131141
parser.add_argument(
132142
"-E",
133143
"--embedding-quantize",
@@ -465,13 +475,13 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str:
465475
return return_val
466476

467477

468-
def export_llama(modelname, args) -> str:
478+
def export_llama(args) -> str:
469479
if args.profile_path is not None:
470480
try:
471481
from executorch.util.python_profiler import CProfilerFlameGraph
472482

473483
with CProfilerFlameGraph(args.profile_path):
474-
builder = _export_llama(modelname, args)
484+
builder = _export_llama(args)
475485
assert (
476486
filename := builder.get_saved_pte_filename()
477487
) is not None, "Fail to get file name from builder"
@@ -482,14 +492,14 @@ def export_llama(modelname, args) -> str:
482492
)
483493
return ""
484494
else:
485-
builder = _export_llama(modelname, args)
495+
builder = _export_llama(args)
486496
assert (
487497
filename := builder.get_saved_pte_filename()
488498
) is not None, "Fail to get file name from builder"
489499
return filename
490500

491501

492-
def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
502+
def _prepare_for_llama_export(args) -> LLMEdgeManager:
493503
"""
494504
Helper function for export_llama. Loads the model from checkpoint and params,
495505
and sets up a LLMEdgeManager with initial transforms and dtype conversion.
@@ -515,7 +525,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
515525

516526
return (
517527
_load_llama_model(
518-
modelname=modelname,
528+
args.model,
519529
checkpoint=checkpoint_path,
520530
checkpoint_dir=checkpoint_dir,
521531
params_path=params_path,
@@ -538,7 +548,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
538548
args=args,
539549
)
540550
.set_output_dir(output_dir_path)
541-
.source_transform(_get_source_transforms(modelname, dtype_override, args))
551+
.source_transform(_get_source_transforms(args.model, dtype_override, args))
542552
)
543553

544554

@@ -612,13 +622,13 @@ def _validate_args(args):
612622
)
613623

614624

615-
def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
625+
def _export_llama(args) -> LLMEdgeManager: # noqa: C901
616626
_validate_args(args)
617627
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
618628

619629
# export_to_edge
620630
builder_exported_to_edge = (
621-
_prepare_for_llama_export(modelname, args)
631+
_prepare_for_llama_export(args)
622632
.export()
623633
.pt2e_quantize(quantizers)
624634
.export_to_edge()
@@ -804,8 +814,8 @@ def _load_llama_model_metadata(
804814

805815

806816
def _load_llama_model(
817+
modelname: str = "llama3",
807818
*,
808-
modelname: str = "llama2",
809819
checkpoint: Optional[str] = None,
810820
checkpoint_dir: Optional[str] = None,
811821
params_path: str,
@@ -833,15 +843,27 @@ def _load_llama_model(
833843
Returns:
834844
An instance of LLMEdgeManager which contains the eager mode model.
835845
"""
846+
836847
assert (
837848
checkpoint or checkpoint_dir
838849
) and params_path, "Both checkpoint/checkpoint_dir and params can't be empty"
839850
logging.info(
840851
f"Loading model with checkpoint={checkpoint}, params={params_path}, use_kv_cache={use_kv_cache}, weight_type={weight_type}"
841852
)
853+
854+
if modelname in EXECUTORCH_DEFINED_MODELS:
855+
module_name = "llama"
856+
model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py.
857+
elif modelname in TORCHTUNE_DEFINED_MODELS:
858+
raise NotImplementedError(
859+
"Torchtune Llama models are not yet supported in ExecuTorch export."
860+
)
861+
else:
862+
raise ValueError(f"{modelname} is not a valid Llama model.")
863+
842864
model, example_inputs, example_kwarg_inputs, _ = EagerModelFactory.create_model(
843-
module_name="llama",
844-
model_class_name="Llama2Model",
865+
module_name,
866+
model_class_name,
845867
checkpoint=checkpoint,
846868
checkpoint_dir=checkpoint_dir,
847869
params=params_path,

examples/models/llama/runner/eager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(self, args):
3737
model_args=model_args,
3838
device="cuda" if torch.cuda.is_available() else "cpu",
3939
)
40-
manager: LLMEdgeManager = _prepare_for_llama_export("llama", args)
40+
manager: LLMEdgeManager = _prepare_for_llama_export(args)
4141
self.model = manager.model.eval().to(device=self.device)
4242

4343
def forward(

0 commit comments

Comments
 (0)