4
4
# This source code is licensed under the license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
import argparse
7
- import time
8
- from typing import Optional
7
+ from typing import Callable , Optional
9
8
10
9
import torch
11
10
import torch ._dynamo .config
20
19
from build .model import Transformer
21
20
from build .utils import set_precision
22
21
from cli import add_arguments_for_verb , arg_init
23
- from generate import encode_tokens , model_forward
24
22
from utils .measure_time import measure_time
25
23
26
24
torch ._dynamo .config .automatic_dynamic_shapes = True
@@ -85,11 +83,17 @@ def __init__(
85
83
self ,
86
84
model : Transformer ,
87
85
tokenizer ,
86
+ model_forward : Optional [Callable ] = None ,
88
87
max_seq_length : Optional [int ] = None ,
89
88
device = "cpu" ,
90
89
):
91
90
super ().__init__ (device = device )
92
91
self ._model = model
92
+ self ._model_forward = (
93
+ model_forward
94
+ if model_forward is not None
95
+ else lambda x , input_pos : model (x , input_pos )
96
+ )
93
97
self ._tokenizer = tokenizer
94
98
self ._device = torch .device (device )
95
99
self ._max_seq_length = 2048 if max_seq_length is None else max_seq_length
@@ -116,11 +120,8 @@ def device(self):
116
120
return self ._device
117
121
118
122
def tok_encode (self , string : str , ** kwargs ):
119
- encoded = encode_tokens (self ._tokenizer , string , bos = True , device = self ._device )
120
- # encoded is a pytorch tensor, but some internal logic in the
121
- # eval harness expects it to be a list instead
122
- # TODO: verify this for multi-batch as well
123
- encoded = encoded .tolist ()
123
+ bos_id = self ._tokenizer .bos_id ()
124
+ encoded = [bos_id ] + self ._tokenizer .encode (string )
124
125
return encoded
125
126
126
127
def tok_decode (self , tokens ):
@@ -142,7 +143,7 @@ def _model_call(self, inps):
142
143
)
143
144
x = seq .index_select (0 , input_pos ).view (1 , - 1 )
144
145
with measure_time (message = None ) as measure :
145
- logits = model_forward ( self ._model , x , input_pos )
146
+ logits = self ._model_forward ( x , input_pos )
146
147
self .times .append (measure .get_time ())
147
148
return logits
148
149
@@ -153,6 +154,7 @@ def _model_generate(self, context, max_length, eos_token_id):
153
154
@torch .no_grad ()
154
155
def eval (
155
156
model : Transformer ,
157
+ model_forward : Callable ,
156
158
tokenizer ,
157
159
tasks : Optional [list ] = None ,
158
160
limit : Optional [int ] = None ,
@@ -176,7 +178,11 @@ def eval(
176
178
tasks = ["wikitext" ]
177
179
178
180
model_eval_wrapper = GPTFastEvalWrapper (
179
- model , tokenizer , max_seq_length , device = device
181
+ model ,
182
+ tokenizer ,
183
+ model_forward = model_forward ,
184
+ max_seq_length = max_seq_length ,
185
+ device = device ,
180
186
)
181
187
182
188
try :
@@ -231,11 +237,12 @@ def main(args) -> None:
231
237
)
232
238
tokenizer_args .validate_model (model )
233
239
240
+ model_forward = lambda x , input_pos : model (x , input_pos ) # noqa
241
+
234
242
if compile :
235
243
assert not (
236
244
builder_args .dso_path or builder_args .pte_path
237
245
), "cannot compile exported model"
238
- global model_forward
239
246
model_forward = torch .compile (
240
247
model_forward , mode = "reduce-overhead" , dynamic = True , fullgraph = True
241
248
)
@@ -244,6 +251,7 @@ def main(args) -> None:
244
251
with measure_time ("Time to run eval: {time:.02f}s." ):
245
252
result = eval (
246
253
model .to (device ),
254
+ model_forward ,
247
255
tokenizer ,
248
256
tasks ,
249
257
limit ,
0 commit comments