81
81
verbosity_setting = None
82
82
83
83
84
+ EXECUTORCH_DEFINED_MODELS = ["stories110m" , "llama2" , "llama3" , "llama3_1" , "llama3_2" ]
85
+ TORCHTUNE_DEFINED_MODELS = []
86
+
87
+
84
88
class WeightType (Enum ):
85
89
LLAMA = "LLAMA"
86
90
FAIRSEQ2 = "FAIRSEQ2"
@@ -105,7 +109,7 @@ def verbose_export():
105
109
106
110
107
111
def build_model (
108
- modelname : str = "model " ,
112
+ modelname : str = "llama3 " ,
109
113
extra_opts : str = "" ,
110
114
* ,
111
115
par_local_output : bool = False ,
@@ -116,11 +120,11 @@ def build_model(
116
120
else :
117
121
output_dir_path = "."
118
122
119
- argString = f"--checkpoint par: { modelname } _ckpt .pt --params par:{ modelname } _params .json { extra_opts } --output-dir { output_dir_path } "
123
+ argString = f"--model { modelname } --checkpoint par:model_ckpt .pt --params par:model_params .json { extra_opts } --output-dir { output_dir_path } "
120
124
parser = build_args_parser ()
121
125
args = parser .parse_args (shlex .split (argString ))
122
126
# pkg_name = resource_pkg_name
123
- return export_llama (modelname , args )
127
+ return export_llama (args )
124
128
125
129
126
130
def build_args_parser () -> argparse .ArgumentParser :
@@ -130,6 +134,12 @@ def build_args_parser() -> argparse.ArgumentParser:
130
134
# parser.add_argument(
131
135
# "-q", "--quantized_ckpt", default=None, help="quantized checkpoint file"
132
136
# )
137
+ parser .add_argument (
138
+ "--model" ,
139
+ default = "llama3" ,
140
+ choices = EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS ,
141
+ help = "The Lllama model architecture to use. stories110M, llama2, llama3, llama3_1, and llama3_2 use the same underlying LlamaTransformer architecture defined in ExecuTorch. All other models use TorchTune model definitions." ,
142
+ )
133
143
parser .add_argument (
134
144
"-E" ,
135
145
"--embedding-quantize" ,
@@ -473,13 +483,13 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str:
473
483
return return_val
474
484
475
485
476
- def export_llama (modelname , args ) -> str :
486
+ def export_llama (args ) -> str :
477
487
if args .profile_path is not None :
478
488
try :
479
489
from executorch .util .python_profiler import CProfilerFlameGraph
480
490
481
491
with CProfilerFlameGraph (args .profile_path ):
482
- builder = _export_llama (modelname , args )
492
+ builder = _export_llama (args )
483
493
assert (
484
494
filename := builder .get_saved_pte_filename ()
485
495
) is not None , "Fail to get file name from builder"
@@ -490,14 +500,14 @@ def export_llama(modelname, args) -> str:
490
500
)
491
501
return ""
492
502
else :
493
- builder = _export_llama (modelname , args )
503
+ builder = _export_llama (args )
494
504
assert (
495
505
filename := builder .get_saved_pte_filename ()
496
506
) is not None , "Fail to get file name from builder"
497
507
return filename
498
508
499
509
500
- def _prepare_for_llama_export (modelname : str , args ) -> LLMEdgeManager :
510
+ def _prepare_for_llama_export (args ) -> LLMEdgeManager :
501
511
"""
502
512
Helper function for export_llama. Loads the model from checkpoint and params,
503
513
and sets up a LLMEdgeManager with initial transforms and dtype conversion.
@@ -523,7 +533,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
523
533
524
534
return (
525
535
_load_llama_model (
526
- modelname = modelname ,
536
+ args . model ,
527
537
checkpoint = checkpoint_path ,
528
538
checkpoint_dir = checkpoint_dir ,
529
539
params_path = params_path ,
@@ -546,7 +556,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
546
556
args = args ,
547
557
)
548
558
.set_output_dir (output_dir_path )
549
- .source_transform (_get_source_transforms (modelname , dtype_override , args ))
559
+ .source_transform (_get_source_transforms (args . model , dtype_override , args ))
550
560
)
551
561
552
562
@@ -620,13 +630,13 @@ def _validate_args(args):
620
630
)
621
631
622
632
623
- def _export_llama (modelname , args ) -> LLMEdgeManager : # noqa: C901
633
+ def _export_llama (args ) -> LLMEdgeManager : # noqa: C901
624
634
_validate_args (args )
625
635
pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (args )
626
636
627
637
# export_to_edge
628
638
builder_exported_to_edge = (
629
- _prepare_for_llama_export (modelname , args )
639
+ _prepare_for_llama_export (args )
630
640
.export ()
631
641
.pt2e_quantize (quantizers )
632
642
.export_to_edge ()
@@ -821,8 +831,8 @@ def _load_llama_model_metadata(
821
831
822
832
823
833
def _load_llama_model (
834
+ modelname : str = "llama3" ,
824
835
* ,
825
- modelname : str = "llama2" ,
826
836
checkpoint : Optional [str ] = None ,
827
837
checkpoint_dir : Optional [str ] = None ,
828
838
params_path : str ,
@@ -850,15 +860,27 @@ def _load_llama_model(
850
860
Returns:
851
861
An instance of LLMEdgeManager which contains the eager mode model.
852
862
"""
863
+
853
864
assert (
854
865
checkpoint or checkpoint_dir
855
866
) and params_path , "Both checkpoint/checkpoint_dir and params can't be empty"
856
867
logging .info (
857
868
f"Loading model with checkpoint={ checkpoint } , params={ params_path } , use_kv_cache={ use_kv_cache } , weight_type={ weight_type } "
858
869
)
870
+
871
+ if modelname in EXECUTORCH_DEFINED_MODELS :
872
+ module_name = "llama"
873
+ model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py.
874
+ elif modelname in TORCHTUNE_DEFINED_MODELS :
875
+ raise NotImplementedError (
876
+ "Torchtune Llama models are not yet supported in ExecuTorch export."
877
+ )
878
+ else :
879
+ raise ValueError (f"{ modelname } is not a valid Llama model." )
880
+
859
881
model , example_inputs , example_kwarg_inputs , _ = EagerModelFactory .create_model (
860
- module_name = "llama" ,
861
- model_class_name = "Llama2Model" ,
882
+ module_name ,
883
+ model_class_name ,
862
884
checkpoint = checkpoint ,
863
885
checkpoint_dir = checkpoint_dir ,
864
886
params = params_path ,
0 commit comments