Skip to content

Commit bac2998

Browse files
committed
Use llm_config instead of args in export_llama functions
Pull Request resolved: #11162 @imported-using-ghimport Differential Revision: [D75484927](https://our.internmc.facebook.com/intern/diff/D75484927/) ghstack-source-id: 289208261
1 parent 62d0580 commit bac2998

File tree

10 files changed

+351
-330
lines changed

10 files changed

+351
-330
lines changed

backends/arm/test/models/test_llama.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
TosaPipelineMI,
2323
)
2424

25+
from executorch.examples.models.llama.config.llm_config import LlmConfig
2526
from executorch.examples.models.llama.export_llama_lib import (
2627
build_args_parser,
2728
get_llama_model,
@@ -89,8 +90,9 @@ def prepare_model(self):
8990
]
9091
parser = build_args_parser()
9192
args = parser.parse_args(args)
93+
llm_config = LlmConfig.from_args(args)
9294

93-
llama_model, llama_inputs, llama_meta = get_llama_model(args)
95+
llama_model, llama_inputs, llama_meta = get_llama_model(llm_config)
9496

9597
return llama_model, llama_inputs, llama_meta
9698

examples/apple/mps/scripts/mps_example.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
serialize_from_bundled_program_to_flatbuffer,
2121
)
2222

23+
from executorch.examples.models.llama.config.llm_config import LlmConfig
2324
from executorch.exir import (
2425
EdgeCompileConfig,
2526
EdgeProgramManager,
@@ -131,28 +132,24 @@ def parse_args():
131132
return args
132133

133134

134-
def get_model_config(args):
135-
model_config = {}
136-
model_config["module_name"] = MODEL_NAME_TO_MODEL[args.model_name][0]
137-
model_config["model_class_name"] = MODEL_NAME_TO_MODEL[args.model_name][1]
138-
139-
if args.model_name == "llama2":
140-
if args.checkpoint:
141-
model_config["checkpoint"] = args.checkpoint
142-
if args.params:
143-
model_config["params"] = args.params
144-
model_config["use_kv_cache"] = True
145-
return model_config
146-
147-
148135
if __name__ == "__main__": # noqa: C901
149136
args = parse_args()
150137

151138
if args.model_name not in MODEL_NAME_TO_MODEL:
152139
raise RuntimeError(f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}.")
153140

154-
model_config = get_model_config(args)
155-
model, example_inputs, _, _ = EagerModelFactory.create_model(**model_config)
141+
llm_config = LlmConfig()
142+
if args.model_name == "llama2":
143+
if args.checkpoint:
144+
llm_config.base.checkpoint = args.checkpoint
145+
if args.params:
146+
llm_config.base.params = args.params
147+
llm_config.model.use_kv_cache = True
148+
model, example_inputs, _, _ = EagerModelFactory.create_model(
149+
module_name=MODEL_NAME_TO_MODEL[args.model_name][0],
150+
model_class_name=MODEL_NAME_TO_MODEL[args.model_name][1],
151+
llm_config=llm_config,
152+
)
156153

157154
model = model.eval()
158155

examples/models/llama/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ runtime.python_library(
6767
"//caffe2:torch",
6868
"//executorch/examples/models:model_base",
6969
"//executorch/examples/models/llama:llama_transformer",
70+
"//executorch/examples/models/llama/config:llm_config",
7071
"//executorch/examples/models:checkpoint",
7172
],
7273
)

examples/models/llama/eval_llama_lib.py

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ def _model_call(self, inps):
164164
def gen_eval_wrapper(
165165
model_name: str,
166166
args: argparse.ArgumentParser,
167+
llm_config=None,
167168
):
168169
"""
169170
Generates a wrapper interface around the provided model and tokenizer for
@@ -172,7 +173,13 @@ def gen_eval_wrapper(
172173
Returns:
173174
eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library.
174175
"""
175-
tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore
176+
# If llm_config is not provided, convert args to llm_config
177+
if llm_config is None:
178+
from executorch.examples.models.llama.config.llm_config import LlmConfig
179+
180+
llm_config = LlmConfig.from_args(args)
181+
182+
tokenizer = get_tokenizer(llm_config.base.tokenizer_path)
176183

177184
# ExecuTorch Binary Evaluation
178185
if (model := args.pte) is not None: # pyre-ignore
@@ -182,7 +189,7 @@ def gen_eval_wrapper(
182189
model=model,
183190
tokenizer=tokenizer,
184191
tokenizer_bin=tokenizer_bin,
185-
max_seq_length=args.max_seq_length, # pyre-ignore
192+
max_seq_length=llm_config.export.max_seq_length,
186193
)
187194

188195
# ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings
@@ -191,12 +198,14 @@ def gen_eval_wrapper(
191198
tokenizer=tokenizer,
192199
# Exported model takes at most (max_seq_length - 1) tokens.
193200
# Note that the eager model takes at most max_seq_length tokens.
194-
max_seq_length=args.max_seq_length - 1,
201+
max_seq_length=llm_config.export.max_seq_length - 1,
195202
)
196203

