File tree Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Original file line number Diff line number Diff line change @@ -107,11 +107,12 @@ def __init__(
107
107
model : Transformer ,
108
108
tokenizer ,
109
109
max_seq_length : Optional [int ] = None ,
110
+ device = "cpu"
110
111
):
111
112
super ().__init__ ()
112
113
self ._model = model
113
114
self ._tokenizer = tokenizer
114
- self ._device = torch .device ("cuda" )
115
+ self ._device = torch .device (device )
115
116
self ._max_seq_length = 2048 if max_seq_length is None else max_seq_length
116
117
117
118
@property
@@ -174,6 +175,7 @@ def eval(
174
175
tasks : Optional [list ] = None ,
175
176
limit : Optional [int ] = None ,
176
177
max_seq_length : Optional [int ] = None ,
178
+ device : str = "cpu"
177
179
) -> dict :
178
180
"""
179
181
Evaluates a language model on a specified task using the lm-evaluation-harness library.
@@ -195,6 +197,7 @@ def eval(
195
197
model ,
196
198
tokenizer ,
197
199
max_seq_length ,
200
+ device = device
198
201
)
199
202
200
203
try :
@@ -267,6 +270,7 @@ def main(args) -> None:
267
270
tasks ,
268
271
limit ,
269
272
max_seq_length ,
273
+ device = builder_args .device ,
270
274
)
271
275
print (f"Time to run eval: { time .time () - t1 :.02f} seconds." )
272
276
if builder_args .dso_path :
You can’t perform that action at this time.
0 commit comments