Skip to content

Commit 1d247e0

Browse files
committed
Add StoppingCriteria and LogitsProcessor to generate to match huggingface API
1 parent c6a9659 commit 1d247e0

File tree

1 file changed

+42
-32
lines changed

1 file changed

+42
-32
lines changed

llama_cpp/llama.py

Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,17 @@
44
import time
55
import math
66
import multiprocessing
7-
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple, Callable
7+
from typing import (
8+
List,
9+
Optional,
10+
Union,
11+
Generator,
12+
Sequence,
13+
Iterator,
14+
Deque,
15+
Tuple,
16+
Callable,
17+
)
818
from collections import deque, OrderedDict
919

1020
from . import llama_cpp
@@ -72,6 +82,24 @@ def __init__(
7282
self.llama_state_size = llama_state_size
7383

7484

85+
LogitsProcessor = Callable[[List[int], List[float]], List[float]]
86+
87+
88+
class LogitsProcessorList(List[LogitsProcessor]):
89+
def __call__(self, input_ids: List[int], scores: List[float]) -> List[float]:
90+
for processor in self:
91+
scores = processor(input_ids, scores)
92+
return scores
93+
94+
95+
StoppingCriteria = Callable[[List[int], List[float]], bool]
96+
97+
98+
class StoppingCriteriaList(List[StoppingCriteria]):
99+
def __call__(self, input_ids: List[int], logits: List[float]) -> bool:
100+
return any([stopping_criteria(input_ids, logits) for stopping_criteria in self])
101+
102+
75103
class Llama:
76104
"""High-level Python wrapper for a llama.cpp model."""
77105

@@ -316,12 +344,10 @@ def _sample(
316344
mirostat_tau: llama_cpp.c_float,
317345
mirostat_eta: llama_cpp.c_float,
318346
penalize_nl: bool = True,
319-
logits_processors: List[Callable[[List[int], List[float]], List[float]]] = None
347+
logits_processor: Optional[LogitsProcessorList] = None,
320348
):
321349
assert self.ctx is not None
322350
assert len(self.eval_logits) > 0
323-
if logits_processors is None:
324-
logits_processors = []
325351

326352
n_vocab = self.n_vocab()
327353
n_ctx = self.n_ctx()
@@ -332,10 +358,10 @@ def _sample(
332358
else last_n_tokens_size
333359
)
334360
logits = self.eval_logits[-1]
335-
for processor in logits_processors:
336-
logits = processor(list(self.eval_tokens), logits)
337361

338-
self.eval_logits[-1] = logits
362+
if logits_processor is not None:
363+
logits = logits_processor(list(self.eval_tokens), logits)
364+
339365
nl_logit = logits[self._token_nl]
340366
candidates = self._candidates
341367
for i, logit in enumerate(logits):
@@ -444,8 +470,7 @@ def sample(
444470
mirostat_eta: float = 0.1,
445471
mirostat_tau: float = 5.0,
446472
penalize_nl: bool = True,
447-
logits_processors: List[Callable[[List[int], List[float]], List[float]]] = None
448-
473+
logits_processor: Optional[LogitsProcessorList] = None,
449474
):
450475
"""Sample a token from the model.
451476
@@ -478,8 +503,7 @@ def sample(
478503
mirostat_tau=llama_cpp.c_float(mirostat_tau),
479504
mirostat_eta=llama_cpp.c_float(mirostat_eta),
480505
penalize_nl=penalize_nl,
481-
logits_processors=logits_processors
482-
506+
logits_processor=logits_processor,
483507
)
484508

485509
def generate(
@@ -496,7 +520,8 @@ def generate(
496520
mirostat_mode: int = 0,
497521
mirostat_tau: float = 5.0,
498522
mirostat_eta: float = 0.1,
499-
logits_processors: List[Callable[[List[int], List[float]], List[float]]] = None
523+
logits_processor: Optional[LogitsProcessorList] = None,
524+
stopping_criteria: Optional[StoppingCriteriaList] = None,
500525
) -> Generator[int, Optional[Sequence[int]], None]:
501526
"""Create a generator of tokens from a prompt.
502527
@@ -554,8 +579,12 @@ def generate(
554579
mirostat_mode=mirostat_mode,
555580
mirostat_tau=mirostat_tau,
556581
mirostat_eta=mirostat_eta,
557-
logits_processors=logits_processors
582+
logits_processor=logits_processor,
558583
)
584+
if stopping_criteria is not None and stopping_criteria(
585+
list(self.eval_tokens), self.eval_logits[-1]
586+
):
587+
return
559588
tokens_or_none = yield token
560589
tokens = [token]
561590
if tokens_or_none is not None:
@@ -651,14 +680,9 @@ def _create_completion(
651680
mirostat_tau: float = 5.0,
652681
mirostat_eta: float = 0.1,
653682
model: Optional[str] = None,
654-
logits_processors: List[Callable[[List[int], List[float]], List[float]]] = None,
655-
stopping_criterias: List[Callable[[List[int], List[float]], bool]] = None,
656683
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
657684
assert self.ctx is not None
658685

659-
if stopping_criterias is None:
660-
stopping_criterias = []
661-
662686
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
663687
created: int = int(time.time())
664688
completion_tokens: List[int] = []
@@ -720,22 +744,13 @@ def _create_completion(
720744
frequency_penalty=frequency_penalty,
721745
presence_penalty=presence_penalty,
722746
repeat_penalty=repeat_penalty,
723-
logits_processors=logits_processors
724747
):
725748
if token == self._token_eos:
726749
text = self.detokenize(completion_tokens)
727750
finish_reason = "stop"
728751
break
729752

730753
completion_tokens.append(token)
731-
for stopping_crit in stopping_criterias:
732-
if stopping_crit(completion_tokens, None):
733-
text = self.detokenize(completion_tokens)
734-
finish_reason = "stop"
735-
break
736-
737-
if finish_reason == "stop":
738-
break
739754

740755
all_text = self.detokenize(completion_tokens)
741756

@@ -1035,8 +1050,6 @@ def create_completion(
10351050
mirostat_tau: float = 5.0,
10361051
mirostat_eta: float = 0.1,
10371052
model: Optional[str] = None,
1038-
logits_processors: List[Callable[[List[int], List[float]], List[float]]] = None,
1039-
stopping_criterias: List[Callable[[List[int], List[float]], bool]] = None
10401053
) -> Union[Completion, Iterator[CompletionChunk]]:
10411054
"""Generate text from a prompt.
10421055
@@ -1079,9 +1092,6 @@ def create_completion(
10791092
mirostat_tau=mirostat_tau,
10801093
mirostat_eta=mirostat_eta,
10811094
model=model,
1082-
logits_processors=logits_processors,
1083-
stopping_criterias=stopping_criterias
1084-
10851095
)
10861096
if stream:
10871097
chunks: Iterator[CompletionChunk] = completion_or_chunks

0 commit comments

Comments
 (0)