79
79
verbosity_setting = None
80
80
81
81
82
+ EXECUTORCH_DEFINED_MODELS = ["stories110m" , "llama2" , "llama3" , "llama3_1" , "llama3_2" ]
83
+ TORCHTUNE_DEFINED_MODELS = []
84
+
85
+
82
86
class WeightType (Enum ):
83
87
LLAMA = "LLAMA"
84
88
FAIRSEQ2 = "FAIRSEQ2"
@@ -103,7 +107,7 @@ def verbose_export():
103
107
104
108
105
109
def build_model (
106
- modelname : str = "model" ,
110
+ modelname : str ,
107
111
extra_opts : str = "" ,
108
112
* ,
109
113
par_local_output : bool = False ,
@@ -114,11 +118,11 @@ def build_model(
114
118
else :
115
119
output_dir_path = "."
116
120
117
- argString = f"--checkpoint par: { modelname } _ckpt .pt --params par:{ modelname } _params .json { extra_opts } --output-dir { output_dir_path } "
121
+ argString = f"--modelname { modelname } --checkpoint par:model_ckpt .pt --params par:model_params .json { extra_opts } --output-dir { output_dir_path } "
118
122
parser = build_args_parser ()
119
123
args = parser .parse_args (shlex .split (argString ))
120
124
# pkg_name = resource_pkg_name
121
- return export_llama (modelname , args )
125
+ return export_llama (args )
122
126
123
127
124
128
def build_args_parser () -> argparse .ArgumentParser :
@@ -128,6 +132,12 @@ def build_args_parser() -> argparse.ArgumentParser:
128
132
# parser.add_argument(
129
133
# "-q", "--quantized_ckpt", default=None, help="quantized checkpoint file"
130
134
# )
135
+ parser .add_argument (
136
+ "--model" ,
137
+ default = "llama3" ,
138
+ choices = EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS ,
139
+ help = "The Lllama model architecture to use. stories110M, llama2, llama3, llama3_1, and llama3_2 use the same underlying LlamaTransformer architecture defined in ExecuTorch. All other models use TorchTune model definitions." ,
140
+ )
131
141
parser .add_argument (
132
142
"-E" ,
133
143
"--embedding-quantize" ,
@@ -465,13 +475,13 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str:
465
475
return return_val
466
476
467
477
468
- def export_llama (modelname , args ) -> str :
478
+ def export_llama (args ) -> str :
469
479
if args .profile_path is not None :
470
480
try :
471
481
from executorch .util .python_profiler import CProfilerFlameGraph
472
482
473
483
with CProfilerFlameGraph (args .profile_path ):
474
- builder = _export_llama (modelname , args )
484
+ builder = _export_llama (args )
475
485
assert (
476
486
filename := builder .get_saved_pte_filename ()
477
487
) is not None , "Fail to get file name from builder"
@@ -482,14 +492,14 @@ def export_llama(modelname, args) -> str:
482
492
)
483
493
return ""
484
494
else :
485
- builder = _export_llama (modelname , args )
495
+ builder = _export_llama (args )
486
496
assert (
487
497
filename := builder .get_saved_pte_filename ()
488
498
) is not None , "Fail to get file name from builder"
489
499
return filename
490
500
491
501
492
- def _prepare_for_llama_export (modelname : str , args ) -> LLMEdgeManager :
502
+ def _prepare_for_llama_export (args ) -> LLMEdgeManager :
493
503
"""
494
504
Helper function for export_llama. Loads the model from checkpoint and params,
495
505
and sets up a LLMEdgeManager with initial transforms and dtype conversion.
@@ -515,7 +525,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
515
525
516
526
return (
517
527
_load_llama_model (
518
- modelname = modelname ,
528
+ args . model ,
519
529
checkpoint = checkpoint_path ,
520
530
checkpoint_dir = checkpoint_dir ,
521
531
params_path = params_path ,
@@ -538,7 +548,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
538
548
args = args ,
539
549
)
540
550
.set_output_dir (output_dir_path )
541
- .source_transform (_get_source_transforms (modelname , dtype_override , args ))
551
+ .source_transform (_get_source_transforms (args . model , dtype_override , args ))
542
552
)
543
553
544
554
@@ -612,13 +622,13 @@ def _validate_args(args):
612
622
)
613
623
614
624
615
- def _export_llama (modelname , args ) -> LLMEdgeManager : # noqa: C901
625
+ def _export_llama (args ) -> LLMEdgeManager : # noqa: C901
616
626
_validate_args (args )
617
627
pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (args )
618
628
619
629
# export_to_edge
620
630
builder_exported_to_edge = (
621
- _prepare_for_llama_export (modelname , args )
631
+ _prepare_for_llama_export (args )
622
632
.export ()
623
633
.pt2e_quantize (quantizers )
624
634
.export_to_edge ()
@@ -804,8 +814,8 @@ def _load_llama_model_metadata(
804
814
805
815
806
816
def _load_llama_model (
817
+ modelname : str = "llama3" ,
807
818
* ,
808
- modelname : str = "llama2" ,
809
819
checkpoint : Optional [str ] = None ,
810
820
checkpoint_dir : Optional [str ] = None ,
811
821
params_path : str ,
@@ -833,15 +843,27 @@ def _load_llama_model(
833
843
Returns:
834
844
An instance of LLMEdgeManager which contains the eager mode model.
835
845
"""
846
+
836
847
assert (
837
848
checkpoint or checkpoint_dir
838
849
) and params_path , "Both checkpoint/checkpoint_dir and params can't be empty"
839
850
logging .info (
840
851
f"Loading model with checkpoint={ checkpoint } , params={ params_path } , use_kv_cache={ use_kv_cache } , weight_type={ weight_type } "
841
852
)
853
+
854
+ if modelname in EXECUTORCH_DEFINED_MODELS :
855
+ module_name = "llama"
856
+ model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py.
857
+ elif modelname in TORCHTUNE_DEFINED_MODELS :
858
+ raise NotImplementedError (
859
+ "Torchtune Llama models are not yet supported in ExecuTorch export."
860
+ )
861
+ else :
862
+ raise ValueError (f"{ modelname } is not a valid Llama model." )
863
+
842
864
model , example_inputs , example_kwarg_inputs , _ = EagerModelFactory .create_model (
843
- module_name = "llama" ,
844
- model_class_name = "Llama2Model" ,
865
+ module_name ,
866
+ model_class_name ,
845
867
checkpoint = checkpoint ,
846
868
checkpoint_dir = checkpoint_dir ,
847
869
params = params_path ,
0 commit comments