Skip to content

Use llm_config instead of args in export_llama functions #11162

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Jun 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f69655c
Use llm_config instead of args in export_llama functions
jackzhxng May 27, 2025
d9c70c2
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng May 27, 2025
209fd7f
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng May 27, 2025
45571eb
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng May 28, 2025
97ec69c
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng May 28, 2025
b928cc7
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng May 28, 2025
b08f22b
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng May 28, 2025
00aa0e8
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng May 29, 2025
20bdaa6
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng Jun 2, 2025
900bbdf
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng Jun 3, 2025
a14f548
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng Jun 3, 2025
d7d33d7
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng Jun 3, 2025
4a875d8
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng Jun 4, 2025
6f6bf53
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng Jun 5, 2025
4760311
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng Jun 5, 2025
1a85097
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng Jun 5, 2025
792022d
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng Jun 5, 2025
54477dc
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng Jun 6, 2025
6f3e0a5
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng Jun 6, 2025
c447cbd
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng Jun 6, 2025
679fe9e
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng Jun 9, 2025
52455bc
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng Jun 9, 2025
9a15088
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng Jun 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion backends/arm/test/models/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
TosaPipelineMI,
)

from executorch.examples.models.llama.config.llm_config import LlmConfig
from executorch.examples.models.llama.export_llama_lib import (
build_args_parser,
get_llama_model,
Expand Down Expand Up @@ -89,8 +90,9 @@ def prepare_model(self):
]
parser = build_args_parser()
args = parser.parse_args(args)
llm_config = LlmConfig.from_args(args)

llama_model, llama_inputs, llama_meta = get_llama_model(args)
llama_model, llama_inputs, llama_meta = get_llama_model(llm_config)

return llama_model, llama_inputs, llama_meta

Expand Down
29 changes: 13 additions & 16 deletions examples/apple/mps/scripts/mps_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
serialize_from_bundled_program_to_flatbuffer,
)

from executorch.examples.models.llama.config.llm_config import LlmConfig
from executorch.exir import (
EdgeCompileConfig,
EdgeProgramManager,
Expand Down Expand Up @@ -131,28 +132,24 @@ def parse_args():
return args


def get_model_config(args):
model_config = {}
model_config["module_name"] = MODEL_NAME_TO_MODEL[args.model_name][0]
model_config["model_class_name"] = MODEL_NAME_TO_MODEL[args.model_name][1]

if args.model_name == "llama2":
if args.checkpoint:
model_config["checkpoint"] = args.checkpoint
if args.params:
model_config["params"] = args.params
model_config["use_kv_cache"] = True
return model_config


if __name__ == "__main__": # noqa: C901
args = parse_args()

if args.model_name not in MODEL_NAME_TO_MODEL:
raise RuntimeError(f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}.")

model_config = get_model_config(args)
model, example_inputs, _, _ = EagerModelFactory.create_model(**model_config)
llm_config = LlmConfig()
if args.model_name == "llama2":
if args.checkpoint:
llm_config.base.checkpoint = args.checkpoint
if args.params:
llm_config.base.params = args.params
llm_config.model.use_kv_cache = True
model, example_inputs, _, _ = EagerModelFactory.create_model(
module_name=MODEL_NAME_TO_MODEL[args.model_name][0],
model_class_name=MODEL_NAME_TO_MODEL[args.model_name][1],
llm_config=llm_config,
)

model = model.eval()

Expand Down
1 change: 1 addition & 0 deletions examples/models/llama/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ runtime.python_library(
"//caffe2:torch",
"//executorch/examples/models:model_base",
"//executorch/examples/models/llama:llama_transformer",
"//executorch/examples/models/llama/config:llm_config",
"//executorch/examples/models:checkpoint",
],
)
Expand Down
60 changes: 39 additions & 21 deletions examples/models/llama/eval_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def _model_call(self, inps):
def gen_eval_wrapper(
model_name: str,
args: argparse.ArgumentParser,
llm_config=None,
):
"""
Generates a wrapper interface around the provided model and tokenizer for
Expand All @@ -172,7 +173,13 @@ def gen_eval_wrapper(
Returns:
eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library.
"""
tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore
# If llm_config is not provided, convert args to llm_config
if llm_config is None:
from executorch.examples.models.llama.config.llm_config import LlmConfig

llm_config = LlmConfig.from_args(args)

tokenizer = get_tokenizer(llm_config.base.tokenizer_path)

# ExecuTorch Binary Evaluation
if (model := args.pte) is not None: # pyre-ignore
Expand All @@ -182,7 +189,7 @@ def gen_eval_wrapper(
model=model,
tokenizer=tokenizer,
tokenizer_bin=tokenizer_bin,
max_seq_length=args.max_seq_length, # pyre-ignore
max_seq_length=llm_config.export.max_seq_length,
)

# ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings
Expand All @@ -191,12 +198,14 @@ def gen_eval_wrapper(
tokenizer=tokenizer,
# Exported model takes at most (max_seq_length - 1) tokens.
# Note that the eager model takes at most max_seq_length tokens.
max_seq_length=args.max_seq_length - 1,
max_seq_length=llm_config.export.max_seq_length - 1,
)

pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(
llm_config
)
# GPTFastEvalWrapper: Create a wrapper around a pre-exported model
manager: LLMEdgeManager = _prepare_for_llama_export(args)
manager: LLMEdgeManager = _prepare_for_llama_export(llm_config)

if len(quantizers) != 0:
manager = manager.export().pt2e_quantize(quantizers)
Expand All @@ -208,9 +217,9 @@ def gen_eval_wrapper(
return GraphModuleEvalWrapper(
model=model,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
use_kv_cache=args.use_kv_cache, # pyre-ignore
enable_dynamic_shape=args.enable_dynamic_shape, # pyre-ignore
max_seq_length=llm_config.export.max_seq_length,
use_kv_cache=llm_config.model.use_kv_cache,
enable_dynamic_shape=llm_config.model.enable_dynamic_shape,
)
else:
# TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch
Expand All @@ -234,8 +243,8 @@ def gen_eval_wrapper(
return EagerEvalWrapper(
model=model,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
use_kv_cache=args.use_kv_cache,
max_seq_length=llm_config.export.max_seq_length,
use_kv_cache=llm_config.model.use_kv_cache,
)


Expand Down Expand Up @@ -296,12 +305,16 @@ def eval_llama(
model_name: str,
args: argparse.ArgumentParser,
) -> None:
# Convert args to LlmConfig
from executorch.examples.models.llama.config.llm_config import LlmConfig

llm_config = LlmConfig.from_args(args)

# Generate the eval wrapper
eval_wrapper = gen_eval_wrapper(model_name, args)
eval_wrapper = gen_eval_wrapper(model_name, args, llm_config)

# Needed for loading mmlu dataset.
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
# pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `tasks`
if args.tasks and "mmlu" in args.tasks:
import datasets

Expand All @@ -312,8 +325,8 @@ def eval_llama(
eval_results = simple_evaluate(
model=eval_wrapper,
tasks=args.tasks,
num_fewshot=args.num_fewshot, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `num_fewshot`
limit=args.limit, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `limit`
num_fewshot=args.num_fewshot,
limit=args.limit,
)

for task, res in eval_results["results"].items():
Expand All @@ -326,19 +339,24 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse

This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py
"""
assert args.use_attention_sink is not None # pyre-ignore [16]
assert args.attention_sink_eval_tokens > 0 # pyre-ignore [16]
attention_sink_params = args.use_attention_sink.split(",")
# Convert args to LlmConfig
from executorch.examples.models.llama.config.llm_config import LlmConfig

llm_config = LlmConfig.from_args(args)

assert llm_config.model.use_attention_sink is not None
assert args.attention_sink_eval_tokens > 0
attention_sink_params = llm_config.model.use_attention_sink.split(",")
assert len(attention_sink_params) == 3
sink_size = int(attention_sink_params[0])
window_size = int(attention_sink_params[1])

assert args.max_seq_length == sink_size + window_size # pyre-ignore [16]
assert llm_config.export.max_seq_length == sink_size + window_size

device = "cuda" if torch.cuda.is_available() else "cpu"
manager: LLMEdgeManager = _prepare_for_llama_export(args)
manager: LLMEdgeManager = _prepare_for_llama_export(llm_config)
model = manager.model.eval().to(device=device)
tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore [16]
tokenizer = get_tokenizer(llm_config.base.tokenizer_path)

eval_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")

Expand All @@ -347,7 +365,7 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse
progress_bar = tqdm(total=args.attention_sink_eval_tokens)
input_pos = 0
while input_pos < args.attention_sink_eval_tokens:
for text in eval_data["text"]: # pyre-ignore [16]
for text in eval_data["text"]:
tokens = tokenizer.encode(text, bos=False, eos=False)
if len(tokens) <= 0:
continue
Expand Down
3 changes: 2 additions & 1 deletion examples/models/llama/export_llama_hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
from executorch.examples.models.llama.config.llm_config import LlmConfig
from executorch.examples.models.llama.export_llama_lib import export_llama
from hydra.core.config_store import ConfigStore
from omegaconf import OmegaConf

cs = ConfigStore.instance()
cs.store(name="llm_config", node=LlmConfig)


@hydra.main(version_base=None, config_name="llm_config")
def main(llm_config: LlmConfig) -> None:
export_llama(llm_config)
export_llama(OmegaConf.to_object(llm_config))


if __name__ == "__main__":
Expand Down
Loading
Loading