@@ -164,6 +164,7 @@ def _model_call(self, inps):
164
164
def gen_eval_wrapper (
165
165
model_name : str ,
166
166
args : argparse .ArgumentParser ,
167
+ llm_config = None ,
167
168
):
168
169
"""
169
170
Generates a wrapper interface around the provided model and tokenizer for
@@ -172,7 +173,13 @@ def gen_eval_wrapper(
172
173
Returns:
173
174
eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library.
174
175
"""
175
- tokenizer = get_tokenizer (args .tokenizer_path ) # pyre-ignore
176
+ # If llm_config is not provided, convert args to llm_config
177
+ if llm_config is None :
178
+ from executorch .examples .models .llama .config .llm_config import LlmConfig
179
+
180
+ llm_config = LlmConfig .from_args (args )
181
+
182
+ tokenizer = get_tokenizer (llm_config .base .tokenizer_path )
176
183
177
184
# ExecuTorch Binary Evaluation
178
185
if (model := args .pte ) is not None : # pyre-ignore
@@ -182,7 +189,7 @@ def gen_eval_wrapper(
182
189
model = model ,
183
190
tokenizer = tokenizer ,
184
191
tokenizer_bin = tokenizer_bin ,
185
- max_seq_length = args . max_seq_length , # pyre-ignore
192
+ max_seq_length = llm_config . export . max_seq_length ,
186
193
)
187
194
188
195
# ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings
@@ -191,12 +198,14 @@ def gen_eval_wrapper(
191
198
tokenizer = tokenizer ,
192
199
# Exported model takes at most (max_seq_length - 1) tokens.
193
200
# Note that the eager model takes at most max_seq_length tokens.
194
- max_seq_length = args .max_seq_length - 1 ,
201
+ max_seq_length = llm_config . export .max_seq_length - 1 ,
195
202
)
196
203
197
- pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (args )
204
+ pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (
205
+ llm_config
206
+ )
198
207
# GPTFastEvalWrapper: Create a wrapper around a pre-exported model
199
- manager : LLMEdgeManager = _prepare_for_llama_export (args )
208
+ manager : LLMEdgeManager = _prepare_for_llama_export (llm_config )
200
209
201
210
if len (quantizers ) != 0 :
202
211
manager = manager .export ().pt2e_quantize (quantizers )
@@ -208,9 +217,9 @@ def gen_eval_wrapper(
208
217
return GraphModuleEvalWrapper (
209
218
model = model ,
210
219
tokenizer = tokenizer ,
211
- max_seq_length = args .max_seq_length ,
212
- use_kv_cache = args . use_kv_cache , # pyre-ignore
213
- enable_dynamic_shape = args . enable_dynamic_shape , # pyre-ignore
220
+ max_seq_length = llm_config . export .max_seq_length ,
221
+ use_kv_cache = llm_config . model . use_kv_cache ,
222
+ enable_dynamic_shape = llm_config . model . enable_dynamic_shape ,
214
223
)
215
224
else :
216
225
# TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch
@@ -234,8 +243,8 @@ def gen_eval_wrapper(
234
243
return EagerEvalWrapper (
235
244
model = model ,
236
245
tokenizer = tokenizer ,
237
- max_seq_length = args .max_seq_length ,
238
- use_kv_cache = args .use_kv_cache ,
246
+ max_seq_length = llm_config . export .max_seq_length ,
247
+ use_kv_cache = llm_config . model .use_kv_cache ,
239
248
)
240
249
241
250
@@ -296,12 +305,16 @@ def eval_llama(
296
305
model_name : str ,
297
306
args : argparse .ArgumentParser ,
298
307
) -> None :
308
+ # Convert args to LlmConfig
309
+ from executorch .examples .models .llama .config .llm_config import LlmConfig
310
+
311
+ llm_config = LlmConfig .from_args (args )
312
+
299
313
# Generate the eval wrapper
300
- eval_wrapper = gen_eval_wrapper (model_name , args )
314
+ eval_wrapper = gen_eval_wrapper (model_name , args , llm_config )
301
315
302
316
# Needed for loading mmlu dataset.
303
317
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
304
- # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `tasks`
305
318
if args .tasks and "mmlu" in args .tasks :
306
319
import datasets
307
320
@@ -312,8 +325,8 @@ def eval_llama(
312
325
eval_results = simple_evaluate (
313
326
model = eval_wrapper ,
314
327
tasks = args .tasks ,
315
- num_fewshot = args .num_fewshot , # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `num_fewshot`
316
- limit = args .limit , # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `limit`
328
+ num_fewshot = args .num_fewshot ,
329
+ limit = args .limit ,
317
330
)
318
331
319
332
for task , res in eval_results ["results" ].items ():
@@ -326,19 +339,24 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse
326
339
327
340
This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py
328
341
"""
329
- assert args .use_attention_sink is not None # pyre-ignore [16]
330
- assert args .attention_sink_eval_tokens > 0 # pyre-ignore [16]
331
- attention_sink_params = args .use_attention_sink .split ("," )
342
+ # Convert args to LlmConfig
343
+ from executorch .examples .models .llama .config .llm_config import LlmConfig
344
+
345
+ llm_config = LlmConfig .from_args (args )
346
+
347
+ assert llm_config .model .use_attention_sink is not None
348
+ assert args .attention_sink_eval_tokens > 0
349
+ attention_sink_params = llm_config .model .use_attention_sink .split ("," )
332
350
assert len (attention_sink_params ) == 3
333
351
sink_size = int (attention_sink_params [0 ])
334
352
window_size = int (attention_sink_params [1 ])
335
353
336
- assert args . max_seq_length == sink_size + window_size # pyre-ignore [16]
354
+ assert llm_config . export . max_seq_length == sink_size + window_size
337
355
338
356
device = "cuda" if torch .cuda .is_available () else "cpu"
339
- manager : LLMEdgeManager = _prepare_for_llama_export (args )
357
+ manager : LLMEdgeManager = _prepare_for_llama_export (llm_config )
340
358
model = manager .model .eval ().to (device = device )
341
- tokenizer = get_tokenizer (args . tokenizer_path ) # pyre-ignore [16]
359
+ tokenizer = get_tokenizer (llm_config . base . tokenizer_path )
342
360
343
361
eval_data = load_dataset ("wikitext" , "wikitext-2-raw-v1" , split = "test" )
344
362
@@ -347,7 +365,7 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse
347
365
progress_bar = tqdm (total = args .attention_sink_eval_tokens )
348
366
input_pos = 0
349
367
while input_pos < args .attention_sink_eval_tokens :
350
- for text in eval_data ["text" ]: # pyre-ignore [16]
368
+ for text in eval_data ["text" ]:
351
369
tokens = tokenizer .encode (text , bos = False , eos = False )
352
370
if len (tokens ) <= 0 :
353
371
continue
0 commit comments