@@ -51,10 +51,11 @@ def next_token(logits: torch.Tensor, temperature: float, top_p: float) -> int:
51
51
52
52
53
53
class LlamaRunner (ABC ):
54
- def __init__ (self , tokenizer_path : str , model_args : ModelArgs ):
54
+ def __init__ (self , tokenizer_path : str , model_args : ModelArgs , device : str = "cpu" ):
55
55
self .params = model_args
56
56
self .tokenizer = get_tokenizer (tokenizer_path )
57
57
assert model_args .vocab_size == self .tokenizer .n_words
58
+ self .device = device
58
59
59
60
@abstractmethod
60
61
def forward (
@@ -73,9 +74,9 @@ def generate( # noqa: C901
73
74
) -> List [int ]:
74
75
# prefill
75
76
logits = self .forward (
76
- tokens = torch .tensor ([prompt_tokens ], dtype = torch .long ),
77
+ tokens = torch .tensor ([prompt_tokens ], dtype = torch .long , device = self . device ),
77
78
input_pos = (
78
- torch .tensor ([0 ], dtype = torch .long )
79
+ torch .tensor ([0 ], dtype = torch .long , device = self . device )
79
80
if self .params .use_kv_cache
80
81
else None
81
82
),
@@ -87,14 +88,21 @@ def generate( # noqa: C901
87
88
while len (tokens ) < self .params .max_seq_len :
88
89
if self .params .use_kv_cache :
89
90
logits = self .forward (
90
- tokens = torch .tensor ([[current_token ]], dtype = torch .long ),
91
- input_pos = torch .tensor ([len (tokens ) - 1 ], dtype = torch .long ),
91
+ tokens = torch .tensor (
92
+ [[current_token ]], dtype = torch .long , device = self .device
93
+ ),
94
+ input_pos = torch .tensor (
95
+ [len (tokens ) - 1 ], dtype = torch .long , device = self .device
96
+ ),
92
97
)
93
98
else :
94
- logits = self .forward (tokens = torch .tensor ([tokens ], dtype = torch .long ))
99
+ logits = self .forward (
100
+ tokens = torch .tensor ([tokens ], dtype = torch .long , device = self .device ),
101
+ )
95
102
current_token = next_token (logits , temperature , top_p )
96
103
if current_token == self .tokenizer .eos_id or (
97
- hasattr (self , "stop_tokens" ) and current_token in self .stop_tokens
104
+ hasattr (self .tokenizer , "stop_tokens" )
105
+ and current_token in self .tokenizer .stop_tokens
98
106
):
99
107
break
100
108
tokens .append (current_token )
0 commit comments