81
81
verbosity_setting = None
82
82
83
83
84
+ EXECUTORCH_DEFINED_MODELS = ["stories110m" , "llama2" , "llama3" , "llama3_1" , "llama3_2" ]
85
+ TORCHTUNE_DEFINED_MODELS = []
86
+
87
+
84
88
class WeightType (Enum ):
85
89
LLAMA = "LLAMA"
86
90
FAIRSEQ2 = "FAIRSEQ2"
@@ -105,7 +109,7 @@ def verbose_export():
105
109
106
110
107
111
def build_model (
108
- modelname : str = "model " ,
112
+ modelname : str = "llama3 " ,
109
113
extra_opts : str = "" ,
110
114
* ,
111
115
par_local_output : bool = False ,
@@ -116,11 +120,11 @@ def build_model(
116
120
else :
117
121
output_dir_path = "."
118
122
119
- argString = f"--checkpoint par: { modelname } _ckpt .pt --params par:{ modelname } _params .json { extra_opts } --output-dir { output_dir_path } "
123
+ argString = f"--model { modelname } --checkpoint par:model_ckpt .pt --params par:model_params .json { extra_opts } --output-dir { output_dir_path } "
120
124
parser = build_args_parser ()
121
125
args = parser .parse_args (shlex .split (argString ))
122
126
# pkg_name = resource_pkg_name
123
- return export_llama (modelname , args )
127
+ return export_llama (args )
124
128
125
129
126
130
def build_args_parser () -> argparse .ArgumentParser :
@@ -130,6 +134,12 @@ def build_args_parser() -> argparse.ArgumentParser:
130
134
# parser.add_argument(
131
135
# "-q", "--quantized_ckpt", default=None, help="quantized checkpoint file"
132
136
# )
137
+ parser .add_argument (
138
+ "--model" ,
139
+ default = "llama3" ,
140
+ choices = EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS ,
141
+ help = "The Lllama model architecture to use. stories110M, llama2, llama3, llama3_1, and llama3_2 use the same underlying LlamaTransformer architecture defined in ExecuTorch. All other models use TorchTune model definitions." ,
142
+ )
133
143
parser .add_argument (
134
144
"-E" ,
135
145
"--embedding-quantize" ,
@@ -480,13 +490,13 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str:
480
490
return return_val
481
491
482
492
483
- def export_llama (modelname , args ) -> str :
493
+ def export_llama (args ) -> str :
484
494
if args .profile_path is not None :
485
495
try :
486
496
from executorch .util .python_profiler import CProfilerFlameGraph
487
497
488
498
with CProfilerFlameGraph (args .profile_path ):
489
- builder = _export_llama (modelname , args )
499
+ builder = _export_llama (args )
490
500
assert (
491
501
filename := builder .get_saved_pte_filename ()
492
502
) is not None , "Fail to get file name from builder"
@@ -497,14 +507,14 @@ def export_llama(modelname, args) -> str:
497
507
)
498
508
return ""
499
509
else :
500
- builder = _export_llama (modelname , args )
510
+ builder = _export_llama (args )
501
511
assert (
502
512
filename := builder .get_saved_pte_filename ()
503
513
) is not None , "Fail to get file name from builder"
504
514
return filename
505
515
506
516
507
- def _prepare_for_llama_export (modelname : str , args ) -> LLMEdgeManager :
517
+ def _prepare_for_llama_export (args ) -> LLMEdgeManager :
508
518
"""
509
519
Helper function for export_llama. Loads the model from checkpoint and params,
510
520
and sets up a LLMEdgeManager with initial transforms and dtype conversion.
@@ -530,7 +540,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
530
540
531
541
return (
532
542
_load_llama_model (
533
- modelname = modelname ,
543
+ args . model ,
534
544
checkpoint = checkpoint_path ,
535
545
checkpoint_dir = checkpoint_dir ,
536
546
params_path = params_path ,
@@ -553,7 +563,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
553
563
args = args ,
554
564
)
555
565
.set_output_dir (output_dir_path )
556
- .source_transform (_get_source_transforms (modelname , dtype_override , args ))
566
+ .source_transform (_get_source_transforms (args . model , dtype_override , args ))
557
567
)
558
568
559
569
@@ -627,12 +637,12 @@ def _validate_args(args):
627
637
)
628
638
629
639
630
- def _export_llama (modelname , args ) -> LLMEdgeManager : # noqa: C901
640
+ def _export_llama (args ) -> LLMEdgeManager : # noqa: C901
631
641
_validate_args (args )
632
642
pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (args )
633
643
634
644
# export_to_edge
635
- builder_exported = _prepare_for_llama_export (modelname , args ).export ()
645
+ builder_exported = _prepare_for_llama_export (args ).export ()
636
646
637
647
if args .export_only :
638
648
exit ()
@@ -830,8 +840,8 @@ def _load_llama_model_metadata(
830
840
831
841
832
842
def _load_llama_model (
843
+ modelname : str = "llama3" ,
833
844
* ,
834
- modelname : str = "llama2" ,
835
845
checkpoint : Optional [str ] = None ,
836
846
checkpoint_dir : Optional [str ] = None ,
837
847
params_path : str ,
@@ -859,15 +869,27 @@ def _load_llama_model(
859
869
Returns:
860
870
An instance of LLMEdgeManager which contains the eager mode model.
861
871
"""
872
+
862
873
assert (
863
874
checkpoint or checkpoint_dir
864
875
) and params_path , "Both checkpoint/checkpoint_dir and params can't be empty"
865
876
logging .info (
866
877
f"Loading model with checkpoint={ checkpoint } , params={ params_path } , use_kv_cache={ use_kv_cache } , weight_type={ weight_type } "
867
878
)
879
+
880
+ if modelname in EXECUTORCH_DEFINED_MODELS :
881
+ module_name = "llama"
882
+ model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py.
883
+ elif modelname in TORCHTUNE_DEFINED_MODELS :
884
+ raise NotImplementedError (
885
+ "Torchtune Llama models are not yet supported in ExecuTorch export."
886
+ )
887
+ else :
888
+ raise ValueError (f"{ modelname } is not a valid Llama model." )
889
+
868
890
model , example_inputs , example_kwarg_inputs , _ = EagerModelFactory .create_model (
869
- module_name = "llama" ,
870
- model_class_name = "Llama2Model" ,
891
+ module_name ,
892
+ model_class_name ,
871
893
checkpoint = checkpoint ,
872
894
checkpoint_dir = checkpoint_dir ,
873
895
params = params_path ,
0 commit comments