11
11
import lm_eval
12
12
13
13
import torch
14
+ from lm_eval .api .model import LM
14
15
from lm_eval .evaluator import evaluate
15
16
from lm_eval .models .huggingface import HFLM as eval_wrapper
16
17
from lm_eval .tasks import get_task_dict
26
27
27
28
class GPTFastEvalWrapper (eval_wrapper ):
28
29
"""
29
- A wrapper class for GPTFast, providing integration with the lm-evaluation-harness library.
30
+ A wrapper class based on GPTFast, providing integration with the lm-evaluation-harness library.
30
31
"""
31
32
32
33
def __init__ (
@@ -85,21 +86,17 @@ def _model_generate(self, context, max_length, eos_token_id):
85
86
86
87
@torch .no_grad ()
87
88
def eval (
88
- model : nn .Module ,
89
- tokenizer ,
89
+ eval_wrapper : LM ,
90
90
tasks : Optional [list ] = None ,
91
91
limit : Optional [int ] = None ,
92
- max_seq_length : Optional [int ] = None ,
93
92
) -> dict :
94
93
"""
95
94
Evaluates a language model on a specified task using the lm-evaluation-harness library.
96
95
97
96
Args:
98
- model (nn.Module): The pre-trained language model to evaluate.
99
- tokenizer: The tokenizer to use for encoding/decoding text.
97
+ eval_wrapper (LM): A LM wrapper class compatible with lm-evaluation-harness evaluation
100
98
task (str): The name of the evaluation task to perform.
101
99
limit (Optional[int]): The maximum number of samples to evaluate (None for all available).
102
- max_seq_length (Optional[int]): The maximum sequence length allowed for input text.
103
100
104
101
Returns:
105
102
eval_results (dict): A dictionary of evaluation results for the specified task(s).
@@ -108,25 +105,46 @@ def eval(
108
105
if tasks is None :
109
106
tasks = ["wikitext" ]
110
107
111
- model_eval_wrapper = GPTFastEvalWrapper (
112
- model ,
113
- tokenizer ,
114
- max_seq_length ,
115
- )
116
-
117
108
if "hendrycks_test" in tasks :
118
109
tasks .remove ("hendrycks_test" )
119
110
tasks += list (lm_eval .tasks .hendrycks_test .create_all_tasks ().keys ())
120
111
task_dict = get_task_dict (tasks )
121
112
122
113
eval_results = evaluate (
123
- model_eval_wrapper ,
114
+ eval_wrapper ,
124
115
task_dict ,
125
116
limit = limit ,
126
117
)
127
118
return eval_results
128
119
129
120
121
+ def gen_eval_wrapper (
122
+ model_name : str ,
123
+ args : argparse .ArgumentParser ,
124
+ ) -> LM :
125
+ """
126
+ Generates a wrapper interface around the provided model and tokenizer for
127
+ the lm-evaluation-harness library.
128
+
129
+ Returns:
130
+ eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library.
131
+ """
132
+ tokenizer = SentencePieceProcessor (model_file = str (args .tokenizer_path ))
133
+
134
+ # GPTFastEvalWrapper: Create a wrapper around a pre-exported model
135
+ manager : LlamaEdgeManager = _prepare_for_llama_export (model_name , args )
136
+ model = (
137
+ manager .model .eval ().to (device = "cuda" )
138
+ if torch .cuda .is_available ()
139
+ else manager .model .to (device = "cpu" )
140
+ )
141
+ return GPTFastEvalWrapper (
142
+ model = model ,
143
+ tokenizer = tokenizer ,
144
+ max_seq_length = args .max_seq_length ,
145
+ )
146
+
147
+
130
148
def build_args_parser () -> argparse .ArgumentParser :
131
149
# Start with arg parser from export_llama_lib
132
150
parser = _build_args_parser ()
@@ -150,22 +168,14 @@ def eval_llama(
150
168
model_name : str ,
151
169
args : argparse .ArgumentParser ,
152
170
) -> None :
153
- # Get a pre-lowering/to_edge LlamaEdgeManager instance
154
- manager : LlamaEdgeManager = _prepare_for_llama_export (model_name , args )
155
- tokenizer = SentencePieceProcessor (model_file = str (args .tokenizer_path ))
171
+ # Generate the eval wrapper
172
+ eval_wrapper = gen_eval_wrapper (model_name , args )
156
173
157
174
# Evaluate the model
158
- model = (
159
- manager .model .eval ().to (device = "cuda" )
160
- if torch .cuda .is_available ()
161
- else manager .model .to (device = "cpu" )
162
- )
163
175
eval_results = eval (
164
- model ,
165
- tokenizer ,
176
+ eval_wrapper ,
166
177
args .tasks ,
167
178
args .limit ,
168
- args .max_seq_length ,
169
179
)
170
180
171
181
for task , res in eval_results ["results" ].items ():
0 commit comments