We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent f2e660b commit 591a549Copy full SHA for 591a549
examples/models/llama2/eval_llama_lib.py
@@ -42,12 +42,11 @@ def __init__(
42
tokenizer: Union[SentencePieceTokenizer, Tiktoken],
43
max_seq_length: Optional[int] = None,
44
):
45
- super().__init__()
+ device = "cuda" if torch.cuda.is_available() else "cpu"
46
+ super().__init__(device=device)
47
self._model = model
48
self._tokenizer = tokenizer
- self._device = (
49
- torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
50
- )
+ self._device = torch.device(device)
51
self._max_seq_length = 2048 if max_seq_length is None else max_seq_length
52
53
@property
0 commit comments