1
1
import os
2
+ from pathlib import Path
2
3
import sys
3
4
import uuid
4
5
import time
23
24
24
25
from . import llama_cpp
25
26
from .llama_types import *
27
+ from .llama_grammar import LlamaGrammar
26
28
27
29
import numpy as np
28
30
import numpy .typing as npt
@@ -223,6 +225,7 @@ def __init__(
223
225
tensor_split : Optional [List [float ]] = None ,
224
226
rope_freq_base : float = 10000.0 ,
225
227
rope_freq_scale : float = 1.0 ,
228
+ grammar : Optional [Union [str , Path ]] = None ,
226
229
n_gqa : Optional [int ] = None , # (TEMPORARY) must be 8 for llama2 70b
227
230
rms_norm_eps : Optional [float ] = None , # (TEMPORARY)
228
231
verbose : bool = True ,
@@ -248,6 +251,7 @@ def __init__(
248
251
tensor_split: List of floats to split the model across multiple GPUs. If None, the model is not split.
249
252
rope_freq_base: Base frequency for rope sampling.
250
253
rope_freq_scale: Scale factor for rope sampling.
254
+ grammar: Path to a BNF grammar file to use for grammar based sampling.
251
255
verbose: Print verbose output to stderr.
252
256
253
257
Raises:
@@ -358,6 +362,12 @@ def __init__(
358
362
self .scores : npt .NDArray [np .single ] = np .ndarray (
359
363
(n_ctx , self ._n_vocab ), dtype = np .single
360
364
)
365
+ if grammar is not None :
366
+ self .grammar = LlamaGrammar .from_file (
367
+ grammar
368
+ ) # type: Optional[LlamaGrammar]
369
+ else :
370
+ self .grammar = None
361
371
362
372
@property
363
373
def _input_ids (self ) -> npt .NDArray [np .intc ]:
@@ -542,8 +552,16 @@ def _sample(
542
552
)
543
553
if not penalize_nl :
544
554
candidates .data [self ._token_nl ].logit = llama_cpp .c_float (nl_logit )
555
+
556
+ if self .grammar is not None :
557
+ llama_cpp .llama_sample_grammar (
558
+ ctx = self .ctx ,
559
+ candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
560
+ grammar = self .grammar .grammar ,
561
+ )
562
+
545
563
if temp .value == 0.0 :
546
- return llama_cpp .llama_sample_token_greedy (
564
+ id = llama_cpp .llama_sample_token_greedy (
547
565
ctx = self .ctx ,
548
566
candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
549
567
)
@@ -555,7 +573,7 @@ def _sample(
555
573
candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
556
574
temp = temp ,
557
575
)
558
- return llama_cpp .llama_sample_token_mirostat (
576
+ id = llama_cpp .llama_sample_token_mirostat (
559
577
ctx = self .ctx ,
560
578
candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
561
579
tau = mirostat_tau ,
@@ -570,7 +588,7 @@ def _sample(
570
588
candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
571
589
temp = temp ,
572
590
)
573
- return llama_cpp .llama_sample_token_mirostat_v2 (
591
+ id = llama_cpp .llama_sample_token_mirostat_v2 (
574
592
ctx = self .ctx ,
575
593
candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
576
594
tau = mirostat_tau ,
@@ -607,10 +625,17 @@ def _sample(
607
625
candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
608
626
temp = temp ,
609
627
)
610
- return llama_cpp .llama_sample_token (
628
+ id = llama_cpp .llama_sample_token (
611
629
ctx = self .ctx ,
612
630
candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
613
631
)
632
+ if self .grammar is not None :
633
+ llama_cpp .llama_grammar_accept_token (
634
+ ctx = self .ctx ,
635
+ grammar = self .grammar .grammar ,
636
+ token = llama_cpp .ctypes .c_int (id ),
637
+ )
638
+ return id
614
639
615
640
def sample (
616
641
self ,
@@ -1509,6 +1534,9 @@ def __del__(self):
1509
1534
if self .ctx is not None :
1510
1535
llama_cpp .llama_free (self .ctx )
1511
1536
self .ctx = None
1537
+ if self .grammar is not None :
1538
+ llama_cpp .llama_grammar_free (self .grammar .grammar )
1539
+ self .grammar = None
1512
1540
1513
1541
def __getstate__ (self ):
1514
1542
return dict (
0 commit comments