Skip to content

Commit e036de0

Browse files
committed
Completely remove args from export_llama_lib
ghstack-source-id: fe64963 Pull Request resolved: #11171
1 parent d2d79f5 commit e036de0

File tree

3 files changed

+24
-26
lines changed

3 files changed

+24
-26
lines changed

backends/arm/test/models/test_llama.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
TosaPipelineMI,
2323
)
2424

25-
from executorch.examples.models.llama.config.llm_config_utils import convert_args_to_llm_config
25+
from executorch.examples.models.llama.config.llm_config_utils import (
26+
convert_args_to_llm_config,
27+
)
2628
from executorch.examples.models.llama.export_llama_lib import (
2729
build_args_parser,
2830
get_llama_model,

examples/models/llama/export_llama_lib.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,8 @@ def build_model(
157157
argString = f"--model {model} --checkpoint {checkpoint} --params {params} {extra_opts} --output-dir {output_dir}"
158158
parser = build_args_parser()
159159
args = parser.parse_args(shlex.split(argString))
160-
return export_llama(args)
160+
llm_config = convert_args_to_llm_config(args)
161+
return export_llama(llm_config)
161162

162163

163164
def parse_list_of_ints(s):
@@ -579,15 +580,10 @@ def export_llama(
579580
) -> str:
580581
if isinstance(export_options, argparse.Namespace):
581582
# Legacy CLI.
582-
args = export_options
583583
llm_config = convert_args_to_llm_config(export_options)
584584
elif isinstance(export_options, DictConfig):
585585
# Hydra CLI.
586586
llm_config = export_options
587-
# Create an args object for backward compatibility during transition
588-
args = argparse.Namespace()
589-
for key, value in llm_config.items():
590-
setattr(args, key, value)
591587
else:
592588
raise ValueError(
593589
"Input to export_llama must be either of type argparse.Namespace or LlmConfig"
@@ -626,7 +622,7 @@ def export_llama(
626622
from executorch.util.python_profiler import CProfilerFlameGraph
627623

628624
with CProfilerFlameGraph(llm_config.debug.profile_path):
629-
builder = _export_llama(llm_config, args)
625+
builder = _export_llama(llm_config)
630626
assert (
631627
filename := builder.get_saved_pte_filename()
632628
) is not None, "Fail to get file name from builder"
@@ -637,14 +633,14 @@ def export_llama(
637633
)
638634
return ""
639635
else:
640-
builder = _export_llama(llm_config, args)
636+
builder = _export_llama(llm_config)
641637
assert (
642638
filename := builder.get_saved_pte_filename()
643639
) is not None, "Fail to get file name from builder"
644640
return filename
645641

646642

647-
def _prepare_for_llama_export(llm_config, args) -> LLMEdgeManager:
643+
def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
648644
"""
649645
Helper function for export_llama. Loads the model from checkpoint and params,
650646
and sets up a LLMEdgeManager with initial transforms and dtype conversion.
@@ -672,7 +668,7 @@ def _prepare_for_llama_export(llm_config, args) -> LLMEdgeManager:
672668
dtype_override = DType[llm_config.model.dtype_override]
673669

674670
edge_manager = _load_llama_model(
675-
llm_config.base.model_class,
671+
llm_config,
676672
checkpoint=checkpoint_path,
677673
checkpoint_dir=checkpoint_dir,
678674
params_path=params_path,
@@ -695,7 +691,6 @@ def _prepare_for_llama_export(llm_config, args) -> LLMEdgeManager:
695691
dtype_override=dtype_override,
696692
use_qnn=llm_config.backend.qnn.enabled,
697693
export_only=llm_config.export.export_only,
698-
args=args,
699694
)
700695

701696
# At this point, the model is loaded in the default fp32.
@@ -1054,7 +1049,7 @@ def _to_edge_and_lower_llama( # noqa: C901
10541049
return builder
10551050

10561051

1057-
def _export_llama(llm_config, args) -> LLMEdgeManager: # noqa: C901
1052+
def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
10581053
_validate_args(llm_config)
10591054

10601055
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(
@@ -1066,7 +1061,7 @@ def _export_llama(llm_config, args) -> LLMEdgeManager: # noqa: C901
10661061
additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])]
10671062

10681063
# export_to_edge
1069-
builder_exported = _prepare_for_llama_export(llm_config, args).export()
1064+
builder_exported = _prepare_for_llama_export(llm_config).export()
10701065
builder_exported.run_canonical_optimizations()
10711066
modelname = builder_exported.modelname
10721067

@@ -1174,7 +1169,7 @@ def _load_llama_model_metadata(
11741169

11751170

11761171
def _load_llama_model(
1177-
modelname: str = "llama3",
1172+
llm_config: LlmConfig,
11781173
*,
11791174
checkpoint: Optional[str] = None,
11801175
checkpoint_dir: Optional[str] = None,
@@ -1198,8 +1193,6 @@ def _load_llama_model(
11981193
dtype_override: Optional[DType] = None,
11991194
use_qnn: bool = False,
12001195
export_only: bool = False,
1201-
args,
1202-
llm_config: Optional[LlmConfig] = None,
12031196
) -> "LLMEdgeManager":
12041197
"""
12051198
A helper util that builds a Llama2 model. It returns a LLMEdgeManager that
@@ -1208,6 +1201,7 @@ def _load_llama_model(
12081201
An instance of LLMEdgeManager which contains the eager mode model.
12091202
"""
12101203

1204+
modelname = llm_config.base.model_class
12111205
if modelname in EXECUTORCH_DEFINED_MODELS:
12121206
module_name = "llama"
12131207
model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py.
@@ -1220,13 +1214,11 @@ def _load_llama_model(
12201214
else:
12211215
raise ValueError(f"{modelname} is not a valid Llama model.")
12221216

1223-
torch_dtype = dtype_override.to_torch_dtype() if dtype_override else None
1224-
12251217
model, example_inputs, example_kwarg_inputs, dynamic_shapes = (
12261218
EagerModelFactory.create_model(
12271219
module_name,
12281220
model_class_name,
1229-
model_args={"llm_config": llm_config},
1221+
llm_config=llm_config,
12301222
)
12311223
)
12321224

examples/models/llama/model.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
get_default_model_resource_dir,
1717
)
1818

19+
from executorch.examples.models.llama.config.llm_config import LlmConfig
1920
from executorch.examples.models.llama.llama_transformer import construct_transformer
2021
from executorch.examples.models.llama.model_args import ModelArgs
2122
from executorch.examples.models.llama.rope import Rope
@@ -36,19 +37,19 @@ def convert_to_llama_checkpoint(**kwargs):
3637

3738

3839
class Llama2Model(EagerModelBase):
39-
def __init__(self, llm_config):
40+
def __init__(self, llm_config: LlmConfig):
4041
resource_dir = get_default_model_resource_dir(__file__)
4142

4243
self.llm_config = llm_config
43-
44+
4445
# Use single checkpoint file.
4546
checkpoint_path = self.llm_config.base.checkpoint
4647
# Check if checkpoint_dir was provided for a sharded checkpoint.
4748
checkpoint_dir = self.llm_config.base.checkpoint_dir
4849

4950
# Params file.
5051
params_path = self.llm_config.base.params
51-
52+
5253
self.use_kv_cache = self.llm_config.model.use_kv_cache
5354
self.use_sdpa_with_kv_cache_op = self.llm_config.model.use_sdpa_with_kv_cache
5455
self.generate_full_logits = self.llm_config.debug.generate_full_logits
@@ -101,7 +102,7 @@ def __init__(self, llm_config):
101102
checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True)
102103

103104
# If given checkpoint is fairseq, convert to llama checkpoint.
104-
fairseq2_checkpoint = kwargs.get("fairseq2", False)
105+
fairseq2_checkpoint = llm_config.base.fairseq2
105106
if fairseq2_checkpoint:
106107
print("Using fairseq2 checkpoint")
107108
checkpoint = convert_to_llama_checkpoint(checkpoint=checkpoint)
@@ -337,12 +338,15 @@ def _transform_for_pre_quantization(self, checkpoint, model_args):
337338
], f"Quantization mode {self.llm_config.base.preq_mode} is not compatible with SpinQuant."
338339
assert self.llm_config.base.preq_group_size, "preq_group_size must be specified"
339340
assert self.llm_config.model.dtype_override, "dtype_override must be specified"
340-
341+
341342
from .source_transformation.pre_quantization import (
342343
transform_linear_for_pre_quantization,
343344
)
344345

345-
assert self.llm_config.base.preq_group_size == model_args.quantization_args["group_size"]
346+
assert (
347+
self.llm_config.base.preq_group_size
348+
== model_args.quantization_args["group_size"]
349+
)
346350

347351
mapping = {
348352
"fp32": torch.float32,

0 commit comments

Comments
 (0)