Skip to content

Commit ae8a253

Browse files
mikekgfbmalfet
authored andcommitted
initial enablement for device support (#314)
1 parent a9f134c commit ae8a253

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

eval.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,12 @@ def __init__(
107107
model: Transformer,
108108
tokenizer,
109109
max_seq_length: Optional[int] = None,
110+
device = "cpu"
110111
):
111112
super().__init__()
112113
self._model = model
113114
self._tokenizer = tokenizer
114-
self._device = torch.device("cuda")
115+
self._device = torch.device(device)
115116
self._max_seq_length = 2048 if max_seq_length is None else max_seq_length
116117

117118
@property
@@ -174,6 +175,7 @@ def eval(
174175
tasks: Optional[list] = None,
175176
limit: Optional[int] = None,
176177
max_seq_length: Optional[int] = None,
178+
device: str = "cpu"
177179
) -> dict:
178180
"""
179181
Evaluates a language model on a specified task using the lm-evaluation-harness library.
@@ -195,6 +197,7 @@ def eval(
195197
model,
196198
tokenizer,
197199
max_seq_length,
200+
device=device
198201
)
199202

200203
try:
@@ -267,6 +270,7 @@ def main(args) -> None:
267270
tasks,
268271
limit,
269272
max_seq_length,
273+
device=builder_args.device,
270274
)
271275
print(f"Time to run eval: {time.time() - t1:.02f} seconds.")
272276
if builder_args.dso_path:

0 commit comments

Comments
 (0)