Skip to content

Commit 418aa83

Browse files
committed
Added grammar based sampling
1 parent ac188a2 commit 418aa83

File tree

2 files changed

+512
-518
lines changed

2 files changed

+512
-518
lines changed

llama_cpp/llama.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from pathlib import Path
23
import sys
34
import uuid
45
import time
@@ -23,6 +24,7 @@
2324

2425
from . import llama_cpp
2526
from .llama_types import *
27+
from .llama_grammar import LlamaGrammar
2628

2729
import numpy as np
2830
import numpy.typing as npt
@@ -223,6 +225,7 @@ def __init__(
223225
tensor_split: Optional[List[float]] = None,
224226
rope_freq_base: float = 10000.0,
225227
rope_freq_scale: float = 1.0,
228+
grammar: Optional[Union[str, Path]] = None,
226229
n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b
227230
rms_norm_eps: Optional[float] = None, # (TEMPORARY)
228231
verbose: bool = True,
@@ -248,6 +251,7 @@ def __init__(
248251
tensor_split: List of floats to split the model across multiple GPUs. If None, the model is not split.
249252
rope_freq_base: Base frequency for rope sampling.
250253
rope_freq_scale: Scale factor for rope sampling.
254+
grammar: Path to a BNF grammar file to use for grammar based sampling.
251255
verbose: Print verbose output to stderr.
252256
253257
Raises:
@@ -358,6 +362,12 @@ def __init__(
358362
self.scores: npt.NDArray[np.single] = np.ndarray(
359363
(n_ctx, self._n_vocab), dtype=np.single
360364
)
365+
if grammar is not None:
366+
self.grammar = LlamaGrammar.from_file(
367+
grammar
368+
) # type: Optional[LlamaGrammar]
369+
else:
370+
self.grammar = None
361371

362372
@property
363373
def _input_ids(self) -> npt.NDArray[np.intc]:
@@ -542,8 +552,16 @@ def _sample(
542552
)
543553
if not penalize_nl:
544554
candidates.data[self._token_nl].logit = llama_cpp.c_float(nl_logit)
555+
556+
if self.grammar is not None:
557+
llama_cpp.llama_sample_grammar(
558+
ctx=self.ctx,
559+
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
560+
grammar=self.grammar.grammar,
561+
)
562+
545563
if temp.value == 0.0:
546-
return llama_cpp.llama_sample_token_greedy(
564+
id = llama_cpp.llama_sample_token_greedy(
547565
ctx=self.ctx,
548566
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
549567
)
@@ -555,7 +573,7 @@ def _sample(
555573
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
556574
temp=temp,
557575
)
558-
return llama_cpp.llama_sample_token_mirostat(
576+
id = llama_cpp.llama_sample_token_mirostat(
559577
ctx=self.ctx,
560578
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
561579
tau=mirostat_tau,
@@ -570,7 +588,7 @@ def _sample(
570588
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
571589
temp=temp,
572590
)
573-
return llama_cpp.llama_sample_token_mirostat_v2(
591+
id = llama_cpp.llama_sample_token_mirostat_v2(
574592
ctx=self.ctx,
575593
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
576594
tau=mirostat_tau,
@@ -607,10 +625,17 @@ def _sample(
607625
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
608626
temp=temp,
609627
)
610-
return llama_cpp.llama_sample_token(
628+
id = llama_cpp.llama_sample_token(
611629
ctx=self.ctx,
612630
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
613631
)
632+
if self.grammar is not None:
633+
llama_cpp.llama_grammar_accept_token(
634+
ctx=self.ctx,
635+
grammar=self.grammar.grammar,
636+
token=llama_cpp.ctypes.c_int(id),
637+
)
638+
return id
614639

615640
def sample(
616641
self,
@@ -1509,6 +1534,9 @@ def __del__(self):
15091534
if self.ctx is not None:
15101535
llama_cpp.llama_free(self.ctx)
15111536
self.ctx = None
1537+
if self.grammar is not None:
1538+
llama_cpp.llama_grammar_free(self.grammar.grammar)
1539+
self.grammar = None
15121540

15131541
def __getstate__(self):
15141542
return dict(

0 commit comments

Comments
 (0)