@@ -84,6 +84,58 @@ def _model_generate(self, context, max_length, eos_token_id):
84
84
raise Exception ("unimplemented" )
85
85
86
86
87
+ class ETEagerEvalWrapper (GPTFastEvalWrapper ):
88
+ """
89
+ A wrapper class for ExecuTorch Eager integration with the
90
+ lm-evaluation-harness library.
91
+ """
92
+
93
+ def __init__ (
94
+ self ,
95
+ model : str ,
96
+ tokenizer : SentencePieceProcessor ,
97
+ max_seq_length : Optional [int ] = None ,
98
+ ):
99
+ super ().__init__ (None , tokenizer , max_seq_length )
100
+ self ._model = model
101
+
102
+ def _model_call (self , inps ):
103
+ # Given inps (tokens), return the logits from a single
104
+ # forward call
105
+
106
+ # Example:
107
+ # inps: Tensor of shape (1, N)
108
+ # logits: Tensor of shape (1, N, 32000)
109
+ pass
110
+
111
+
112
+ class ETRunnerEvalWrapper (GPTFastEvalWrapper ):
113
+ """
114
+ A wrapper class for ExecuTorch Runtime integration with the
115
+ lm-evaluation-harness library.
116
+ """
117
+
118
+ def __init__ (
119
+ self ,
120
+ model : str ,
121
+ tokenizer : SentencePieceProcessor ,
122
+ tokenizer_bin : str ,
123
+ max_seq_length : Optional [int ] = None ,
124
+ ):
125
+ super ().__init__ (None , tokenizer , max_seq_length )
126
+ self ._model = model
127
+ self ._tokenizer_bin = tokenizer_bin
128
+
129
+ def _model_call (self , inps ):
130
+ # Given inps (tokens), return the logits from a single
131
+ # forward call
132
+
133
+ # Example:
134
+ # inps: Tensor of shape (1, N)
135
+ # logits: Tensor of shape (1, N, 32000)
136
+ pass
137
+
138
+
87
139
@torch .no_grad ()
88
140
def eval (
89
141
eval_wrapper : LM ,
@@ -131,6 +183,24 @@ def gen_eval_wrapper(
131
183
"""
132
184
tokenizer = SentencePieceProcessor (model_file = str (args .tokenizer_path ))
133
185
186
+ # ExecuTorch Binary Evaluation
187
+ if (model := args .pte ) is not None :
188
+ if (tokenizer_bin := args .tokenizer_bin ) is not None :
189
+ # ETRunnerEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated at runtime
190
+ return ETRunnerEvalWrapper (
191
+ model = model ,
192
+ tokenizer = tokenizer ,
193
+ tokenizer_bin = tokenizer_bin ,
194
+ max_seq_length = args .max_seq_length ,
195
+ )
196
+
197
+ # ETRunnerEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated eagerly
198
+ return ETEagerEvalWrapper (
199
+ model = model ,
200
+ tokenizer = tokenizer ,
201
+ max_seq_length = args .max_seq_length ,
202
+ )
203
+
134
204
# GPTFastEvalWrapper: Create a wrapper around a pre-exported model
135
205
manager : LlamaEdgeManager = _prepare_for_llama_export (model_name , args )
136
206
model = (
@@ -161,6 +231,21 @@ def build_args_parser() -> argparse.ArgumentParser:
161
231
"--limit" , type = int , default = 5 , help = "number of samples to evalulate"
162
232
)
163
233
234
+ # Add additional args specific to eval via an ET Runner
235
+ # Note: For initial integration, the tokenizer.model is also required
236
+ parser .add_argument (
237
+ "--pte" ,
238
+ type = str ,
239
+ default = None ,
240
+ help = "[For ExecuTorch] Path to the ExecuTorch model being evaluated. If provided, don't go through the export flow" ,
241
+ )
242
+ parser .add_argument (
243
+ "--tokenizer_bin" ,
244
+ type = str ,
245
+ default = None ,
246
+ help = "[For ExecuTorch] Path to the Tokenizer binary for evaluating ExecuTorch models via runtime" ,
247
+ )
248
+
164
249
return parser
165
250
166
251
0 commit comments