Skip to content

Commit 579ccce

Browse files
Jack-Khuufacebook-github-bot
authored andcommitted
Format Eval output and enabled cuda support (#2569)
Summary: Pull Request resolved: #2569 When using eval_llama_lib, it is much faster with cuda enabled if possible. This diff enables this In addition it wraps the output format of eval to more digestable Reviewed By: jerryzh168 Differential Revision: D55208754 fbshipit-source-id: 8744d58064b6bcab5567a62bb2bf99fe69507aa1
1 parent 725c590 commit 579ccce

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

examples/models/llama2/eval_llama_lib.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ def __init__(
3838
super().__init__()
3939
self._model = model
4040
self._tokenizer = tokenizer
41-
self._device = torch.device("cpu")
41+
self._device = (
42+
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
43+
)
4244
self._max_seq_length = 2048 if max_seq_length is None else max_seq_length
4345

4446
@property
@@ -153,12 +155,18 @@ def eval_llama(
153155
tokenizer = SentencePieceProcessor(model_file=str(args.tokenizer_path))
154156

155157
# Evaluate the model
158+
model = (
159+
manager.model.eval().to(device="cuda")
160+
if torch.cuda.is_available()
161+
else manager.model.to(device="cpu")
162+
)
156163
eval_results = eval(
157-
manager.model.to(device="cpu"),
164+
model,
158165
tokenizer,
159166
args.tasks,
160167
args.limit,
161168
args.max_seq_length,
162169
)
163170

164-
print("Results: ", eval_results)
171+
for task, res in eval_results["results"].items():
172+
print(f"{task}: {res}")

0 commit comments

Comments
 (0)