Skip to content

Commit 0fdd6ba

Browse files
committed
Update
[ghstack-poisoned]
1 parent c2a0002 commit 0fdd6ba

File tree

4 files changed

+711
-642
lines changed

4 files changed

+711
-642
lines changed

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ python3 torchchat.py generate llama3 --prompt "write me a story about a boy and
118118

119119
For more information run `python3 torchchat.py generate --help`
120120

121-
122121
### Browser
123122

124123
[skip default]: begin

eval.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66
import argparse
7-
import time
8-
from typing import Optional
7+
from typing import Callable, Optional
98

109
import torch
1110
import torch._dynamo.config
@@ -20,7 +19,6 @@
2019
from build.model import Transformer
2120
from build.utils import set_precision
2221
from cli import add_arguments_for_verb, arg_init
23-
from generate import encode_tokens, model_forward
2422
from utils.measure_time import measure_time
2523

2624
torch._dynamo.config.automatic_dynamic_shapes = True
@@ -85,11 +83,17 @@ def __init__(
8583
self,
8684
model: Transformer,
8785
tokenizer,
86+
model_forward: Optional[Callable] = None,
8887
max_seq_length: Optional[int] = None,
8988
device="cpu",
9089
):
9190
super().__init__(device=device)
9291
self._model = model
92+
self._model_forward = (
93+
model_forward
94+
if model_forward is not None
95+
else lambda x, input_pos: model(x, input_pos)
96+
)
9397
self._tokenizer = tokenizer
9498
self._device = torch.device(device)
9599
self._max_seq_length = 2048 if max_seq_length is None else max_seq_length
@@ -116,11 +120,8 @@ def device(self):
116120
return self._device
117121

118122
def tok_encode(self, string: str, **kwargs):
119-
encoded = encode_tokens(self._tokenizer, string, bos=True, device=self._device)
120-
# encoded is a pytorch tensor, but some internal logic in the
121-
# eval harness expects it to be a list instead
122-
# TODO: verify this for multi-batch as well
123-
encoded = encoded.tolist()
123+
bos_id = self._tokenizer.bos_id()
124+
encoded = [bos_id] + self._tokenizer.encode(string)
124125
return encoded
125126

126127
def tok_decode(self, tokens):
@@ -142,7 +143,7 @@ def _model_call(self, inps):
142143
)
143144
x = seq.index_select(0, input_pos).view(1, -1)
144145
with measure_time(message=None) as measure:
145-
logits = model_forward(self._model, x, input_pos)
146+
logits = self._model_forward(x, input_pos)
146147
self.times.append(measure.get_time())
147148
return logits
148149

@@ -153,6 +154,7 @@ def _model_generate(self, context, max_length, eos_token_id):
153154
@torch.no_grad()
154155
def eval(
155156
model: Transformer,
157+
model_forward: Callable,
156158
tokenizer,
157159
tasks: Optional[list] = None,
158160
limit: Optional[int] = None,
@@ -176,7 +178,11 @@ def eval(
176178
tasks = ["wikitext"]
177179

178180
model_eval_wrapper = GPTFastEvalWrapper(
179-
model, tokenizer, max_seq_length, device=device
181+
model,
182+
tokenizer,
183+
model_forward=model_forward,
184+
max_seq_length=max_seq_length,
185+
device=device,
180186
)
181187

182188
try:
@@ -231,11 +237,12 @@ def main(args) -> None:
231237
)
232238
tokenizer_args.validate_model(model)
233239

240+
model_forward = lambda x, input_pos: model(x, input_pos) # noqa
241+
234242
if compile:
235243
assert not (
236244
builder_args.dso_path or builder_args.pte_path
237245
), "cannot compile exported model"
238-
global model_forward
239246
model_forward = torch.compile(
240247
model_forward, mode="reduce-overhead", dynamic=True, fullgraph=True
241248
)
@@ -244,6 +251,7 @@ def main(args) -> None:
244251
with measure_time("Time to run eval: {time:.02f}s."):
245252
result = eval(
246253
model.to(device),
254+
model_forward,
247255
tokenizer,
248256
tasks,
249257
limit,

0 commit comments

Comments
 (0)