@@ -257,6 +257,11 @@ def _sample_top_p_top_k(
257
257
top_k : llama_cpp .c_int ,
258
258
top_p : llama_cpp .c_float ,
259
259
temp : llama_cpp .c_float ,
260
+ mirostat_mode : llama_cpp .c_int ,
261
+ mirostat_tau : llama_cpp .c_float ,
262
+ mirostat_eta : llama_cpp .c_float ,
263
+ mirostat_mu : llama_cpp .c_float ,
264
+ mirostat_m : llama_cpp .c_int ,
260
265
repeat_penalty : llama_cpp .c_float ,
261
266
):
262
267
assert self .ctx is not None
@@ -287,7 +292,34 @@ def _sample_top_p_top_k(
287
292
candidates = llama_cpp .ctypes .pointer (candidates ),
288
293
penalty = repeat_penalty ,
289
294
)
290
- if float (temp .value ) == 0.0 :
295
+ if mirostat_mode == 1 :
296
+ llama_cpp .llama_sample_temperature (
297
+ ctx = self .ctx ,
298
+ candidates = llama_cpp .ctypes .pointer (candidates ),
299
+ temp = temp ,
300
+ )
301
+ llama_cpp .llama_sample_token_mirostat (
302
+ ctx = self .ctx ,
303
+ candidates = llama_cpp .ctypes .pointer (candidates ),
304
+ tau = mirostat_tau ,
305
+ eta = mirostat_eta ,
306
+ mu = mirostat_mu ,
307
+ m = mirostat_m
308
+ )
309
+ elif mirostat_mode == 2 :
310
+ llama_cpp .llama_sample_temperature (
311
+ ctx = self .ctx ,
312
+ candidates = llama_cpp .ctypes .pointer (candidates ),
313
+ temp = temp ,
314
+ )
315
+ llama_cpp .llama_sample_token_mirostat_v2 (
316
+ ctx = self .ctx ,
317
+ candidates = llama_cpp .ctypes .pointer (candidates ),
318
+ tau = mirostat_tau ,
319
+ eta = mirostat_eta ,
320
+ mu = mirostat_mu
321
+ )
322
+ elif float (temp .value ) == 0.0 :
291
323
return llama_cpp .llama_sample_token_greedy (
292
324
ctx = self .ctx ,
293
325
candidates = llama_cpp .ctypes .pointer (candidates ),
@@ -328,6 +360,11 @@ def sample(
328
360
top_k : int ,
329
361
top_p : float ,
330
362
temp : float ,
363
+ mirostat_mode : int ,
364
+ mirostat_tau : float ,
365
+ mirostat_eta : float ,
366
+ mirostat_mu : float ,
367
+ mirostat_m : int ,
331
368
repeat_penalty : float ,
332
369
):
333
370
"""Sample a token from the model.
@@ -353,6 +390,11 @@ def sample(
353
390
top_k = llama_cpp .c_int (top_k ),
354
391
top_p = llama_cpp .c_float (top_p ),
355
392
temp = llama_cpp .c_float (temp ),
393
+ mirostat = llama_cpp .c_int (mirostat_mode ),
394
+ mirostat_mu = llama_cpp .c_float (mirostat_mu ),
395
+ mirostat_tau = llama_cpp .c_float (mirostat_tau ),
396
+ mirostat_eta = llama_cpp .c_float (mirostat_eta ),
397
+ mirostat_m = llama_cpp .c_int (mirostat_m ),
356
398
repeat_penalty = llama_cpp .c_float (repeat_penalty ),
357
399
)
358
400
@@ -362,6 +404,11 @@ def generate(
362
404
top_k : int ,
363
405
top_p : float ,
364
406
temp : float ,
407
+ mirostat : int ,
408
+ mirostat_tau : float ,
409
+ mirostat_eta : float ,
410
+ mirostat_mu : float ,
411
+ mirostat_m : int ,
365
412
repeat_penalty : float ,
366
413
reset : bool = True ,
367
414
) -> Generator [
@@ -416,6 +463,11 @@ def generate(
416
463
top_k = top_k ,
417
464
top_p = top_p ,
418
465
temp = temp ,
466
+ mirostat_mode = mirostat_mode ,
467
+ mirostat_tau = mirostat_tau ,
468
+ mirostat_eta = mirostat_eta ,
469
+ mirostat_mu = mirostat_mu ,
470
+ mirostat_m = mirostat_m ,
419
471
repeat_penalty = repeat_penalty ,
420
472
)
421
473
tokens_or_none = yield token
@@ -486,6 +538,11 @@ def _create_completion(
486
538
suffix : Optional [str ] = None ,
487
539
max_tokens : int = 16 ,
488
540
temperature : float = 0.8 ,
541
+ mirostat_mode : int = 0 ,
542
+ mirostat_tau : float = 5.0 ,
543
+ mirostat_eta : float = 0.1 ,
544
+ mirostat_mu : float = 10 ,
545
+ mirostat_m : int = 100 ,
489
546
top_p : float = 0.95 ,
490
547
logprobs : Optional [int ] = None ,
491
548
echo : bool = False ,
@@ -536,6 +593,11 @@ def _create_completion(
536
593
top_k = top_k ,
537
594
top_p = top_p ,
538
595
temp = temperature ,
596
+ mirostat_mode = mirostat_mode ,
597
+ mirostat_tau = mirostat_tau ,
598
+ mirostat_eta = mirostat_eta ,
599
+ mirostat_mu = mirostat_mu ,
600
+ mirostat_m = mirostat_m ,
539
601
repeat_penalty = repeat_penalty ,
540
602
):
541
603
if token == llama_cpp .llama_token_eos ():
@@ -707,6 +769,11 @@ def create_completion(
707
769
suffix : Optional [str ] = None ,
708
770
max_tokens : int = 128 ,
709
771
temperature : float = 0.8 ,
772
+ mirostat_mode : int = 0 ,
773
+ mirostat_tau : float = 5.0 ,
774
+ mirostat_eta : float = 0.1 ,
775
+ mirostat_mu : float = 10 ,
776
+ mirostat_m : int = 100 ,
710
777
top_p : float = 0.95 ,
711
778
logprobs : Optional [int ] = None ,
712
779
echo : bool = False ,
@@ -742,6 +809,11 @@ def create_completion(
742
809
suffix = suffix ,
743
810
max_tokens = max_tokens ,
744
811
temperature = temperature ,
812
+ mirostat_mode = mirostat_mode ,
813
+ mirostat_tau = mirostat_tau ,
814
+ mirostat_eta = mirostat_eta ,
815
+ mirostat_mu = mirostat_mu ,
816
+ mirostat_m = mirostat_m ,
745
817
top_p = top_p ,
746
818
logprobs = logprobs ,
747
819
echo = echo ,
@@ -762,6 +834,11 @@ def __call__(
762
834
suffix : Optional [str ] = None ,
763
835
max_tokens : int = 128 ,
764
836
temperature : float = 0.8 ,
837
+ mirostat_mode : int = 0 ,
838
+ mirostat_tau : float = 5.0 ,
839
+ mirostat_eta : float = 0.1 ,
840
+ mirostat_mu : float = 10 ,
841
+ mirostat_m : int = 100 ,
765
842
top_p : float = 0.95 ,
766
843
logprobs : Optional [int ] = None ,
767
844
echo : bool = False ,
@@ -797,6 +874,11 @@ def __call__(
797
874
suffix = suffix ,
798
875
max_tokens = max_tokens ,
799
876
temperature = temperature ,
877
+ mirostat_mode = mirostat_mode ,
878
+ mirostat_tau = mirostat_tau ,
879
+ mirostat_eta = mirostat_eta ,
880
+ mirostat_mu = mirostat_mu ,
881
+ mirostat_m = mirostat_m ,
800
882
top_p = top_p ,
801
883
logprobs = logprobs ,
802
884
echo = echo ,
0 commit comments