Skip to content

Commit da463e6

Browse files
Added types to logit processor list and stop criteria list
1 parent c05fcdf commit da463e6

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

llama_cpp/llama.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import time
55
import math
66
import multiprocessing
7-
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple
7+
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple, Callable
88
from collections import deque, OrderedDict
99

1010
from . import llama_cpp
@@ -316,12 +316,11 @@ def _sample(
316316
mirostat_tau: llama_cpp.c_float,
317317
mirostat_eta: llama_cpp.c_float,
318318
penalize_nl: bool = True,
319-
logits_processors=None
319+
logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None
320320
):
321321
assert self.ctx is not None
322322
assert len(self.eval_logits) > 0
323-
324-
if logits_processors == None:
323+
if logits_processors is None:
325324
logits_processors = []
326325

327326
n_vocab = self.n_vocab()
@@ -445,7 +444,7 @@ def sample(
445444
mirostat_eta: float = 0.1,
446445
mirostat_tau: float = 5.0,
447446
penalize_nl: bool = True,
448-
logits_processors=None
447+
logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None
449448

450449
):
451450
"""Sample a token from the model.
@@ -497,7 +496,7 @@ def generate(
497496
mirostat_mode: int = 0,
498497
mirostat_tau: float = 5.0,
499498
mirostat_eta: float = 0.1,
500-
logits_processors=None
499+
logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None
501500
) -> Generator[int, Optional[Sequence[int]], None]:
502501
"""Create a generator of tokens from a prompt.
503502
@@ -652,12 +651,12 @@ def _create_completion(
652651
mirostat_tau: float = 5.0,
653652
mirostat_eta: float = 0.1,
654653
model: Optional[str] = None,
655-
logits_processors=None,
656-
stopping_criterias=None
654+
logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None,
655+
stopping_criterias: List[Callable[[List[int], List[llama_cpp.c_float]], bool]] = None,
657656
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
658657
assert self.ctx is not None
659658

660-
if stopping_criterias == None:
659+
if stopping_criterias is None:
661660
stopping_criterias = []
662661

663662
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
@@ -1036,8 +1035,8 @@ def create_completion(
10361035
mirostat_tau: float = 5.0,
10371036
mirostat_eta: float = 0.1,
10381037
model: Optional[str] = None,
1039-
logits_processors=None,
1040-
stopping_criterias=None
1038+
logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None,
1039+
stopping_criterias: List[Callable[[List[int], List[llama_cpp.c_float]], bool]] = None
10411040
) -> Union[Completion, Iterator[CompletionChunk]]:
10421041
"""Generate text from a prompt.
10431042

0 commit comments

Comments
 (0)