Skip to content

Commit b1d2327

Browse files
jackzhxngfacebook-github-bot
authored andcommitted
Accept model type parameter in export_llama (#6507)
Summary: Specify model to export in the CLI. Test Plan: Exported the stories 110M model. ``` python -m examples.models.llama.export_llama -c stories110M/stories110M.pt -p stories110M/params.json -X -kv ``` PR chain: - [Add kwarg example inputs to eager model base](#5765) - [Llama2 model cleanup](#5859) - **YOU ARE HERE ~>** [Accept model type parameter in export_llama](#5910) - [Export TorchTune llama3_2_vision in ET](#5911) - [Runner changes for TorchTune Llama3.2 vision text decoder](#6610) - [Add et version of TorchTune MHA for swapping with custom op](#5912) Reviewed By: helunwencser Differential Revision: D65612837 Pulled By: dvorjackz
1 parent 4947e27 commit b1d2327

File tree

7 files changed

+53
-22
lines changed

7 files changed

+53
-22
lines changed

docs/source/llm/build-run-llama3-qualcomm-ai-engine-direct-backend.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ To export Llama 3 8B instruct with the Qualcomm AI Engine Direct Backend, ensure
3939

4040
```bash
4141
# Please note that calibration_data must include the prompt template for special tokens.
42-
python -m examples.models.llama.export_llama -t <path_to_tokenizer.model>
42+
python -m examples.models.llama.export_llama -t <path_to_tokenizer.model>
4343
llama3/Meta-Llama-3-8B-Instruct/tokenizer.model -p <path_to_params.json> -c <path_to_checkpoint_for_Meta-Llama-3-8B-Instruct> --use_kv_cache --qnn --pt2e_quantize qnn_16a4w --disable_dynamic_shape --num_sharding 8 --calibration_tasks wikitext --calibration_limit 1 --calibration_seq_length 128 --optimized_rotation_path <path_to_optimized_matrix> --calibration_data "<|start_header_id|>system<|end_header_id|>\n\nYou are a funny chatbot.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nCould you tell me about Facebook?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
4444
```
4545

examples/demo-apps/android/LlamaDemo/docs/delegates/qualcomm_README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ To export Llama 3 8B instruct with the Qualcomm AI Engine Direct Backend, ensure
158158
* 8B models might need 16GB RAM on the device to run.
159159
```
160160
# Please note that calibration_data must include the prompt template for special tokens.
161-
python -m examples.models.llama.export_llama -t <path_to_tokenizer.model> -p <path_to_params.json> -c <path_to_checkpoint_for_Meta-Llama-3-8B-Instruct> --use_kv_cache --qnn --pt2e_quantize qnn_16a4w --disable_dynamic_shape --num_sharding 8 --calibration_tasks wikitext --calibration_limit 1 --calibration_seq_length 128 --optimized_rotation_path <path_to_optimized_matrix> --calibration_data "<|start_header_id|>system<|end_header_id|>\n\nYou are a funny chatbot.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nCould you tell me about Facebook?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
161+
python -m examples.models.llama.export_llama -t <path_to_tokenizer.model> -p <path_to_params.json> -c <path_to_checkpoint_for_Meta-Llama-3-8B-Instruct> --use_kv_cache --qnn --pt2e_quantize qnn_16a4w --disable_dynamic_shape --num_sharding 8 --calibration_tasks wikitext --calibration_limit 1 --calibration_seq_length 128 --optimized_rotation_path <path_to_optimized_matrix> --calibration_data "<|start_header_id|>system<|end_header_id|>\n\nYou are a funny chatbot.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nCould you tell me about Facebook?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
162162
```
163163

164164
## Pushing Model and Tokenizer

examples/models/llama/README.md

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,9 +239,19 @@ You can export and run the original Llama 3 8B instruct model.
239239

240240
2. Export model and generate `.pte` file
241241
```
242-
python -m examples.models.llama.export_llama --checkpoint <consolidated.00.pth> -p <params.json> -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 128 -d fp32 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --embedding-quantize 4,32 --output_name="llama3_kv_sdpa_xnn_qe_4_32.pte"
242+
python -m examples.models.llama.export_llama \
243+
--checkpoint <consolidated.00.pth> \
244+
-p <params.json> \
245+
-kv \
246+
--use_sdpa_with_kv_cache \
247+
-X \
248+
-qmode 8da4w \
249+
--group_size 128 \
250+
-d fp32 \
251+
--metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' \
252+
--embedding-quantize 4,32 \
253+
--output_name="llama3_kv_sdpa_xnn_qe_4_32.pte"
243254
```
244-
245255
Due to the larger vocabulary size of Llama 3, we recommend quantizing the embeddings with `--embedding-quantize 4,32` as shown above to further reduce the model size.
246256
247257

examples/models/llama/eval_llama_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def gen_eval_wrapper(
190190

191191
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
192192
# GPTFastEvalWrapper: Create a wrapper around a pre-exported model
193-
manager: LLMEdgeManager = _prepare_for_llama_export(model_name, args)
193+
manager: LLMEdgeManager = _prepare_for_llama_export(args)
194194

195195
if len(quantizers) != 0:
196196
manager = manager.export().pt2e_quantize(quantizers)

examples/models/llama/export_llama.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,9 @@
2323
def main() -> None:
2424
seed = 42
2525
torch.manual_seed(seed)
26-
modelname = "llama2"
2726
parser = build_args_parser()
2827
args = parser.parse_args()
29-
export_llama(modelname, args)
28+
export_llama(args)
3029

3130

3231
if __name__ == "__main__":

examples/models/llama/export_llama_lib.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@
8181
verbosity_setting = None
8282

8383

84+
EXECUTORCH_DEFINED_MODELS = ["stories110m", "llama2", "llama3", "llama3_1", "llama3_2"]
85+
TORCHTUNE_DEFINED_MODELS = []
86+
87+
8488
class WeightType(Enum):
8589
LLAMA = "LLAMA"
8690
FAIRSEQ2 = "FAIRSEQ2"
@@ -105,7 +109,7 @@ def verbose_export():
105109

106110

107111
def build_model(
108-
modelname: str = "model",
112+
modelname: str = "llama3",
109113
extra_opts: str = "",
110114
*,
111115
par_local_output: bool = False,
@@ -116,11 +120,11 @@ def build_model(
116120
else:
117121
output_dir_path = "."
118122

119-
argString = f"--checkpoint par:{modelname}_ckpt.pt --params par:{modelname}_params.json {extra_opts} --output-dir {output_dir_path}"
123+
argString = f"--model {modelname} --checkpoint par:model_ckpt.pt --params par:model_params.json {extra_opts} --output-dir {output_dir_path}"
120124
parser = build_args_parser()
121125
args = parser.parse_args(shlex.split(argString))
122126
# pkg_name = resource_pkg_name
123-
return export_llama(modelname, args)
127+
return export_llama(args)
124128

125129

126130
def build_args_parser() -> argparse.ArgumentParser:
@@ -130,6 +134,12 @@ def build_args_parser() -> argparse.ArgumentParser:
130134
# parser.add_argument(
131135
# "-q", "--quantized_ckpt", default=None, help="quantized checkpoint file"
132136
# )
137+
parser.add_argument(
138+
"--model",
139+
default="llama3",
140+
choices=EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS,
141+
help="The Lllama model architecture to use. stories110M, llama2, llama3, llama3_1, and llama3_2 use the same underlying LlamaTransformer architecture defined in ExecuTorch. All other models use TorchTune model definitions.",
142+
)
133143
parser.add_argument(
134144
"-E",
135145
"--embedding-quantize",
@@ -473,13 +483,13 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str:
473483
return return_val
474484

475485

476-
def export_llama(modelname, args) -> str:
486+
def export_llama(args) -> str:
477487
if args.profile_path is not None:
478488
try:
479489
from executorch.util.python_profiler import CProfilerFlameGraph
480490

481491
with CProfilerFlameGraph(args.profile_path):
482-
builder = _export_llama(modelname, args)
492+
builder = _export_llama(args)
483493
assert (
484494
filename := builder.get_saved_pte_filename()
485495
) is not None, "Fail to get file name from builder"
@@ -490,14 +500,14 @@ def export_llama(modelname, args) -> str:
490500
)
491501
return ""
492502
else:
493-
builder = _export_llama(modelname, args)
503+
builder = _export_llama(args)
494504
assert (
495505
filename := builder.get_saved_pte_filename()
496506
) is not None, "Fail to get file name from builder"
497507
return filename
498508

499509

500-
def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
510+
def _prepare_for_llama_export(args) -> LLMEdgeManager:
501511
"""
502512
Helper function for export_llama. Loads the model from checkpoint and params,
503513
and sets up a LLMEdgeManager with initial transforms and dtype conversion.
@@ -523,7 +533,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
523533

524534
return (
525535
_load_llama_model(
526-
modelname=modelname,
536+
args.model,
527537
checkpoint=checkpoint_path,
528538
checkpoint_dir=checkpoint_dir,
529539
params_path=params_path,
@@ -546,7 +556,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
546556
args=args,
547557
)
548558
.set_output_dir(output_dir_path)
549-
.source_transform(_get_source_transforms(modelname, dtype_override, args))
559+
.source_transform(_get_source_transforms(args.model, dtype_override, args))
550560
)
551561

552562

@@ -620,13 +630,13 @@ def _validate_args(args):
620630
)
621631

622632

623-
def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
633+
def _export_llama(args) -> LLMEdgeManager: # noqa: C901
624634
_validate_args(args)
625635
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
626636

627637
# export_to_edge
628638
builder_exported_to_edge = (
629-
_prepare_for_llama_export(modelname, args)
639+
_prepare_for_llama_export(args)
630640
.export()
631641
.pt2e_quantize(quantizers)
632642
.export_to_edge()
@@ -821,8 +831,8 @@ def _load_llama_model_metadata(
821831

822832

823833
def _load_llama_model(
834+
modelname: str = "llama3",
824835
*,
825-
modelname: str = "llama2",
826836
checkpoint: Optional[str] = None,
827837
checkpoint_dir: Optional[str] = None,
828838
params_path: str,
@@ -850,15 +860,27 @@ def _load_llama_model(
850860
Returns:
851861
An instance of LLMEdgeManager which contains the eager mode model.
852862
"""
863+
853864
assert (
854865
checkpoint or checkpoint_dir
855866
) and params_path, "Both checkpoint/checkpoint_dir and params can't be empty"
856867
logging.info(
857868
f"Loading model with checkpoint={checkpoint}, params={params_path}, use_kv_cache={use_kv_cache}, weight_type={weight_type}"
858869
)
870+
871+
if modelname in EXECUTORCH_DEFINED_MODELS:
872+
module_name = "llama"
873+
model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py.
874+
elif modelname in TORCHTUNE_DEFINED_MODELS:
875+
raise NotImplementedError(
876+
"Torchtune Llama models are not yet supported in ExecuTorch export."
877+
)
878+
else:
879+
raise ValueError(f"{modelname} is not a valid Llama model.")
880+
859881
model, example_inputs, example_kwarg_inputs, _ = EagerModelFactory.create_model(
860-
module_name="llama",
861-
model_class_name="Llama2Model",
882+
module_name,
883+
model_class_name,
862884
checkpoint=checkpoint,
863885
checkpoint_dir=checkpoint_dir,
864886
params=params_path,

examples/models/llama/runner/eager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(self, args):
3737
model_args=model_args,
3838
device="cuda" if torch.cuda.is_available() else "cpu",
3939
)
40-
manager: LLMEdgeManager = _prepare_for_llama_export("llama", args)
40+
manager: LLMEdgeManager = _prepare_for_llama_export(args)
4141
self.model = manager.model.eval().to(device=self.device)
4242

4343
def forward(

0 commit comments

Comments
 (0)