Skip to content

Completely remove args from export_llama_lib #11171

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion backends/arm/test/models/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
TosaPipelineMI,
)

from executorch.examples.models.llama.config.llm_config_utils import convert_args_to_llm_config
from executorch.examples.models.llama.config.llm_config_utils import (
convert_args_to_llm_config,
)
from executorch.examples.models.llama.export_llama_lib import (
build_args_parser,
get_llama_model,
Expand Down
30 changes: 11 additions & 19 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ def build_model(
argString = f"--model {model} --checkpoint {checkpoint} --params {params} {extra_opts} --output-dir {output_dir}"
parser = build_args_parser()
args = parser.parse_args(shlex.split(argString))
return export_llama(args)
llm_config = convert_args_to_llm_config(args)
return export_llama(llm_config)


def parse_list_of_ints(s):
Expand Down Expand Up @@ -579,15 +580,10 @@ def export_llama(
) -> str:
if isinstance(export_options, argparse.Namespace):
# Legacy CLI.
args = export_options
llm_config = convert_args_to_llm_config(export_options)
elif isinstance(export_options, DictConfig):
# Hydra CLI.
llm_config = export_options
# Create an args object for backward compatibility during transition
args = argparse.Namespace()
for key, value in llm_config.items():
setattr(args, key, value)
else:
raise ValueError(
"Input to export_llama must be either of type argparse.Namespace or LlmConfig"
Expand Down Expand Up @@ -626,7 +622,7 @@ def export_llama(
from executorch.util.python_profiler import CProfilerFlameGraph

with CProfilerFlameGraph(llm_config.debug.profile_path):
builder = _export_llama(llm_config, args)
builder = _export_llama(llm_config)
assert (
filename := builder.get_saved_pte_filename()
) is not None, "Fail to get file name from builder"
Expand All @@ -637,14 +633,14 @@ def export_llama(
)
return ""
else:
builder = _export_llama(llm_config, args)
builder = _export_llama(llm_config)
assert (
filename := builder.get_saved_pte_filename()
) is not None, "Fail to get file name from builder"
return filename


def _prepare_for_llama_export(llm_config, args) -> LLMEdgeManager:
def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
"""
Helper function for export_llama. Loads the model from checkpoint and params,
and sets up a LLMEdgeManager with initial transforms and dtype conversion.
Expand Down Expand Up @@ -672,7 +668,7 @@ def _prepare_for_llama_export(llm_config, args) -> LLMEdgeManager:
dtype_override = DType[llm_config.model.dtype_override]

edge_manager = _load_llama_model(
llm_config.base.model_class,
llm_config,
checkpoint=checkpoint_path,
checkpoint_dir=checkpoint_dir,
params_path=params_path,
Expand All @@ -695,7 +691,6 @@ def _prepare_for_llama_export(llm_config, args) -> LLMEdgeManager:
dtype_override=dtype_override,
use_qnn=llm_config.backend.qnn.enabled,
export_only=llm_config.export.export_only,
args=args,
)

# At this point, the model is loaded in the default fp32.
Expand Down Expand Up @@ -1054,7 +1049,7 @@ def _to_edge_and_lower_llama( # noqa: C901
return builder


def _export_llama(llm_config, args) -> LLMEdgeManager: # noqa: C901
def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
_validate_args(llm_config)

pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(
Expand All @@ -1066,7 +1061,7 @@ def _export_llama(llm_config, args) -> LLMEdgeManager: # noqa: C901
additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])]

# export_to_edge
builder_exported = _prepare_for_llama_export(llm_config, args).export()
builder_exported = _prepare_for_llama_export(llm_config).export()
builder_exported.run_canonical_optimizations()
modelname = builder_exported.modelname

Expand Down Expand Up @@ -1174,7 +1169,7 @@ def _load_llama_model_metadata(


def _load_llama_model(
modelname: str = "llama3",
llm_config: LlmConfig,
*,
checkpoint: Optional[str] = None,
checkpoint_dir: Optional[str] = None,
Expand All @@ -1198,8 +1193,6 @@ def _load_llama_model(
dtype_override: Optional[DType] = None,
use_qnn: bool = False,
export_only: bool = False,
args,
llm_config: Optional[LlmConfig] = None,
) -> "LLMEdgeManager":
"""
A helper util that builds a Llama2 model. It returns a LLMEdgeManager that
Expand All @@ -1208,6 +1201,7 @@ def _load_llama_model(
An instance of LLMEdgeManager which contains the eager mode model.
"""

modelname = llm_config.base.model_class
if modelname in EXECUTORCH_DEFINED_MODELS:
module_name = "llama"
model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py.
Expand All @@ -1220,13 +1214,11 @@ def _load_llama_model(
else:
raise ValueError(f"{modelname} is not a valid Llama model.")

torch_dtype = dtype_override.to_torch_dtype() if dtype_override else None

model, example_inputs, example_kwarg_inputs, dynamic_shapes = (
EagerModelFactory.create_model(
module_name,
model_class_name,
model_args={"llm_config": llm_config},
llm_config=llm_config,
)
)

Expand Down
16 changes: 10 additions & 6 deletions examples/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
get_default_model_resource_dir,
)

from executorch.examples.models.llama.config.llm_config import LlmConfig
from executorch.examples.models.llama.llama_transformer import construct_transformer
from executorch.examples.models.llama.model_args import ModelArgs
from executorch.examples.models.llama.rope import Rope
Expand All @@ -36,19 +37,19 @@ def convert_to_llama_checkpoint(**kwargs):


class Llama2Model(EagerModelBase):
def __init__(self, llm_config):
def __init__(self, llm_config: LlmConfig):
resource_dir = get_default_model_resource_dir(__file__)

self.llm_config = llm_config

# Use single checkpoint file.
checkpoint_path = self.llm_config.base.checkpoint
# Check if checkpoint_dir was provided for a sharded checkpoint.
checkpoint_dir = self.llm_config.base.checkpoint_dir

# Params file.
params_path = self.llm_config.base.params

self.use_kv_cache = self.llm_config.model.use_kv_cache
self.use_sdpa_with_kv_cache_op = self.llm_config.model.use_sdpa_with_kv_cache
self.generate_full_logits = self.llm_config.debug.generate_full_logits
Expand Down Expand Up @@ -101,7 +102,7 @@ def __init__(self, llm_config):
checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True)

# If given checkpoint is fairseq, convert to llama checkpoint.
fairseq2_checkpoint = kwargs.get("fairseq2", False)
fairseq2_checkpoint = llm_config.base.fairseq2
if fairseq2_checkpoint:
print("Using fairseq2 checkpoint")
checkpoint = convert_to_llama_checkpoint(checkpoint=checkpoint)
Expand Down Expand Up @@ -337,12 +338,15 @@ def _transform_for_pre_quantization(self, checkpoint, model_args):
], f"Quantization mode {self.llm_config.base.preq_mode} is not compatible with SpinQuant."
assert self.llm_config.base.preq_group_size, "preq_group_size must be specified"
assert self.llm_config.model.dtype_override, "dtype_override must be specified"

from .source_transformation.pre_quantization import (
transform_linear_for_pre_quantization,
)

assert self.llm_config.base.preq_group_size == model_args.quantization_args["group_size"]
assert (
self.llm_config.base.preq_group_size
== model_args.quantization_args["group_size"]
)

mapping = {
"fp32": torch.float32,
Expand Down
Loading