4
4
import time
5
5
import math
6
6
import multiprocessing
7
- from typing import List , Optional , Union , Generator , Sequence , Iterator , Deque , Tuple , Callable
7
+ from typing import (
8
+ List ,
9
+ Optional ,
10
+ Union ,
11
+ Generator ,
12
+ Sequence ,
13
+ Iterator ,
14
+ Deque ,
15
+ Tuple ,
16
+ Callable ,
17
+ )
8
18
from collections import deque , OrderedDict
9
19
10
20
from . import llama_cpp
@@ -72,6 +82,24 @@ def __init__(
72
82
self .llama_state_size = llama_state_size
73
83
74
84
85
+ LogitsProcessor = Callable [[List [int ], List [float ]], List [float ]]
86
+
87
+
88
+ class LogitsProcessorList (List [LogitsProcessor ]):
89
+ def __call__ (self , input_ids : List [int ], scores : List [float ]) -> List [float ]:
90
+ for processor in self :
91
+ scores = processor (input_ids , scores )
92
+ return scores
93
+
94
+
95
+ StoppingCriteria = Callable [[List [int ], List [float ]], bool ]
96
+
97
+
98
+ class StoppingCriteriaList (List [StoppingCriteria ]):
99
+ def __call__ (self , input_ids : List [int ], logits : List [float ]) -> bool :
100
+ return any ([stopping_criteria (input_ids , logits ) for stopping_criteria in self ])
101
+
102
+
75
103
class Llama :
76
104
"""High-level Python wrapper for a llama.cpp model."""
77
105
@@ -316,12 +344,10 @@ def _sample(
316
344
mirostat_tau : llama_cpp .c_float ,
317
345
mirostat_eta : llama_cpp .c_float ,
318
346
penalize_nl : bool = True ,
319
- logits_processors : List [ Callable [[ List [ int ], List [ float ]], List [ float ]]] = None
347
+ logits_processor : Optional [ LogitsProcessorList ] = None ,
320
348
):
321
349
assert self .ctx is not None
322
350
assert len (self .eval_logits ) > 0
323
- if logits_processors is None :
324
- logits_processors = []
325
351
326
352
n_vocab = self .n_vocab ()
327
353
n_ctx = self .n_ctx ()
@@ -332,10 +358,10 @@ def _sample(
332
358
else last_n_tokens_size
333
359
)
334
360
logits = self .eval_logits [- 1 ]
335
- for processor in logits_processors :
336
- logits = processor (list (self .eval_tokens ), logits )
337
361
338
- self .eval_logits [- 1 ] = logits
362
+ if logits_processor is not None :
363
+ logits = logits_processor (list (self .eval_tokens ), logits )
364
+
339
365
nl_logit = logits [self ._token_nl ]
340
366
candidates = self ._candidates
341
367
for i , logit in enumerate (logits ):
@@ -444,8 +470,7 @@ def sample(
444
470
mirostat_eta : float = 0.1 ,
445
471
mirostat_tau : float = 5.0 ,
446
472
penalize_nl : bool = True ,
447
- logits_processors : List [Callable [[List [int ], List [float ]], List [float ]]] = None
448
-
473
+ logits_processor : Optional [LogitsProcessorList ] = None ,
449
474
):
450
475
"""Sample a token from the model.
451
476
@@ -478,8 +503,7 @@ def sample(
478
503
mirostat_tau = llama_cpp .c_float (mirostat_tau ),
479
504
mirostat_eta = llama_cpp .c_float (mirostat_eta ),
480
505
penalize_nl = penalize_nl ,
481
- logits_processors = logits_processors
482
-
506
+ logits_processor = logits_processor ,
483
507
)
484
508
485
509
def generate (
@@ -496,7 +520,8 @@ def generate(
496
520
mirostat_mode : int = 0 ,
497
521
mirostat_tau : float = 5.0 ,
498
522
mirostat_eta : float = 0.1 ,
499
- logits_processors : List [Callable [[List [int ], List [float ]], List [float ]]] = None
523
+ logits_processor : Optional [LogitsProcessorList ] = None ,
524
+ stopping_criteria : Optional [StoppingCriteriaList ] = None ,
500
525
) -> Generator [int , Optional [Sequence [int ]], None ]:
501
526
"""Create a generator of tokens from a prompt.
502
527
@@ -554,8 +579,12 @@ def generate(
554
579
mirostat_mode = mirostat_mode ,
555
580
mirostat_tau = mirostat_tau ,
556
581
mirostat_eta = mirostat_eta ,
557
- logits_processors = logits_processors
582
+ logits_processor = logits_processor ,
558
583
)
584
+ if stopping_criteria is not None and stopping_criteria (
585
+ list (self .eval_tokens ), self .eval_logits [- 1 ]
586
+ ):
587
+ return
559
588
tokens_or_none = yield token
560
589
tokens = [token ]
561
590
if tokens_or_none is not None :
@@ -651,14 +680,9 @@ def _create_completion(
651
680
mirostat_tau : float = 5.0 ,
652
681
mirostat_eta : float = 0.1 ,
653
682
model : Optional [str ] = None ,
654
- logits_processors : List [Callable [[List [int ], List [float ]], List [float ]]] = None ,
655
- stopping_criterias : List [Callable [[List [int ], List [float ]], bool ]] = None ,
656
683
) -> Union [Iterator [Completion ], Iterator [CompletionChunk ]]:
657
684
assert self .ctx is not None
658
685
659
- if stopping_criterias is None :
660
- stopping_criterias = []
661
-
662
686
completion_id : str = f"cmpl-{ str (uuid .uuid4 ())} "
663
687
created : int = int (time .time ())
664
688
completion_tokens : List [int ] = []
@@ -720,22 +744,13 @@ def _create_completion(
720
744
frequency_penalty = frequency_penalty ,
721
745
presence_penalty = presence_penalty ,
722
746
repeat_penalty = repeat_penalty ,
723
- logits_processors = logits_processors
724
747
):
725
748
if token == self ._token_eos :
726
749
text = self .detokenize (completion_tokens )
727
750
finish_reason = "stop"
728
751
break
729
752
730
753
completion_tokens .append (token )
731
- for stopping_crit in stopping_criterias :
732
- if stopping_crit (completion_tokens , None ):
733
- text = self .detokenize (completion_tokens )
734
- finish_reason = "stop"
735
- break
736
-
737
- if finish_reason == "stop" :
738
- break
739
754
740
755
all_text = self .detokenize (completion_tokens )
741
756
@@ -1035,8 +1050,6 @@ def create_completion(
1035
1050
mirostat_tau : float = 5.0 ,
1036
1051
mirostat_eta : float = 0.1 ,
1037
1052
model : Optional [str ] = None ,
1038
- logits_processors : List [Callable [[List [int ], List [float ]], List [float ]]] = None ,
1039
- stopping_criterias : List [Callable [[List [int ], List [float ]], bool ]] = None
1040
1053
) -> Union [Completion , Iterator [CompletionChunk ]]:
1041
1054
"""Generate text from a prompt.
1042
1055
@@ -1079,9 +1092,6 @@ def create_completion(
1079
1092
mirostat_tau = mirostat_tau ,
1080
1093
mirostat_eta = mirostat_eta ,
1081
1094
model = model ,
1082
- logits_processors = logits_processors ,
1083
- stopping_criterias = stopping_criterias
1084
-
1085
1095
)
1086
1096
if stream :
1087
1097
chunks : Iterator [CompletionChunk ] = completion_or_chunks
0 commit comments