@@ -177,19 +177,19 @@ def __init__(
177
177
if self .verbose :
178
178
print (llama_cpp .llama_print_system_info ().decode ("utf-8" ), file = sys .stderr )
179
179
180
- n_vocab = self .n_vocab ()
181
- n_ctx = self .n_ctx ()
182
- data = (llama_cpp .llama_token_data * n_vocab )(
180
+ self . _n_vocab = self .n_vocab ()
181
+ self . _n_ctx = self .n_ctx ()
182
+ data = (llama_cpp .llama_token_data * self . _n_vocab )(
183
183
* [
184
184
llama_cpp .llama_token_data (
185
185
id = llama_cpp .llama_token (i ),
186
186
logit = llama_cpp .c_float (0.0 ),
187
187
p = llama_cpp .c_float (0.0 ),
188
188
)
189
- for i in range (n_vocab )
189
+ for i in range (self . _n_vocab )
190
190
]
191
191
)
192
- size = llama_cpp .c_size_t (n_vocab )
192
+ size = llama_cpp .c_size_t (self . _n_vocab )
193
193
sorted = False
194
194
candidates = llama_cpp .llama_token_data_array (
195
195
data = data ,
@@ -213,18 +213,18 @@ def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
213
213
A list of tokens.
214
214
"""
215
215
assert self .ctx is not None
216
- n_ctx = llama_cpp . llama_n_ctx ( self .ctx )
217
- tokens = (llama_cpp .llama_token * int ( n_ctx ) )()
216
+ n_ctx = self ._n_ctx
217
+ tokens = (llama_cpp .llama_token * n_ctx )()
218
218
n_tokens = llama_cpp .llama_tokenize (
219
219
self .ctx ,
220
220
text ,
221
221
tokens ,
222
222
llama_cpp .c_int (n_ctx ),
223
223
llama_cpp .c_bool (add_bos ),
224
224
)
225
- if int ( n_tokens ) < 0 :
225
+ if n_tokens < 0 :
226
226
n_tokens = abs (n_tokens )
227
- tokens = (llama_cpp .llama_token * int ( n_tokens ) )()
227
+ tokens = (llama_cpp .llama_token * n_tokens )()
228
228
n_tokens = llama_cpp .llama_tokenize (
229
229
self .ctx ,
230
230
text ,
@@ -275,7 +275,7 @@ def eval(self, tokens: Sequence[int]):
275
275
tokens: The list of tokens to evaluate.
276
276
"""
277
277
assert self .ctx is not None
278
- n_ctx = int ( llama_cpp . llama_n_ctx ( self .ctx ))
278
+ n_ctx = self ._n_ctx
279
279
for i in range (0 , len (tokens ), self .n_batch ):
280
280
batch = tokens [i : min (len (tokens ), i + self .n_batch )]
281
281
n_past = min (n_ctx - len (batch ), len (self .eval_tokens ))
@@ -287,18 +287,16 @@ def eval(self, tokens: Sequence[int]):
287
287
n_past = llama_cpp .c_int (n_past ),
288
288
n_threads = llama_cpp .c_int (self .n_threads ),
289
289
)
290
- if int ( return_code ) != 0 :
290
+ if return_code != 0 :
291
291
raise RuntimeError (f"llama_eval returned { return_code } " )
292
292
# Save tokens
293
293
self .eval_tokens .extend (batch )
294
294
# Save logits
295
295
rows = n_tokens if self .params .logits_all else 1
296
- n_vocab = llama_cpp . llama_n_vocab ( self .ctx )
297
- cols = int ( n_vocab )
296
+ n_vocab = self ._n_vocab
297
+ cols = n_vocab
298
298
logits_view = llama_cpp .llama_get_logits (self .ctx )
299
- logits : List [List [float ]] = [
300
- [logits_view [i * cols + j ] for j in range (cols )] for i in range (rows )
301
- ]
299
+ logits = [logits_view [i * cols : (i + 1 ) * cols ] for i in range (rows )]
302
300
self .eval_logits .extend (logits )
303
301
304
302
def _sample (
@@ -319,8 +317,8 @@ def _sample(
319
317
):
320
318
assert self .ctx is not None
321
319
assert len (self .eval_logits ) > 0
322
- n_vocab = self .n_vocab ()
323
- n_ctx = self .n_ctx ()
320
+ n_vocab = self ._n_vocab
321
+ n_ctx = self ._n_ctx
324
322
top_k = llama_cpp .c_int (n_vocab ) if top_k .value <= 0 else top_k
325
323
last_n_tokens_size = (
326
324
llama_cpp .c_int (n_ctx )
@@ -654,9 +652,9 @@ def _create_completion(
654
652
if self .verbose :
655
653
llama_cpp .llama_reset_timings (self .ctx )
656
654
657
- if len (prompt_tokens ) + max_tokens > int ( llama_cpp . llama_n_ctx ( self .ctx )) :
655
+ if len (prompt_tokens ) + max_tokens > self ._n_ctx :
658
656
raise ValueError (
659
- f"Requested tokens exceed context window of { llama_cpp . llama_n_ctx ( self .ctx ) } "
657
+ f"Requested tokens exceed context window of { self ._n_ctx } "
660
658
)
661
659
662
660
if stop != []:
0 commit comments