|
4 | 4 | import time
|
5 | 5 | import math
|
6 | 6 | import multiprocessing
|
7 |
| -from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple |
| 7 | +from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple, Callable |
8 | 8 | from collections import deque, OrderedDict
|
9 | 9 |
|
10 | 10 | from . import llama_cpp
|
@@ -316,12 +316,11 @@ def _sample(
|
316 | 316 | mirostat_tau: llama_cpp.c_float,
|
317 | 317 | mirostat_eta: llama_cpp.c_float,
|
318 | 318 | penalize_nl: bool = True,
|
319 |
| - logits_processors=None |
| 319 | + logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None |
320 | 320 | ):
|
321 | 321 | assert self.ctx is not None
|
322 | 322 | assert len(self.eval_logits) > 0
|
323 |
| - |
324 |
| - if logits_processors == None: |
| 323 | + if logits_processors is None: |
325 | 324 | logits_processors = []
|
326 | 325 |
|
327 | 326 | n_vocab = self.n_vocab()
|
@@ -445,7 +444,7 @@ def sample(
|
445 | 444 | mirostat_eta: float = 0.1,
|
446 | 445 | mirostat_tau: float = 5.0,
|
447 | 446 | penalize_nl: bool = True,
|
448 |
| - logits_processors=None |
| 447 | + logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None |
449 | 448 |
|
450 | 449 | ):
|
451 | 450 | """Sample a token from the model.
|
@@ -497,7 +496,7 @@ def generate(
|
497 | 496 | mirostat_mode: int = 0,
|
498 | 497 | mirostat_tau: float = 5.0,
|
499 | 498 | mirostat_eta: float = 0.1,
|
500 |
| - logits_processors=None |
| 499 | + logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None |
501 | 500 | ) -> Generator[int, Optional[Sequence[int]], None]:
|
502 | 501 | """Create a generator of tokens from a prompt.
|
503 | 502 |
|
@@ -652,12 +651,12 @@ def _create_completion(
|
652 | 651 | mirostat_tau: float = 5.0,
|
653 | 652 | mirostat_eta: float = 0.1,
|
654 | 653 | model: Optional[str] = None,
|
655 |
| - logits_processors=None, |
656 |
| - stopping_criterias=None |
| 654 | + logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None, |
| 655 | + stopping_criterias: List[Callable[[List[int], List[llama_cpp.c_float]], bool]] = None, |
657 | 656 | ) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
|
658 | 657 | assert self.ctx is not None
|
659 | 658 |
|
660 |
| - if stopping_criterias == None: |
| 659 | + if stopping_criterias is None: |
661 | 660 | stopping_criterias = []
|
662 | 661 |
|
663 | 662 | completion_id: str = f"cmpl-{str(uuid.uuid4())}"
|
@@ -1036,8 +1035,8 @@ def create_completion(
|
1036 | 1035 | mirostat_tau: float = 5.0,
|
1037 | 1036 | mirostat_eta: float = 0.1,
|
1038 | 1037 | model: Optional[str] = None,
|
1039 |
| - logits_processors=None, |
1040 |
| - stopping_criterias=None |
| 1038 | + logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None, |
| 1039 | + stopping_criterias: List[Callable[[List[int], List[llama_cpp.c_float]], bool]] = None |
1041 | 1040 | ) -> Union[Completion, Iterator[CompletionChunk]]:
|
1042 | 1041 | """Generate text from a prompt.
|
1043 | 1042 |
|
|
0 commit comments