@@ -157,7 +157,8 @@ def build_model(
157
157
argString = f"--model { model } --checkpoint { checkpoint } --params { params } { extra_opts } --output-dir { output_dir } "
158
158
parser = build_args_parser ()
159
159
args = parser .parse_args (shlex .split (argString ))
160
- return export_llama (args )
160
+ llm_config = convert_args_to_llm_config (args )
161
+ return export_llama (llm_config )
161
162
162
163
163
164
def parse_list_of_ints (s ):
@@ -579,15 +580,10 @@ def export_llama(
579
580
) -> str :
580
581
if isinstance (export_options , argparse .Namespace ):
581
582
# Legacy CLI.
582
- args = export_options
583
583
llm_config = convert_args_to_llm_config (export_options )
584
584
elif isinstance (export_options , DictConfig ):
585
585
# Hydra CLI.
586
586
llm_config = export_options
587
- # Create an args object for backward compatibility during transition
588
- args = argparse .Namespace ()
589
- for key , value in llm_config .items ():
590
- setattr (args , key , value )
591
587
else :
592
588
raise ValueError (
593
589
"Input to export_llama must be either of type argparse.Namespace or LlmConfig"
@@ -626,7 +622,7 @@ def export_llama(
626
622
from executorch .util .python_profiler import CProfilerFlameGraph
627
623
628
624
with CProfilerFlameGraph (llm_config .debug .profile_path ):
629
- builder = _export_llama (llm_config , args )
625
+ builder = _export_llama (llm_config )
630
626
assert (
631
627
filename := builder .get_saved_pte_filename ()
632
628
) is not None , "Fail to get file name from builder"
@@ -637,14 +633,14 @@ def export_llama(
637
633
)
638
634
return ""
639
635
else :
640
- builder = _export_llama (llm_config , args )
636
+ builder = _export_llama (llm_config )
641
637
assert (
642
638
filename := builder .get_saved_pte_filename ()
643
639
) is not None , "Fail to get file name from builder"
644
640
return filename
645
641
646
642
647
- def _prepare_for_llama_export (llm_config , args ) -> LLMEdgeManager :
643
+ def _prepare_for_llama_export (llm_config : LlmConfig ) -> LLMEdgeManager :
648
644
"""
649
645
Helper function for export_llama. Loads the model from checkpoint and params,
650
646
and sets up a LLMEdgeManager with initial transforms and dtype conversion.
@@ -672,7 +668,7 @@ def _prepare_for_llama_export(llm_config, args) -> LLMEdgeManager:
672
668
dtype_override = DType [llm_config .model .dtype_override ]
673
669
674
670
edge_manager = _load_llama_model (
675
- llm_config . base . model_class ,
671
+ llm_config ,
676
672
checkpoint = checkpoint_path ,
677
673
checkpoint_dir = checkpoint_dir ,
678
674
params_path = params_path ,
@@ -695,7 +691,6 @@ def _prepare_for_llama_export(llm_config, args) -> LLMEdgeManager:
695
691
dtype_override = dtype_override ,
696
692
use_qnn = llm_config .backend .qnn .enabled ,
697
693
export_only = llm_config .export .export_only ,
698
- args = args ,
699
694
)
700
695
701
696
# At this point, the model is loaded in the default fp32.
@@ -1054,7 +1049,7 @@ def _to_edge_and_lower_llama( # noqa: C901
1054
1049
return builder
1055
1050
1056
1051
1057
- def _export_llama (llm_config , args ) -> LLMEdgeManager : # noqa: C901
1052
+ def _export_llama (llm_config : LlmConfig ) -> LLMEdgeManager : # noqa: C901
1058
1053
_validate_args (llm_config )
1059
1054
1060
1055
pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (
@@ -1066,7 +1061,7 @@ def _export_llama(llm_config, args) -> LLMEdgeManager: # noqa: C901
1066
1061
additional_passes = [InitializedMutableBufferPass (["kv_cache_pos" ])]
1067
1062
1068
1063
# export_to_edge
1069
- builder_exported = _prepare_for_llama_export (llm_config , args ).export ()
1064
+ builder_exported = _prepare_for_llama_export (llm_config ).export ()
1070
1065
builder_exported .run_canonical_optimizations ()
1071
1066
modelname = builder_exported .modelname
1072
1067
@@ -1174,7 +1169,7 @@ def _load_llama_model_metadata(
1174
1169
1175
1170
1176
1171
def _load_llama_model (
1177
- modelname : str = "llama3" ,
1172
+ llm_config : LlmConfig ,
1178
1173
* ,
1179
1174
checkpoint : Optional [str ] = None ,
1180
1175
checkpoint_dir : Optional [str ] = None ,
@@ -1198,8 +1193,6 @@ def _load_llama_model(
1198
1193
dtype_override : Optional [DType ] = None ,
1199
1194
use_qnn : bool = False ,
1200
1195
export_only : bool = False ,
1201
- args ,
1202
- llm_config : Optional [LlmConfig ] = None ,
1203
1196
) -> "LLMEdgeManager" :
1204
1197
"""
1205
1198
A helper util that builds a Llama2 model. It returns a LLMEdgeManager that
@@ -1208,6 +1201,7 @@ def _load_llama_model(
1208
1201
An instance of LLMEdgeManager which contains the eager mode model.
1209
1202
"""
1210
1203
1204
+ modelname = llm_config .base .model_class
1211
1205
if modelname in EXECUTORCH_DEFINED_MODELS :
1212
1206
module_name = "llama"
1213
1207
model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py.
@@ -1220,13 +1214,11 @@ def _load_llama_model(
1220
1214
else :
1221
1215
raise ValueError (f"{ modelname } is not a valid Llama model." )
1222
1216
1223
- torch_dtype = dtype_override .to_torch_dtype () if dtype_override else None
1224
-
1225
1217
model , example_inputs , example_kwarg_inputs , dynamic_shapes = (
1226
1218
EagerModelFactory .create_model (
1227
1219
module_name ,
1228
1220
model_class_name ,
1229
- model_args = { " llm_config" : llm_config } ,
1221
+ llm_config = llm_config ,
1230
1222
)
1231
1223
)
1232
1224
0 commit comments