2
2
import sys
3
3
import uuid
4
4
import time
5
+ import math
5
6
import multiprocessing
6
7
from typing import List , Optional , Union , Generator , Sequence , Iterator
7
8
from collections import deque
10
11
from .llama_types import *
11
12
12
13
14
+ class LlamaCache :
15
+ """Cache for a llama.cpp model.
16
+
17
+ NOTE: This implementation currently only tells the Llama class to avoid reprocessing bytes and continue from the last
18
+ completion. It does not actually cache the results."""
19
+
20
+ pass
21
+
22
+
13
23
class Llama :
14
24
"""High-level Python wrapper for a llama.cpp model."""
15
25
@@ -20,7 +30,7 @@ def __init__(
20
30
n_ctx : int = 512 ,
21
31
n_parts : int = - 1 ,
22
32
seed : int = 1337 ,
23
- f16_kv : bool = False ,
33
+ f16_kv : bool = True ,
24
34
logits_all : bool = False ,
25
35
vocab_only : bool = False ,
26
36
use_mmap : bool = True ,
@@ -75,7 +85,19 @@ def __init__(
75
85
maxlen = self .last_n_tokens_size ,
76
86
)
77
87
self .tokens_consumed = 0
88
+ self .tokens : List [llama_cpp .llama_token ] = []
78
89
self .n_batch = min (n_ctx , n_batch )
90
+ self .n_tokens = 0
91
+ self .n_past = 0
92
+ self .all_logits : List [List [float ]] = [] # TODO: Use an array instead of a list.
93
+
94
+ ### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support
95
+ ### saving and restoring state, this allows us to continue a completion if the last
96
+ ### completion_bytes is a prefix to the prompt passed in. However this is actually incorrect
97
+ ### because it does not take into account stop tokens which have been processed by the model.
98
+ self ._completion_bytes : List [bytes ] = []
99
+ self ._cache : Optional [LlamaCache ] = None
100
+ ###
79
101
80
102
self .n_threads = n_threads or max (multiprocessing .cpu_count () // 2 , 1 )
81
103
@@ -130,12 +152,24 @@ def detokenize(self, tokens: List[llama_cpp.llama_token]) -> bytes:
130
152
output += llama_cpp .llama_token_to_str (self .ctx , token )
131
153
return output
132
154
155
+ def set_cache (self , cache : Optional [LlamaCache ]):
156
+ """Set the cache.
157
+
158
+ Args:
159
+ cache: The cache to set.
160
+ """
161
+ self ._cache = cache
162
+
133
163
def reset (self ):
134
164
"""Reset the model state."""
135
165
self .last_n_tokens_data .extend (
136
166
[llama_cpp .llama_token (0 )] * self .last_n_tokens_size
137
167
)
138
168
self .tokens_consumed = 0
169
+ self .tokens .clear ()
170
+ self .n_tokens = 0
171
+ self .n_past = 0
172
+ self .all_logits .clear ()
139
173
140
174
def eval (self , tokens : Sequence [llama_cpp .llama_token ]):
141
175
"""Evaluate a list of tokens.
@@ -147,18 +181,32 @@ def eval(self, tokens: Sequence[llama_cpp.llama_token]):
147
181
n_ctx = int (llama_cpp .llama_n_ctx (self .ctx ))
148
182
for i in range (0 , len (tokens ), self .n_batch ):
149
183
batch = tokens [i : min (len (tokens ), i + self .n_batch )]
150
- n_past = min (n_ctx - len (batch ), self .tokens_consumed )
184
+ self .n_past = min (n_ctx - len (batch ), self .tokens_consumed )
185
+ self .n_tokens = len (batch )
151
186
return_code = llama_cpp .llama_eval (
152
187
ctx = self .ctx ,
153
188
tokens = (llama_cpp .llama_token * len (batch ))(* batch ),
154
- n_tokens = llama_cpp .c_int (len ( batch ) ),
155
- n_past = llama_cpp .c_int (n_past ),
189
+ n_tokens = llama_cpp .c_int (self . n_tokens ),
190
+ n_past = llama_cpp .c_int (self . n_past ),
156
191
n_threads = llama_cpp .c_int (self .n_threads ),
157
192
)
158
193
if int (return_code ) != 0 :
159
194
raise RuntimeError (f"llama_eval returned { return_code } " )
195
+ self .tokens .extend (batch )
160
196
self .last_n_tokens_data .extend (batch )
161
197
self .tokens_consumed += len (batch )
198
+ if self .params .logits_all :
199
+ self .all_logits .extend (self ._logits ())
200
+
201
+ def _logits (self ) -> List [List [float ]]:
202
+ """Return the logits from the last call to llama_eval."""
203
+ assert self .ctx is not None
204
+ n_vocab = llama_cpp .llama_n_vocab (self .ctx )
205
+ cols = int (n_vocab )
206
+ rows = self .n_tokens if self .params .logits_all else 1
207
+ logits_view = llama_cpp .llama_get_logits (self .ctx )
208
+ logits = [[logits_view [i * cols + j ] for j in range (cols )] for i in range (rows )]
209
+ return logits
162
210
163
211
def sample (
164
212
self ,
@@ -198,6 +246,7 @@ def generate(
198
246
top_p : float ,
199
247
temp : float ,
200
248
repeat_penalty : float ,
249
+ reset : bool = True ,
201
250
) -> Generator [
202
251
llama_cpp .llama_token , Optional [Sequence [llama_cpp .llama_token ]], None
203
252
]:
@@ -215,12 +264,25 @@ def generate(
215
264
top_p: The top-p sampling parameter.
216
265
temp: The temperature parameter.
217
266
repeat_penalty: The repeat penalty parameter.
267
+ reset: Whether to reset the model state.
218
268
219
269
Yields:
220
270
The generated tokens.
221
271
"""
222
272
assert self .ctx is not None
223
- self .reset ()
273
+ ### HACK
274
+ if (
275
+ reset
276
+ and self ._cache
277
+ and len (self .tokens ) > 0
278
+ and self .tokens == tokens [: len (self .tokens )]
279
+ ):
280
+ if self .verbose :
281
+ print ("generate cache hit" , file = sys .stderr )
282
+ reset = False
283
+ ###
284
+ if reset :
285
+ self .reset ()
224
286
while True :
225
287
self .eval (tokens )
226
288
token = self .sample (
@@ -300,19 +362,22 @@ def _create_completion(
300
362
top_p : float = 0.95 ,
301
363
logprobs : Optional [int ] = None ,
302
364
echo : bool = False ,
303
- stop : List [str ] = [],
365
+ stop : Optional [ List [str ] ] = [],
304
366
repeat_penalty : float = 1.1 ,
305
367
top_k : int = 40 ,
306
368
stream : bool = False ,
307
- ) -> Union [Iterator [Completion ], Iterator [CompletionChunk ], ]:
369
+ ) -> Union [Iterator [Completion ], Iterator [CompletionChunk ]]:
308
370
assert self .ctx is not None
309
- completion_id = f"cmpl-{ str (uuid .uuid4 ())} "
310
- created = int (time .time ())
371
+ completion_id : str = f"cmpl-{ str (uuid .uuid4 ())} "
372
+ created : int = int (time .time ())
311
373
completion_tokens : List [llama_cpp .llama_token ] = []
312
374
# Add blank space to start of prompt to match OG llama tokenizer
313
- prompt_tokens = self .tokenize (b" " + prompt .encode ("utf-8" ))
314
- text = b""
315
- returned_characters = 0
375
+ prompt_tokens : List [llama_cpp .llama_token ] = self .tokenize (
376
+ b" " + prompt .encode ("utf-8" )
377
+ )
378
+ text : bytes = b""
379
+ returned_characters : int = 0
380
+ stop = stop if stop is not None else []
316
381
317
382
if self .verbose :
318
383
llama_cpp .llama_reset_timings (self .ctx )
@@ -327,13 +392,34 @@ def _create_completion(
327
392
else :
328
393
stop_sequences = []
329
394
330
- finish_reason = None
395
+ if logprobs is not None and self .params .logits_all is False :
396
+ raise ValueError (
397
+ "logprobs is not supported for models created with logits_all=False"
398
+ )
399
+
400
+ ### HACK
401
+ reset : bool = True
402
+ _prompt : bytes = prompt .encode ("utf-8" )
403
+ _completion : bytes = b"" .join (self ._completion_bytes )
404
+ if len (_completion ) and self ._cache and _prompt .startswith (_completion ):
405
+ if self .verbose :
406
+ print ("completion cache hit" , file = sys .stderr )
407
+ reset = False
408
+ _prompt = _prompt [len (_completion ) :]
409
+ prompt_tokens = self .tokenize (b" " + _prompt )
410
+ self ._completion_bytes .append (_prompt )
411
+ else :
412
+ self ._completion_bytes = [prompt .encode ("utf-8" )]
413
+ ###
414
+
415
+ finish_reason = "length"
331
416
for token in self .generate (
332
417
prompt_tokens ,
333
418
top_k = top_k ,
334
419
top_p = top_p ,
335
420
temp = temperature ,
336
421
repeat_penalty = repeat_penalty ,
422
+ reset = reset ,
337
423
):
338
424
if token == llama_cpp .llama_token_eos ():
339
425
text = self .detokenize (completion_tokens )
@@ -363,6 +449,9 @@ def _create_completion(
363
449
break
364
450
text = all_text [: len (all_text ) - longest ]
365
451
returned_characters += len (text [start :])
452
+ ### HACK
453
+ self ._completion_bytes .append (text [start :])
454
+ ###
366
455
yield {
367
456
"id" : completion_id ,
368
457
"object" : "text_completion" ,
@@ -377,15 +466,16 @@ def _create_completion(
377
466
}
378
467
],
379
468
}
469
+
380
470
if len (completion_tokens ) >= max_tokens :
381
471
text = self .detokenize (completion_tokens )
382
472
finish_reason = "length"
383
473
break
384
474
385
- if finish_reason is None :
386
- finish_reason = "length"
387
-
388
475
if stream :
476
+ ### HACK
477
+ self ._completion_bytes .append (text [returned_characters :])
478
+ ###
389
479
yield {
390
480
"id" : completion_id ,
391
481
"object" : "text_completion" ,
@@ -402,16 +492,57 @@ def _create_completion(
402
492
}
403
493
return
404
494
405
- text = text .decode ("utf-8" )
495
+ ### HACK
496
+ self ._completion_bytes .append (text )
497
+ ###
498
+ text_str = text .decode ("utf-8" )
406
499
407
500
if echo :
408
- text = prompt + text
501
+ text_str = prompt + text_str
409
502
410
503
if suffix is not None :
411
- text = text + suffix
504
+ text_str = text_str + suffix
412
505
506
+ logprobs_or_none : Optional [CompletionLogprobs ] = None
413
507
if logprobs is not None :
414
- raise NotImplementedError ("logprobs not implemented" )
508
+ text_offset = 0
509
+ text_offsets : List [int ] = []
510
+ token_logprobs : List [float ] = []
511
+ tokens : List [str ] = []
512
+ top_logprobs : List [Dict [str , float ]] = []
513
+
514
+ all_tokens = prompt_tokens + completion_tokens
515
+ all_token_strs = [
516
+ self .detokenize ([token ]).decode ("utf-8" ) for token in all_tokens
517
+ ]
518
+ all_logprobs = [
519
+ [Llama .logit_to_logprob (logit ) for logit in row ]
520
+ for row in self .all_logits
521
+ ]
522
+ for token , token_str , logprobs_token in zip (
523
+ all_tokens , all_token_strs , all_logprobs
524
+ ):
525
+ text_offsets .append (text_offset )
526
+ text_offset += len (token_str )
527
+ tokens .append (token_str )
528
+ sorted_logprobs = list (
529
+ sorted (
530
+ zip (logprobs_token , range (len (logprobs_token ))), reverse = True
531
+ )
532
+ )
533
+ token_logprobs .append (sorted_logprobs [int (token )][0 ])
534
+ top_logprob = {
535
+ self .detokenize ([llama_cpp .llama_token (i )]).decode ("utf-8" ): logprob
536
+ for logprob , i in sorted_logprobs [:logprobs ]
537
+ }
538
+ top_logprob .update ({token_str : sorted_logprobs [int (token )][0 ]})
539
+ top_logprobs .append (top_logprob )
540
+ logprobs_or_none = {
541
+ "tokens" : tokens ,
542
+ "text_offset" : text_offsets ,
543
+ "token_logprobs" : token_logprobs ,
544
+ "top_logprobs" : top_logprobs ,
545
+ }
415
546
416
547
if self .verbose :
417
548
llama_cpp .llama_print_timings (self .ctx )
@@ -423,9 +554,9 @@ def _create_completion(
423
554
"model" : self .model_path ,
424
555
"choices" : [
425
556
{
426
- "text" : text ,
557
+ "text" : text_str ,
427
558
"index" : 0 ,
428
- "logprobs" : None ,
559
+ "logprobs" : logprobs_or_none ,
429
560
"finish_reason" : finish_reason ,
430
561
}
431
562
],
@@ -445,7 +576,7 @@ def create_completion(
445
576
top_p : float = 0.95 ,
446
577
logprobs : Optional [int ] = None ,
447
578
echo : bool = False ,
448
- stop : List [str ] = [],
579
+ stop : Optional [ List [str ] ] = [],
449
580
repeat_penalty : float = 1.1 ,
450
581
top_k : int = 40 ,
451
582
stream : bool = False ,
@@ -500,7 +631,7 @@ def __call__(
500
631
top_p : float = 0.95 ,
501
632
logprobs : Optional [int ] = None ,
502
633
echo : bool = False ,
503
- stop : List [str ] = [],
634
+ stop : Optional [ List [str ] ] = [],
504
635
repeat_penalty : float = 1.1 ,
505
636
top_k : int = 40 ,
506
637
stream : bool = False ,
@@ -602,12 +733,12 @@ def _convert_text_completion_chunks_to_chat(
602
733
def create_chat_completion (
603
734
self ,
604
735
messages : List [ChatCompletionMessage ],
605
- temperature : float = 0.8 ,
736
+ temperature : float = 0.2 ,
606
737
top_p : float = 0.95 ,
607
738
top_k : int = 40 ,
608
739
stream : bool = False ,
609
- stop : List [str ] = [],
610
- max_tokens : int = 128 ,
740
+ stop : Optional [ List [str ] ] = [],
741
+ max_tokens : int = 256 ,
611
742
repeat_penalty : float = 1.1 ,
612
743
) -> Union [ChatCompletion , Iterator [ChatCompletionChunk ]]:
613
744
"""Generate a chat completion from a list of messages.
@@ -625,13 +756,13 @@ def create_chat_completion(
625
756
Returns:
626
757
Generated chat completion or a stream of chat completion chunks.
627
758
"""
628
- instructions = """Complete the following chat conversation between the user and the assistant. System messages should be strictly followed as additional instructions."""
629
- chat_history = "\n " .join (
630
- f'{ message ["role" ]} { message . get ( "user" , "" ) } : { message ["content" ]} '
759
+ stop = stop if stop is not None else []
760
+ chat_history = "" .join (
761
+ f'### { "Human" if message ["role" ] == "user" else "Assistant" } : { message ["content" ]} '
631
762
for message in messages
632
763
)
633
- PROMPT = f" \n \n ### Instructions: { instructions } \n \n ### Inputs: { chat_history } \n \n ### Response: \n assistant: "
634
- PROMPT_STOP = ["###" , " \n user: " , "\n assistant: " , " \n system: " ]
764
+ PROMPT = chat_history + " ### Assistant: "
765
+ PROMPT_STOP = ["### Assistant: " , "### Human: " ]
635
766
completion_or_chunks = self (
636
767
prompt = PROMPT ,
637
768
stop = PROMPT_STOP + stop ,
@@ -668,8 +799,6 @@ def __getstate__(self):
668
799
use_mlock = self .params .use_mlock ,
669
800
embedding = self .params .embedding ,
670
801
last_n_tokens_size = self .last_n_tokens_size ,
671
- last_n_tokens_data = self .last_n_tokens_data ,
672
- tokens_consumed = self .tokens_consumed ,
673
802
n_batch = self .n_batch ,
674
803
n_threads = self .n_threads ,
675
804
)
@@ -691,9 +820,6 @@ def __setstate__(self, state):
691
820
last_n_tokens_size = state ["last_n_tokens_size" ],
692
821
verbose = state ["verbose" ],
693
822
)
694
- self .last_n_tokens_data = state ["last_n_tokens_data" ]
695
- self .tokens_consumed = state ["tokens_consumed" ]
696
-
697
823
698
824
@staticmethod
699
825
def token_eos () -> llama_cpp .llama_token :
@@ -704,3 +830,7 @@ def token_eos() -> llama_cpp.llama_token:
704
830
def token_bos () -> llama_cpp .llama_token :
705
831
"""Return the beginning-of-sequence token."""
706
832
return llama_cpp .llama_token_bos ()
833
+
834
+ @staticmethod
835
+ def logit_to_logprob (x : float ) -> float :
836
+ return math .log (1.0 + math .exp (x ))
0 commit comments