Skip to content

Commit aa203a0

Browse files
Added mirostat sampling to the high level API.
1 parent 2f2ea00 commit aa203a0

File tree

1 file changed

+83
-1
lines changed

1 file changed

+83
-1
lines changed

llama_cpp/llama.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,11 @@ def _sample_top_p_top_k(
257257
top_k: llama_cpp.c_int,
258258
top_p: llama_cpp.c_float,
259259
temp: llama_cpp.c_float,
260+
mirostat_mode: llama_cpp.c_int,
261+
mirostat_tau: llama_cpp.c_float,
262+
mirostat_eta: llama_cpp.c_float,
263+
mirostat_mu: llama_cpp.c_float,
264+
mirostat_m: llama_cpp.c_int,
260265
repeat_penalty: llama_cpp.c_float,
261266
):
262267
assert self.ctx is not None
@@ -287,7 +292,34 @@ def _sample_top_p_top_k(
287292
candidates=llama_cpp.ctypes.pointer(candidates),
288293
penalty=repeat_penalty,
289294
)
290-
if float(temp.value) == 0.0:
295+
if mirostat_mode == 1:
296+
llama_cpp.llama_sample_temperature(
297+
ctx=self.ctx,
298+
candidates=llama_cpp.ctypes.pointer(candidates),
299+
temp=temp,
300+
)
301+
llama_cpp.llama_sample_token_mirostat(
302+
ctx=self.ctx,
303+
candidates=llama_cpp.ctypes.pointer(candidates),
304+
tau=mirostat_tau,
305+
eta=mirostat_eta,
306+
mu=mirostat_mu,
307+
m=mirostat_m
308+
)
309+
elif mirostat_mode == 2:
310+
llama_cpp.llama_sample_temperature(
311+
ctx=self.ctx,
312+
candidates=llama_cpp.ctypes.pointer(candidates),
313+
temp=temp,
314+
)
315+
llama_cpp.llama_sample_token_mirostat_v2(
316+
ctx=self.ctx,
317+
candidates=llama_cpp.ctypes.pointer(candidates),
318+
tau=mirostat_tau,
319+
eta=mirostat_eta,
320+
mu=mirostat_mu
321+
)
322+
elif float(temp.value) == 0.0:
291323
return llama_cpp.llama_sample_token_greedy(
292324
ctx=self.ctx,
293325
candidates=llama_cpp.ctypes.pointer(candidates),
@@ -328,6 +360,11 @@ def sample(
328360
top_k: int,
329361
top_p: float,
330362
temp: float,
363+
mirostat_mode: int,
364+
mirostat_tau: float,
365+
mirostat_eta: float,
366+
mirostat_mu: float,
367+
mirostat_m: int,
331368
repeat_penalty: float,
332369
):
333370
"""Sample a token from the model.
@@ -353,6 +390,11 @@ def sample(
353390
top_k=llama_cpp.c_int(top_k),
354391
top_p=llama_cpp.c_float(top_p),
355392
temp=llama_cpp.c_float(temp),
393+
mirostat=llama_cpp.c_int(mirostat_mode),
394+
mirostat_mu=llama_cpp.c_float(mirostat_mu),
395+
mirostat_tau=llama_cpp.c_float(mirostat_tau),
396+
mirostat_eta=llama_cpp.c_float(mirostat_eta),
397+
mirostat_m=llama_cpp.c_int(mirostat_m),
356398
repeat_penalty=llama_cpp.c_float(repeat_penalty),
357399
)
358400

@@ -362,6 +404,11 @@ def generate(
362404
top_k: int,
363405
top_p: float,
364406
temp: float,
407+
mirostat: int,
408+
mirostat_tau: float,
409+
mirostat_eta: float,
410+
mirostat_mu: float,
411+
mirostat_m: int,
365412
repeat_penalty: float,
366413
reset: bool = True,
367414
) -> Generator[
@@ -416,6 +463,11 @@ def generate(
416463
top_k=top_k,
417464
top_p=top_p,
418465
temp=temp,
466+
mirostat_mode=mirostat_mode,
467+
mirostat_tau=mirostat_tau,
468+
mirostat_eta=mirostat_eta,
469+
mirostat_mu=mirostat_mu,
470+
mirostat_m=mirostat_m,
419471
repeat_penalty=repeat_penalty,
420472
)
421473
tokens_or_none = yield token
@@ -486,6 +538,11 @@ def _create_completion(
486538
suffix: Optional[str] = None,
487539
max_tokens: int = 16,
488540
temperature: float = 0.8,
541+
mirostat_mode: int = 0,
542+
mirostat_tau: float = 5.0,
543+
mirostat_eta: float = 0.1,
544+
mirostat_mu: float = 10,
545+
mirostat_m: int = 100,
489546
top_p: float = 0.95,
490547
logprobs: Optional[int] = None,
491548
echo: bool = False,
@@ -536,6 +593,11 @@ def _create_completion(
536593
top_k=top_k,
537594
top_p=top_p,
538595
temp=temperature,
596+
mirostat_mode=mirostat_mode,
597+
mirostat_tau=mirostat_tau,
598+
mirostat_eta=mirostat_eta,
599+
mirostat_mu=mirostat_mu,
600+
mirostat_m=mirostat_m,
539601
repeat_penalty=repeat_penalty,
540602
):
541603
if token == llama_cpp.llama_token_eos():
@@ -707,6 +769,11 @@ def create_completion(
707769
suffix: Optional[str] = None,
708770
max_tokens: int = 128,
709771
temperature: float = 0.8,
772+
mirostat_mode: int = 0,
773+
mirostat_tau: float = 5.0,
774+
mirostat_eta: float = 0.1,
775+
mirostat_mu: float = 10,
776+
mirostat_m: int = 100,
710777
top_p: float = 0.95,
711778
logprobs: Optional[int] = None,
712779
echo: bool = False,
@@ -742,6 +809,11 @@ def create_completion(
742809
suffix=suffix,
743810
max_tokens=max_tokens,
744811
temperature=temperature,
812+
mirostat_mode=mirostat_mode,
813+
mirostat_tau=mirostat_tau,
814+
mirostat_eta=mirostat_eta,
815+
mirostat_mu=mirostat_mu,
816+
mirostat_m=mirostat_m,
745817
top_p=top_p,
746818
logprobs=logprobs,
747819
echo=echo,
@@ -762,6 +834,11 @@ def __call__(
762834
suffix: Optional[str] = None,
763835
max_tokens: int = 128,
764836
temperature: float = 0.8,
837+
mirostat_mode: int = 0,
838+
mirostat_tau: float = 5.0,
839+
mirostat_eta: float = 0.1,
840+
mirostat_mu: float = 10,
841+
mirostat_m: int = 100,
765842
top_p: float = 0.95,
766843
logprobs: Optional[int] = None,
767844
echo: bool = False,
@@ -797,6 +874,11 @@ def __call__(
797874
suffix=suffix,
798875
max_tokens=max_tokens,
799876
temperature=temperature,
877+
mirostat_mode=mirostat_mode,
878+
mirostat_tau=mirostat_tau,
879+
mirostat_eta=mirostat_eta,
880+
mirostat_mu=mirostat_mu,
881+
mirostat_m=mirostat_m,
800882
top_p=top_p,
801883
logprobs=logprobs,
802884
echo=echo,

0 commit comments

Comments
 (0)