22
22
from . import llama_cpp
23
23
from .llama_types import *
24
24
25
+ import numpy as np
26
+ import numpy .typing as npt
27
+
25
28
26
29
27
30
class LlamaCache :
@@ -76,11 +79,15 @@ def __init__(
76
79
self ,
77
80
eval_tokens : Deque [int ],
78
81
eval_logits : Deque [List [float ]],
82
+ input_ids : npt .NDArray [np .intc ],
83
+ scores : npt .NDArray [np .single ],
79
84
llama_state , # type: llama_cpp.Array[llama_cpp.c_uint8]
80
85
llama_state_size : int ,
81
86
):
82
87
self .eval_tokens = eval_tokens
83
88
self .eval_logits = eval_logits
89
+ self .input_ids = input_ids
90
+ self .scores = scores
84
91
self .llama_state = llama_state
85
92
self .llama_state_size = llama_state_size
86
93
@@ -210,27 +217,27 @@ def __init__(
210
217
211
218
self ._n_vocab = self .n_vocab ()
212
219
self ._n_ctx = self .n_ctx ()
213
- data = (llama_cpp .llama_token_data * self ._n_vocab )(
214
- * [
215
- llama_cpp .llama_token_data (
216
- id = llama_cpp .llama_token (i ),
217
- logit = llama_cpp .c_float (0.0 ),
218
- p = llama_cpp .c_float (0.0 ),
219
- )
220
- for i in range (self ._n_vocab )
221
- ]
222
- )
223
220
size = llama_cpp .c_size_t (self ._n_vocab )
224
- sorted = False
221
+ sorted = llama_cpp .c_bool (False )
222
+ self ._candidates_data = np .array (
223
+ [],
224
+ dtype = np .dtype (
225
+ [("id" , np .intc ), ("logit" , np .single ), ("p" , np .single )], align = True
226
+ ),
227
+ )
228
+ self ._candidates_data .resize (3 , self ._n_vocab )
225
229
candidates = llama_cpp .llama_token_data_array (
226
- data = data ,
230
+ data = self . _candidates_data . ctypes . data_as ( llama_cpp . llama_token_data_p ) ,
227
231
size = size ,
228
232
sorted = sorted ,
229
233
)
230
234
self ._candidates = candidates
231
235
self ._token_nl = Llama .token_nl ()
232
236
self ._token_eos = Llama .token_eos ()
233
237
238
+ self ._input_ids = np .array ([], dtype = np .intc )
239
+ self ._scores = np .ndarray ((0 , self ._n_vocab ), dtype = np .single )
240
+
234
241
def tokenize (self , text : bytes , add_bos : bool = True ) -> List [int ]:
235
242
"""Tokenize a string.
236
243
@@ -298,6 +305,8 @@ def reset(self):
298
305
"""Reset the model state."""
299
306
self .eval_tokens .clear ()
300
307
self .eval_logits .clear ()
308
+ self ._input_ids = np .array ([], dtype = np .intc )
309
+ self ._scores = np .ndarray ((0 , self ._n_vocab ), dtype = np .single )
301
310
302
311
def eval (self , tokens : Sequence [int ]):
303
312
"""Evaluate a list of tokens.
@@ -309,7 +318,7 @@ def eval(self, tokens: Sequence[int]):
309
318
n_ctx = self ._n_ctx
310
319
for i in range (0 , len (tokens ), self .n_batch ):
311
320
batch = tokens [i : min (len (tokens ), i + self .n_batch )]
312
- n_past = min (n_ctx - len (batch ), len (self .eval_tokens ))
321
+ n_past = min (n_ctx - len (batch ), len (self ._input_ids ))
313
322
n_tokens = len (batch )
314
323
return_code = llama_cpp .llama_eval (
315
324
ctx = self .ctx ,
@@ -322,13 +331,19 @@ def eval(self, tokens: Sequence[int]):
322
331
raise RuntimeError (f"llama_eval returned { return_code } " )
323
332
# Save tokens
324
333
self .eval_tokens .extend (batch )
334
+ self ._input_ids : npt .NDArray [np .intc ] = np .concatenate (
335
+ (self ._input_ids , np .array (batch , dtype = np .intc )), axis = 0
336
+ )
325
337
# Save logits
326
338
rows = n_tokens if self .params .logits_all else 1
327
339
n_vocab = self ._n_vocab
328
340
cols = n_vocab
329
341
logits_view = llama_cpp .llama_get_logits (self .ctx )
330
342
logits = [logits_view [i * cols : (i + 1 ) * cols ] for i in range (rows )]
331
343
self .eval_logits .extend (logits )
344
+ self ._scores : npt .NDArray [np .single ] = np .concatenate (
345
+ (self ._scores , np .array (logits , dtype = np .single )), axis = 0
346
+ )
332
347
333
348
def _sample (
334
349
self ,
@@ -349,6 +364,7 @@ def _sample(
349
364
):
350
365
assert self .ctx is not None
351
366
assert len (self .eval_logits ) > 0
367
+ assert self ._scores .shape [0 ] > 0
352
368
n_vocab = self ._n_vocab
353
369
n_ctx = self ._n_ctx
354
370
top_k = llama_cpp .c_int (n_vocab ) if top_k .value <= 0 else top_k
@@ -357,18 +373,23 @@ def _sample(
357
373
if last_n_tokens_size .value < 0
358
374
else last_n_tokens_size
359
375
)
360
- logits = self .eval_logits [- 1 ]
376
+ logits : npt . NDArray [ np . single ] = self ._scores [- 1 , : ]
361
377
362
378
if logits_processor is not None :
363
- logits = logits_processor (list (self .eval_tokens ), logits )
364
- self .eval_logits [- 1 ] = logits
379
+ logits = np .array (
380
+ logits_processor (self ._input_ids .tolist (), logits .tolist ()),
381
+ dtype = np .single ,
382
+ )
383
+ self ._scores [- 1 , :] = logits
384
+ self .eval_logits [- 1 ] = logits .tolist ()
365
385
366
386
nl_logit = logits [self ._token_nl ]
367
387
candidates = self ._candidates
368
- for i , logit in enumerate (logits ):
369
- candidates .data [i ].id = llama_cpp .llama_token (i )
370
- candidates .data [i ].logit = llama_cpp .c_float (logit )
371
- candidates .data [i ].p = llama_cpp .c_float (0.0 )
388
+ candidates_data = self ._candidates_data
389
+ candidates_data ["id" ] = np .arange (n_vocab , dtype = np .intc ) # type: ignore
390
+ candidates_data ["logit" ] = logits
391
+ candidates_data ["p" ] = np .zeros (n_vocab , dtype = np .single )
392
+ candidates .data = candidates_data .ctypes .data_as (llama_cpp .llama_token_data_p )
372
393
candidates .sorted = llama_cpp .c_bool (False )
373
394
candidates .size = llama_cpp .c_size_t (n_vocab )
374
395
llama_cpp .llama_sample_repetition_penalty (
@@ -486,8 +507,8 @@ def sample(
486
507
"""
487
508
assert self .ctx is not None
488
509
last_n_tokens_data = [llama_cpp .llama_token (0 )] * max (
489
- 0 , self .last_n_tokens_size - len (self .eval_tokens )
490
- ) + list ( self .eval_tokens ) [- self .last_n_tokens_size :]
510
+ 0 , self .last_n_tokens_size - len (self ._input_ids )
511
+ ) + self ._input_ids [- self .last_n_tokens_size :]. tolist ()
491
512
return self ._sample (
492
513
last_n_tokens_data = (llama_cpp .llama_token * self .last_n_tokens_size )(
493
514
* last_n_tokens_data
@@ -545,9 +566,9 @@ def generate(
545
566
"""
546
567
assert self .ctx is not None
547
568
548
- if reset and len (self .eval_tokens ) > 0 :
569
+ if reset and len (self ._input_ids ) > 0 :
549
570
longest_prefix = 0
550
- for a , b in zip (self .eval_tokens , tokens [:- 1 ]):
571
+ for a , b in zip (self ._input_ids , tokens [:- 1 ]):
551
572
if a == b :
552
573
longest_prefix += 1
553
574
else :
@@ -557,6 +578,8 @@ def generate(
557
578
print ("Llama.generate: prefix-match hit" , file = sys .stderr )
558
579
reset = False
559
580
tokens = tokens [longest_prefix :]
581
+ self ._input_ids = self ._input_ids [:longest_prefix ]
582
+ self ._scores = self ._scores [:longest_prefix , :]
560
583
for _ in range (len (self .eval_tokens ) - longest_prefix ):
561
584
self .eval_tokens .pop ()
562
585
try :
@@ -583,7 +606,7 @@ def generate(
583
606
logits_processor = logits_processor ,
584
607
)
585
608
if stopping_criteria is not None and stopping_criteria (
586
- list ( self .eval_tokens ), self .eval_logits [- 1 ]
609
+ self ._input_ids . tolist ( ), self ._scores [- 1 , :]. tolist ()
587
610
):
588
611
return
589
612
tokens_or_none = yield token
@@ -718,10 +741,10 @@ def _create_completion(
718
741
try :
719
742
cache_item = self .cache [prompt_tokens ]
720
743
cache_prefix_len = Llama .longest_token_prefix (
721
- cache_item .eval_tokens , prompt_tokens
744
+ cache_item .input_ids . tolist () , prompt_tokens
722
745
)
723
746
eval_prefix_len = Llama .longest_token_prefix (
724
- self .eval_tokens , prompt_tokens
747
+ self ._input_ids . tolist () , prompt_tokens
725
748
)
726
749
if cache_prefix_len > eval_prefix_len :
727
750
self .load_state (cache_item )
@@ -810,7 +833,7 @@ def _create_completion(
810
833
self .detokenize (completion_tokens [:returned_tokens ])
811
834
)
812
835
token_offset = len (prompt_tokens ) + returned_tokens
813
- logits = self .eval_logits [token_offset - 1 ]
836
+ logits = self ._scores [token_offset - 1 , :]. tolist ()
814
837
current_logprobs = Llama .logits_to_logprobs (logits )
815
838
sorted_logprobs = list (
816
839
sorted (
@@ -859,7 +882,7 @@ def _create_completion(
859
882
break
860
883
861
884
if stopping_criteria is not None and stopping_criteria (
862
- list ( self .eval_tokens ), self .eval_logits [- 1 ]
885
+ self ._input_ids . tolist ( ), self ._scores [- 1 , :]. tolist ()
863
886
):
864
887
text = self .detokenize (completion_tokens )
865
888
finish_reason = "stop"
@@ -889,7 +912,7 @@ def _create_completion(
889
912
self .detokenize (completion_tokens [:returned_tokens ])
890
913
)
891
914
token_offset = len (prompt_tokens ) + returned_tokens - 1
892
- logits = self .eval_logits [token_offset ]
915
+ logits = self ._scores [token_offset , :]. tolist ()
893
916
current_logprobs = Llama .logits_to_logprobs (logits )
894
917
sorted_logprobs = list (
895
918
sorted (
@@ -991,8 +1014,7 @@ def _create_completion(
991
1014
for token in all_tokens
992
1015
]
993
1016
all_logprobs = [
994
- Llama .logits_to_logprobs (list (map (float , row )))
995
- for row in self .eval_logits
1017
+ Llama .logits_to_logprobs (row .tolist ()) for row in self ._scores
996
1018
][token_offset :]
997
1019
for token , token_str , logprobs_token in zip (
998
1020
all_tokens , all_token_strs , all_logprobs
@@ -1376,6 +1398,8 @@ def save_state(self) -> LlamaState:
1376
1398
return LlamaState (
1377
1399
eval_tokens = self .eval_tokens .copy (),
1378
1400
eval_logits = self .eval_logits .copy (),
1401
+ scores = self ._scores .copy (),
1402
+ input_ids = self ._input_ids .copy (),
1379
1403
llama_state = llama_state_compact ,
1380
1404
llama_state_size = n_bytes ,
1381
1405
)
@@ -1384,6 +1408,8 @@ def load_state(self, state: LlamaState) -> None:
1384
1408
assert self .ctx is not None
1385
1409
self .eval_tokens = state .eval_tokens .copy ()
1386
1410
self .eval_logits = state .eval_logits .copy ()
1411
+ self ._scores = state .scores .copy ()
1412
+ self ._input_ids = state .input_ids .copy ()
1387
1413
state_size = state .llama_state_size
1388
1414
if llama_cpp .llama_set_state_data (self .ctx , state .llama_state ) != state_size :
1389
1415
raise RuntimeError ("Failed to set llama state data" )
0 commit comments