197-
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
204+
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(
205+
llm_config
206+
)
198207
# GPTFastEvalWrapper: Create a wrapper around a pre-exported model
199-
manager: LLMEdgeManager = _prepare_for_llama_export(args)
208+
manager: LLMEdgeManager = _prepare_for_llama_export(llm_config)
200209

201210
if len(quantizers) != 0:
202211
manager = manager.export().pt2e_quantize(quantizers)
@@ -208,9 +217,9 @@ def gen_eval_wrapper(
208217
return GraphModuleEvalWrapper(
209218
model=model,
210219
tokenizer=tokenizer,
211-
max_seq_length=args.max_seq_length,
212-
use_kv_cache=args.use_kv_cache, # pyre-ignore
213-
enable_dynamic_shape=args.enable_dynamic_shape, # pyre-ignore
220+
max_seq_length=llm_config.export.max_seq_length,
221+
use_kv_cache=llm_config.model.use_kv_cache,
222+
enable_dynamic_shape=llm_config.model.enable_dynamic_shape,
214223
)
215224
else:
216225
# TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch
@@ -234,8 +243,8 @@ def gen_eval_wrapper(
234243
return EagerEvalWrapper(
235244
model=model,
236245
tokenizer=tokenizer,
237-
max_seq_length=args.max_seq_length,
238-
use_kv_cache=args.use_kv_cache,
246+
max_seq_length=llm_config.export.max_seq_length,
247+
use_kv_cache=llm_config.model.use_kv_cache,
239248
)
240249

241250

@@ -296,12 +305,16 @@ def eval_llama(
296305
model_name: str,
297306
args: argparse.ArgumentParser,
298307
) -> None:
308+
# Convert args to LlmConfig
309+
from executorch.examples.models.llama.config.llm_config import LlmConfig
310+
311+
llm_config = LlmConfig.from_args(args)
312+
299313
# Generate the eval wrapper
300-
eval_wrapper = gen_eval_wrapper(model_name, args)
314+
eval_wrapper = gen_eval_wrapper(model_name, args, llm_config)
301315

302316
# Needed for loading mmlu dataset.
303317
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
304-
# pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `tasks`
305318
if args.tasks and "mmlu" in args.tasks:
306319
import datasets
307320

@@ -312,8 +325,8 @@ def eval_llama(
312325
eval_results = simple_evaluate(
313326
model=eval_wrapper,
314327
tasks=args.tasks,
315-
num_fewshot=args.num_fewshot, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `num_fewshot`
316-
limit=args.limit, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `limit`
328+
num_fewshot=args.num_fewshot,
329+
limit=args.limit,
317330
)
318331

319332
for task, res in eval_results["results"].items():
@@ -326,19 +339,24 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse
326339
327340
This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py
328341
"""
329-
assert args.use_attention_sink is not None # pyre-ignore [16]
330-
assert args.attention_sink_eval_tokens > 0 # pyre-ignore [16]
331-
attention_sink_params = args.use_attention_sink.split(",")
342+
# Convert args to LlmConfig
343+
from executorch.examples.models.llama.config.llm_config import LlmConfig
344+
345+
llm_config = LlmConfig.from_args(args)
346+
347+
assert llm_config.model.use_attention_sink is not None
348+
assert args.attention_sink_eval_tokens > 0
349+
attention_sink_params = llm_config.model.use_attention_sink.split(",")
332350
assert len(attention_sink_params) == 3
333351
sink_size = int(attention_sink_params[0])
334352
window_size = int(attention_sink_params[1])
335353

336-
assert args.max_seq_length == sink_size + window_size # pyre-ignore [16]
354+
assert llm_config.export.max_seq_length == sink_size + window_size
337355

338356
device = "cuda" if torch.cuda.is_available() else "cpu"
339-
manager: LLMEdgeManager = _prepare_for_llama_export(args)
357+
manager: LLMEdgeManager = _prepare_for_llama_export(llm_config)
340358
model = manager.model.eval().to(device=device)
341-
tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore [16]
359+
tokenizer = get_tokenizer(llm_config.base.tokenizer_path)
342360

343361
eval_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
344362

@@ -347,7 +365,7 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse
347365
progress_bar = tqdm(total=args.attention_sink_eval_tokens)
348366
input_pos = 0
349367
while input_pos < args.attention_sink_eval_tokens:
350-
for text in eval_data["text"]: # pyre-ignore [16]
368+
for text in eval_data["text"]:
351369
tokens = tokenizer.encode(text, bos=False, eos=False)
352370
if len(tokens) <= 0:
353371
continue

examples/models/llama/export_llama_hydra.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@
1313
from executorch.examples.models.llama.config.llm_config import LlmConfig
1414
from executorch.examples.models.llama.export_llama_lib import export_llama
1515
from hydra.core.config_store import ConfigStore
16+
from omegaconf import OmegaConf
1617

1718
cs = ConfigStore.instance()
1819
cs.store(name="llm_config", node=LlmConfig)
1920

2021

2122
@hydra.main(version_base=None, config_name="llm_config")
2223
def main(llm_config: LlmConfig) -> None:
23-
export_llama(llm_config)
24+
export_llama(OmegaConf.to_object(llm_config))
2425

2526

2627
if __name__ == "__main__":

0 commit comments

Comments
 (0)