10
10
from typing import Optional , Union
11
11
12
12
import torch
13
- from executorch .examples .models .llama2 .evaluate import EagerEvalWrapper , evaluate_model
14
13
from executorch .examples .models .llama2 .export_llama_lib import (
15
14
get_quantizer_and_quant_params ,
16
15
)
17
16
from executorch .examples .models .llama2 .tokenizer .tiktoken import Tokenizer as Tiktoken
18
17
19
- from executorch .extension .llm .export import LLMEdgeManager
18
+ from executorch .extension .llm .export . builder import LLMEdgeManager
20
19
from executorch .extension .llm .tokenizer .tokenizer import (
21
20
Tokenizer as SentencePieceTokenizer ,
22
21
)
23
22
from executorch .extension .llm .tokenizer .utils import get_tokenizer
24
23
from lm_eval .api .model import LM
25
24
25
+ from .evaluate .eager_eval import EagerEvalWrapper , evaluate_model
26
+
26
27
from .export_llama_lib import (
27
28
_prepare_for_llama_export ,
28
29
build_args_parser as _build_args_parser ,
@@ -91,7 +92,7 @@ def __init__(
91
92
tokenizer : Union [SentencePieceTokenizer , Tiktoken ],
92
93
max_seq_length : Optional [int ] = None ,
93
94
):
94
- super ().__init__ (None , tokenizer , max_seq_length )
95
+ super ().__init__ (None , tokenizer , max_seq_length ) # pyre-ignore
95
96
self ._model = model # Expects model to be path to a .pte file
96
97
97
98
from executorch .extension .pybindings .portable_lib import _load_for_executorch
@@ -106,7 +107,7 @@ def __init__(
106
107
from executorch .kernels import quantized # noqa
107
108
108
109
self ._et_model = _load_for_executorch (self ._model )
109
- self ._use_kv_cache = self ._et_model .run_method ("use_kv_cache" )[0 ]
110
+ self ._use_kv_cache = self ._et_model .run_method ("use_kv_cache" )[0 ] # pyre-ignore
110
111
111
112
def _model_call (self , inps ):
112
113
# Given inps (tokens), return the logits from a single forward call
@@ -140,7 +141,7 @@ def __init__(
140
141
tokenizer_bin : str ,
141
142
max_seq_length : Optional [int ] = None ,
142
143
):
143
- super ().__init__ (None , tokenizer , max_seq_length )
144
+ super ().__init__ (None , tokenizer , max_seq_length ) # pyre-ignore
144
145
self ._model = model
145
146
self ._tokenizer_bin = tokenizer_bin
146
147
@@ -165,17 +166,17 @@ def gen_eval_wrapper(
165
166
Returns:
166
167
eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library.
167
168
"""
168
- tokenizer = get_tokenizer (args .tokenizer_path )
169
+ tokenizer = get_tokenizer (args .tokenizer_path ) # pyre-ignore
169
170
170
171
# ExecuTorch Binary Evaluation
171
- if (model := args .pte ) is not None :
172
- if (tokenizer_bin := args .tokenizer_bin ) is not None :
172
+ if (model := args .pte ) is not None : # pyre-ignore
173
+ if (tokenizer_bin := args .tokenizer_bin ) is not None : # pyre-ignore
173
174
# ETRunnerEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated at runtime
174
175
return ETRunnerEvalWrapper (
175
176
model = model ,
176
177
tokenizer = tokenizer ,
177
178
tokenizer_bin = tokenizer_bin ,
178
- max_seq_length = args .max_seq_length ,
179
+ max_seq_length = args .max_seq_length , # pyre-ignore
179
180
)
180
181
181
182
# ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings
@@ -194,16 +195,16 @@ def gen_eval_wrapper(
194
195
if len (quantizers ) != 0 :
195
196
manager = manager .capture_pre_autograd_graph ().pt2e_quantize (quantizers )
196
197
model = (
197
- manager .pre_autograd_graph_module .to (device = "cuda" )
198
+ manager .pre_autograd_graph_module .to (device = "cuda" ) # pyre-ignore
198
199
if torch .cuda .is_available ()
199
200
else manager .pre_autograd_graph_module .to (device = "cpu" )
200
201
)
201
202
return GraphModuleEvalWrapper (
202
203
model = model ,
203
204
tokenizer = tokenizer ,
204
205
max_seq_length = args .max_seq_length ,
205
- use_kv_cache = args .use_kv_cache ,
206
- enable_dynamic_shape = args .enable_dynamic_shape ,
206
+ use_kv_cache = args .use_kv_cache , # pyre-ignore
207
+ enable_dynamic_shape = args .enable_dynamic_shape , # pyre-ignore
207
208
)
208
209
else :
209
210
# TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch
@@ -221,7 +222,7 @@ def gen_eval_wrapper(
221
222
# that is not available in this eval_llama. We save the checkpoint
222
223
# here for consistency with eval_llama. The accuracy results we
223
224
# get from eval_llama can be used as a reference to other evaluations.
224
- if args .output_eager_checkpoint_file is not None :
225
+ if args .output_eager_checkpoint_file is not None : # pyre-ignore
225
226
torch .save (model , args .output_eager_checkpoint_file )
226
227
227
228
return EagerEvalWrapper (
@@ -282,8 +283,8 @@ def eval_llama(
282
283
# Evaluate the model
283
284
eval_results = evaluate_model (
284
285
eval_wrapper ,
285
- args .tasks ,
286
- args .limit ,
286
+ args .tasks , # pyre-ignore
287
+ args .limit , # pyre-ignore
287
288
)
288
289
289
290
for task , res in eval_results ["results" ].items ():
0 commit comments