78
78
verbosity_setting = None
79
79
80
80
81
+ EXECUTORCH_DEFINED_MODELS = ["llama2" , "llama3" , "llama3_1" , "llama3_2" ]
82
+ TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision" ]
83
+
84
+
81
85
class WeightType (Enum ):
82
86
LLAMA = "LLAMA"
83
87
FAIRSEQ2 = "FAIRSEQ2"
@@ -113,11 +117,11 @@ def build_model(
113
117
else :
114
118
output_dir_path = "."
115
119
116
- argString = f"--checkpoint par:{ modelname } _ckpt.pt --params par:{ modelname } _params.json { extra_opts } --output-dir { output_dir_path } "
120
+ argString = f"--model { modelname } -- checkpoint par:{ modelname } _ckpt.pt --params par:{ modelname } _params.json { extra_opts } --output-dir { output_dir_path } "
117
121
parser = build_args_parser ()
118
122
args = parser .parse_args (shlex .split (argString ))
119
123
# pkg_name = resource_pkg_name
120
- return export_llama (modelname , args )
124
+ return export_llama (args )
121
125
122
126
123
127
def build_args_parser () -> argparse .ArgumentParser :
@@ -127,6 +131,12 @@ def build_args_parser() -> argparse.ArgumentParser:
127
131
# parser.add_argument(
128
132
# "-q", "--quantized_ckpt", default=None, help="quantized checkpoint file"
129
133
# )
134
+ parser .add_argument (
135
+ "--model" ,
136
+ default = "llama2" ,
137
+ choices = EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS ,
138
+ help = "The Lllama model to export. llama2, llama3, llama3_1, llama3_2 share the same architecture, so they are technically interchangeable, given you provide the checkpoint file for the desired version." ,
139
+ )
130
140
parser .add_argument (
131
141
"-E" ,
132
142
"--embedding-quantize" ,
@@ -458,13 +468,13 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str:
458
468
return return_val
459
469
460
470
461
- def export_llama (modelname , args ) -> str :
471
+ def export_llama (args ) -> str :
462
472
if args .profile_path is not None :
463
473
try :
464
474
from executorch .util .python_profiler import CProfilerFlameGraph
465
475
466
476
with CProfilerFlameGraph (args .profile_path ):
467
- builder = _export_llama (modelname , args )
477
+ builder = _export_llama (args )
468
478
assert (
469
479
filename := builder .get_saved_pte_filename ()
470
480
) is not None , "Fail to get file name from builder"
@@ -475,14 +485,14 @@ def export_llama(modelname, args) -> str:
475
485
)
476
486
return ""
477
487
else :
478
- builder = _export_llama (modelname , args )
488
+ builder = _export_llama (args )
479
489
assert (
480
490
filename := builder .get_saved_pte_filename ()
481
491
) is not None , "Fail to get file name from builder"
482
492
return filename
483
493
484
494
485
- def _prepare_for_llama_export (modelname : str , args ) -> LLMEdgeManager :
495
+ def _prepare_for_llama_export (args ) -> LLMEdgeManager :
486
496
"""
487
497
Helper function for export_llama. Loads the model from checkpoint and params,
488
498
and sets up a LLMEdgeManager with initial transforms and dtype conversion.
@@ -508,7 +518,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
508
518
509
519
return (
510
520
_load_llama_model (
511
- modelname = modelname ,
521
+ args . model ,
512
522
checkpoint = checkpoint_path ,
513
523
checkpoint_dir = checkpoint_dir ,
514
524
params_path = params_path ,
@@ -530,7 +540,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
530
540
args = args ,
531
541
)
532
542
.set_output_dir (output_dir_path )
533
- .source_transform (_get_source_transforms (modelname , dtype_override , args ))
543
+ .source_transform (_get_source_transforms (args . model , dtype_override , args ))
534
544
)
535
545
536
546
@@ -574,13 +584,13 @@ def _validate_args(args):
574
584
raise ValueError ("Model shard is only supported with qnn backend now." )
575
585
576
586
577
- def _export_llama (modelname , args ) -> LLMEdgeManager : # noqa: C901
587
+ def _export_llama (args ) -> LLMEdgeManager : # noqa: C901
578
588
_validate_args (args )
579
589
pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (args )
580
590
581
591
# export_to_edge
582
592
builder_exported_to_edge = (
583
- _prepare_for_llama_export (modelname , args )
593
+ _prepare_for_llama_export (args )
584
594
.capture_pre_autograd_graph ()
585
595
.pt2e_quantize (quantizers )
586
596
.export_to_edge ()
@@ -748,8 +758,8 @@ def _load_llama_model_metadata(
748
758
749
759
750
760
def _load_llama_model (
761
+ modelname : str ,
751
762
* ,
752
- modelname : str = "llama2" ,
753
763
checkpoint : Optional [str ] = None ,
754
764
checkpoint_dir : Optional [str ] = None ,
755
765
params_path : str ,
@@ -776,26 +786,41 @@ def _load_llama_model(
776
786
Returns:
777
787
An instance of LLMEdgeManager which contains the eager mode model.
778
788
"""
789
+
779
790
assert (
780
791
checkpoint or checkpoint_dir
781
792
) and params_path , "Both checkpoint/checkpoint_dir and params can't be empty"
782
793
logging .info (
783
794
f"Loading model with checkpoint={ checkpoint } , params={ params_path } , use_kv_cache={ use_kv_cache } , weight_type={ weight_type } "
784
795
)
785
- model , example_inputs , example_kwarg_inputs , _ = EagerModelFactory .create_model (
786
- "llama2" ,
787
- "Llama2Model" ,
788
- checkpoint = checkpoint ,
789
- checkpoint_dir = checkpoint_dir ,
790
- params = params_path ,
791
- use_kv_cache = use_kv_cache ,
792
- use_sdpa_with_kv_cache = use_sdpa_with_kv_cache ,
793
- generate_full_logits = generate_full_logits ,
794
- fairseq2 = weight_type == WeightType .FAIRSEQ2 ,
795
- max_seq_len = max_seq_len ,
796
- enable_dynamic_shape = enable_dynamic_shape ,
797
- output_prune_map_path = output_prune_map_path ,
798
- args = args ,
796
+
797
+ if modelname in EXECUTORCH_DEFINED_MODELS :
798
+ # Set to llama2 because all models in EXECUTORCH_DEFINED_MODELS share the same archteciture as
799
+ # defined in example/models/llama2.
800
+ modelname = "llama2"
801
+ model_class_name = "Llama2Model"
802
+ elif modelname in TORCHTUNE_DEFINED_MODELS :
803
+ if modelname == "llama3_2_vision" :
804
+ model_class_name = "Llama3_2Decoder"
805
+ else :
806
+ raise ValueError (f"{ modelname } is not a valid Llama model." )
807
+
808
+ model , example_inputs , example_kwarg_inputs , _ = (
809
+ EagerModelFactory .create_model (
810
+ modelname ,
811
+ model_class_name ,
812
+ checkpoint = checkpoint ,
813
+ checkpoint_dir = checkpoint_dir ,
814
+ params = params_path ,
815
+ use_kv_cache = use_kv_cache ,
816
+ use_sdpa_with_kv_cache = use_sdpa_with_kv_cache ,
817
+ generate_full_logits = generate_full_logits ,
818
+ fairseq2 = weight_type == WeightType .FAIRSEQ2 ,
819
+ max_seq_len = max_seq_len ,
820
+ enable_dynamic_shape = enable_dynamic_shape ,
821
+ output_prune_map_path = output_prune_map_path ,
822
+ args = args ,
823
+ )
799
824
)
800
825
if dtype_override :
801
826
assert isinstance (
0 commit comments