9
9
from typing import Optional
10
10
11
11
import lm_eval
12
-
13
12
import torch
13
+
14
14
from lm_eval .api .model import LM
15
15
from lm_eval .evaluator import evaluate
16
16
from lm_eval .models .huggingface import HFLM as eval_wrapper
@@ -33,7 +33,7 @@ class GPTFastEvalWrapper(eval_wrapper):
33
33
def __init__ (
34
34
self ,
35
35
model : nn .Module ,
36
- tokenizer ,
36
+ tokenizer : SentencePieceProcessor ,
37
37
max_seq_length : Optional [int ] = None ,
38
38
):
39
39
super ().__init__ ()
@@ -97,16 +97,18 @@ def __init__(
97
97
max_seq_length : Optional [int ] = None ,
98
98
):
99
99
super ().__init__ (None , tokenizer , max_seq_length )
100
- self ._model = model
100
+ self ._model = model # Expects model to be path to a .pte file
101
101
102
- def _model_call (self , inps ):
103
- # Given inps (tokens), return the logits from a single
104
- # forward call
102
+ from executorch .extension .pybindings .portable_lib import _load_for_executorch
105
103
106
- # Example:
107
- # inps: Tensor of shape (1, N)
108
- # logits: Tensor of shape (1, N, 32000)
109
- pass
104
+ self ._et_model = _load_for_executorch (self ._model )
105
+
106
+ def _model_call (self , inps ):
107
+ # Given inps (tokens), return the logits from a single forward call
108
+ # inps: Tensor of shape (1, max_seq_len - 1)
109
+ # logits: Tensor of shape (1, max_seq_len - 1, 32000)
110
+ result = self ._et_model .forward ((inps ,))
111
+ return result [0 ]
110
112
111
113
112
114
class ETRunnerEvalWrapper (GPTFastEvalWrapper ):
@@ -198,7 +200,9 @@ def gen_eval_wrapper(
198
200
return ETEagerEvalWrapper (
199
201
model = model ,
200
202
tokenizer = tokenizer ,
201
- max_seq_length = args .max_seq_length ,
203
+ # Exported model takes at most (max_seq_length - 1) tokens.
204
+ # Note that the eager model takes at most max_seq_length tokens.
205
+ max_seq_length = args .max_seq_length - 1 ,
202
206
)
203
207
204
208
# GPTFastEvalWrapper: Create a wrapper around a pre-exported model
0 commit comments