Skip to content

Commit d405f0a

Browse files
HDCharlesmalfet
authored andcommitted
adding run time info to eval and cleaning up output (#422)
* adding run time info to eval and cleaning up output Summary: output now includes info on model run time distribution and a cleaned up result output. Test Plan: python eval.py --checkpoint-path checkpoints/$MODEL_REPO/model.pth \ --dtype bfloat16 --device cuda \ Time to run eval: 53.31s. Time in model.forward: 20.29s, over 186 model evaluations forward run time stats - Median: 0.10s Min: 0.04s Max: 2.18s For model checkpoints/meta-llama/Llama-2-7b-hf/model.pth wikitext: word_perplexity,none: 9.1649 byte_perplexity,none: 1.5133 bits_per_byte,none: 0.5977 alias: wikitext Reviewers: Subscribers: Tasks: Tags: * Adding evaluation.md content Summary: see added content Test Plan: n/a Reviewers: Subscribers: Tasks: Tags: * docs update Summary: removing install instructions Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent dad64b4 commit d405f0a

File tree

2 files changed

+40
-7
lines changed

2 files changed

+40
-7
lines changed

docs/evaluation.md

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,30 @@
11

2-
# Model evaluation
2+
Evaluation Features
3+
===================
34

4-
TODO(jerry):
5-
Add documentation about `torchchat eval` explaining the process and options.
5+
Torchchat provides evaluation functionality for your language model on a variety of tasks using the [lm-evaluation-harness](https://github.com/facebookresearch/lm_eval) library.
66

7-
[#339](https://github.com/pytorch/torchchat/issues/339)
7+
Usage
8+
-----
9+
10+
The evaluation mode of `torchchat.py` script can be used to evaluate your language model on various tasks available in the `lm_eval` library such as "wikitext". You can specify the task(s) you want to evaluate using the `--tasks` option, and limit the evaluation using the `--limit` option. If no task is specified, it will default to evaluating on "wikitext".
11+
12+
**Examples**
13+
14+
Running wikitext for 10 iterations
15+
```
16+
python3 torchchat.py eval stories15M --tasks wikitext --limit 10
17+
```
18+
19+
Running an exported model
20+
```
21+
# python3 torchchat.py export stories15M --output-pte-path stories15M.pte
22+
python3 torchchat.py eval --pte-path stories15M.pte
23+
```
24+
25+
Running multiple tasks and calling eval.py directly:
26+
```
27+
python3 eval.py --pte-path stories15M.pte --tasks wikitext hellaswag
28+
```
29+
30+
For more information and a list of tasks/metrics see [lm-evaluation-harness](https://github.com/facebookresearch/lm_eval).

eval.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
torch._inductor.config.epilogue_fusion = False
2929
torch._inductor.config.triton.cudagraphs = True
3030
torch._dynamo.config.cache_size_limit = 100000
31-
31+
import time
3232

3333
try:
3434
import lm_eval
@@ -108,6 +108,7 @@ def __init__(
108108
self._tokenizer = tokenizer
109109
self._device = torch.device(device)
110110
self._max_seq_length = 2048 if max_seq_length is None else max_seq_length
111+
self.times = []
111112

112113
@property
113114
def eot_token_id(self):
@@ -155,7 +156,9 @@ def _model_call(self, inps):
155156
)
156157
)
157158
x = seq.index_select(0, input_pos).view(1, -1)
159+
start = time.time()
158160
logits = model_forward(self._model, x, input_pos)
161+
self.times.append(time.time()-start)
159162
return logits
160163

161164
def _model_generate(self, context, max_length, eos_token_id):
@@ -206,6 +209,7 @@ def eval(
206209
task_dict,
207210
limit=limit,
208211
)
212+
eval_results["times"] = model_eval_wrapper.times
209213
return eval_results
210214

211215

@@ -261,7 +265,10 @@ def main(args) -> None:
261265
max_seq_length,
262266
device=builder_args.device,
263267
)
264-
print(f"Time to run eval: {time.time() - t1:.02f} seconds.")
268+
print(f"Time to run eval: {time.time() - t1:.02f}s.")
269+
times=torch.tensor(result["times"])
270+
print(f"Time in model.forward: {times.sum():.02f}s, over {times.numel()} model evaluations")
271+
print(f"forward run time stats - Median: {times.median():.02f}s Min: {times.min():.02f}s Max: {times.max():.02f}s")
265272
if builder_args.dso_path:
266273
print(f"For model {builder_args.dso_path}")
267274
elif builder_args.pte_path:
@@ -274,7 +281,10 @@ def main(args) -> None:
274281
raise RuntimeError("Well That's Fine. How did we get here")
275282

276283
for task, res in result["results"].items():
277-
print(f"{task}: {res}")
284+
print(f"{task}:")
285+
for metric, val in res.items():
286+
if val != "N/A":
287+
print(f" {metric}: {val if isinstance(val, str) else f'{val:0.4f}'}")
278288

279289

280290
if __name__ == "__main__":

0 commit comments

Comments
 (0)