77
77
verbosity_setting = None
78
78
79
79
80
+ EXECUTORCH_DEFINED_MODELS = ["llama2" , "llama3" , "llama3_1" , "llama3_2" ]
81
+ TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision" ]
82
+
83
+
80
84
class WeightType (Enum ):
81
85
LLAMA = "LLAMA"
82
86
FAIRSEQ2 = "FAIRSEQ2"
@@ -112,11 +116,11 @@ def build_model(
112
116
else :
113
117
output_dir_path = "."
114
118
115
- argString = f"--checkpoint par:{ modelname } _ckpt.pt --params par:{ modelname } _params.json { extra_opts } --output-dir { output_dir_path } "
119
+ argString = f"--model { modelname } -- checkpoint par:{ modelname } _ckpt.pt --params par:{ modelname } _params.json { extra_opts } --output-dir { output_dir_path } "
116
120
parser = build_args_parser ()
117
121
args = parser .parse_args (shlex .split (argString ))
118
122
# pkg_name = resource_pkg_name
119
- return export_llama (modelname , args )
123
+ return export_llama (args )
120
124
121
125
122
126
def build_args_parser () -> argparse .ArgumentParser :
@@ -126,6 +130,12 @@ def build_args_parser() -> argparse.ArgumentParser:
126
130
# parser.add_argument(
127
131
# "-q", "--quantized_ckpt", default=None, help="quantized checkpoint file"
128
132
# )
133
+ parser .add_argument (
134
+ "--model" ,
135
+ default = "llama2" ,
136
+ choices = EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS ,
137
+ help = "The Lllama model to export. llama2, llama3, llama3_1, llama3_2 share the same architecture, so they are technically interchangeable, given you provide the checkpoint file for the desired version." ,
138
+ )
129
139
parser .add_argument (
130
140
"-E" ,
131
141
"--embedding-quantize" ,
@@ -456,13 +466,13 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str:
456
466
return return_val
457
467
458
468
459
- def export_llama (modelname , args ) -> str :
469
+ def export_llama (args ) -> str :
460
470
if args .profile_path is not None :
461
471
try :
462
472
from executorch .util .python_profiler import CProfilerFlameGraph
463
473
464
474
with CProfilerFlameGraph (args .profile_path ):
465
- builder = _export_llama (modelname , args )
475
+ builder = _export_llama (args )
466
476
assert (
467
477
filename := builder .get_saved_pte_filename ()
468
478
) is not None , "Fail to get file name from builder"
@@ -473,14 +483,14 @@ def export_llama(modelname, args) -> str:
473
483
)
474
484
return ""
475
485
else :
476
- builder = _export_llama (modelname , args )
486
+ builder = _export_llama (args )
477
487
assert (
478
488
filename := builder .get_saved_pte_filename ()
479
489
) is not None , "Fail to get file name from builder"
480
490
return filename
481
491
482
492
483
- def _prepare_for_llama_export (modelname : str , args ) -> LLMEdgeManager :
493
+ def _prepare_for_llama_export (args ) -> LLMEdgeManager :
484
494
"""
485
495
Helper function for export_llama. Loads the model from checkpoint and params,
486
496
and sets up a LLMEdgeManager with initial transforms and dtype conversion.
@@ -506,7 +516,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
506
516
507
517
return (
508
518
_load_llama_model (
509
- modelname = modelname ,
519
+ args . model ,
510
520
checkpoint = checkpoint_path ,
511
521
checkpoint_dir = checkpoint_dir ,
512
522
params_path = params_path ,
@@ -528,7 +538,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
528
538
args = args ,
529
539
)
530
540
.set_output_dir (output_dir_path )
531
- .source_transform (_get_source_transforms (modelname , dtype_override , args ))
541
+ .source_transform (_get_source_transforms (args . model , dtype_override , args ))
532
542
)
533
543
534
544
@@ -566,13 +576,13 @@ def _validate_args(args):
566
576
raise ValueError ("Model shard is only supported with qnn backend now." )
567
577
568
578
569
- def _export_llama (modelname , args ) -> LLMEdgeManager : # noqa: C901
579
+ def _export_llama (args ) -> LLMEdgeManager : # noqa: C901
570
580
_validate_args (args )
571
581
pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (args )
572
582
573
583
# export_to_edge
574
584
builder_exported_to_edge = (
575
- _prepare_for_llama_export (modelname , args )
585
+ _prepare_for_llama_export (args )
576
586
.capture_pre_autograd_graph ()
577
587
.pt2e_quantize (quantizers )
578
588
.export_to_edge ()
@@ -746,8 +756,8 @@ def _load_llama_model_metadata(
746
756
747
757
748
758
def _load_llama_model (
759
+ modelname : str ,
749
760
* ,
750
- modelname : str = "llama2" ,
751
761
checkpoint : Optional [str ] = None ,
752
762
checkpoint_dir : Optional [str ] = None ,
753
763
params_path : str ,
@@ -774,26 +784,41 @@ def _load_llama_model(
774
784
Returns:
775
785
An instance of LLMEdgeManager which contains the eager mode model.
776
786
"""
787
+
777
788
assert (
778
789
checkpoint or checkpoint_dir
779
790
) and params_path , "Both checkpoint/checkpoint_dir and params can't be empty"
780
791
logging .info (
781
792
f"Loading model with checkpoint={ checkpoint } , params={ params_path } , use_kv_cache={ use_kv_cache } , weight_type={ weight_type } "
782
793
)
783
- model , example_inputs , _ = EagerModelFactory .create_model (
784
- "llama2" ,
785
- "Llama2Model" ,
786
- checkpoint = checkpoint ,
787
- checkpoint_dir = checkpoint_dir ,
788
- params = params_path ,
789
- use_kv_cache = use_kv_cache ,
790
- use_sdpa_with_kv_cache = use_sdpa_with_kv_cache ,
791
- generate_full_logits = generate_full_logits ,
792
- fairseq2 = weight_type == WeightType .FAIRSEQ2 ,
793
- max_seq_len = max_seq_len ,
794
- enable_dynamic_shape = enable_dynamic_shape ,
795
- output_prune_map_path = output_prune_map_path ,
796
- args = args ,
794
+
795
+ if modelname in EXECUTORCH_DEFINED_MODELS :
796
+ # Set to llama2 because all models in EXECUTORCH_DEFINED_MODELS share the same archteciture as
797
+ # defined in example/models/llama2.
798
+ modelname = "llama2"
799
+ model_class_name = "Llama2Model"
800
+ elif modelname in TORCHTUNE_DEFINED_MODELS :
801
+ if modelname == "llama3_2_vision" :
802
+ model_class_name = "Llama3_2Decoder"
803
+ else :
804
+ raise ValueError (f"{ modelname } is not a valid Llama model." )
805
+
806
+ model , (example_inputs , example_kwarg_inputs ), dynamic_shapes = (
807
+ EagerModelFactory .create_model (
808
+ modelname ,
809
+ model_class_name ,
810
+ checkpoint = checkpoint ,
811
+ checkpoint_dir = checkpoint_dir ,
812
+ params = params_path ,
813
+ use_kv_cache = use_kv_cache ,
814
+ use_sdpa_with_kv_cache = use_sdpa_with_kv_cache ,
815
+ generate_full_logits = generate_full_logits ,
816
+ fairseq2 = weight_type == WeightType .FAIRSEQ2 ,
817
+ max_seq_len = max_seq_len ,
818
+ enable_dynamic_shape = enable_dynamic_shape ,
819
+ output_prune_map_path = output_prune_map_path ,
820
+ args = args ,
821
+ )
797
822
)
798
823
if dtype_override :
799
824
assert isinstance (
0 commit comments