Skip to content

Commit 8044fd8

Browse files
committed
Add beam search. Invoke by adding "beam_search": 2 (for example) to /v1/completions POST.
1 parent 3afbf2e commit 8044fd8

File tree

3 files changed

+75
-2
lines changed

3 files changed

+75
-2
lines changed

llama_cpp/llama.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
Callable,
1818
)
1919
from collections import deque, OrderedDict
20+
from dataclasses import dataclass
2021

2122
import diskcache
2223
import ctypes
@@ -205,6 +206,44 @@ def __call__(
205206
) -> bool:
206207
return any([stopping_criteria(input_ids, logits) for stopping_criteria in self])
207208

209+
# Custom data that is accessible to the beam_search_callback() function.
210+
@dataclass
211+
class beam_search_callback_data:
212+
ctx: llama_cpp.llama_context_p
213+
response_tokens: List[int]
214+
215+
# Used for debugging to view beam states
216+
def beam_view_to_string(ctx, beam_view):
217+
string = f"p({beam_view.p}): "
218+
for i in range(beam_view.n_tokens):
219+
string += llama_cpp.llama_token_get_text(ctx, beam_view.tokens[i]).decode("utf-8")
220+
return string
221+
222+
# One requirement of the callback is that it MUST determine end-of-beam.
223+
def is_at_eob(ctx, tokens, n_tokens) :
224+
return 0 < n_tokens and tokens[n_tokens-1] == llama_cpp.llama_token_eos(ctx);
225+
226+
# beam_search_callback requires a global dictionary to pass data via their object id.
227+
beam_search_dictionary = {}
228+
229+
# beam_search_callback() must flag beams when they reach end-of-sentence.
230+
# TODO: Use stop_sequences.
231+
def beam_search_callback(callback_data_id, beams_state):
232+
callback_data = beam_search_dictionary[callback_data_id]
233+
for i in range(beams_state.n_beams):
234+
beam_view = beams_state.beam_views[i]
235+
if not beam_view.eob and is_at_eob(callback_data.ctx, beam_view.tokens, beam_view.n_tokens):
236+
beam_view.eob = True; # Flag beams as EOB as required.
237+
# Collect tokens into callback_data.response_tokens
238+
if 0 < beams_state.common_prefix_length:
239+
assert(0 < beams_state.n_beams);
240+
tokens = ctypes.cast(beams_state.beam_views[0].tokens, ctypes.POINTER(ctypes.c_int * beams_state.common_prefix_length)).contents
241+
callback_data.response_tokens.extend(tokens)
242+
243+
# DEBUG print beams and their relative probabilities
244+
#print(f"\n\nCurrent beams (last_call={beams_state.last_call}):\n")
245+
#for i in range(beams_state.n_beams):
246+
# print(f"beams[{i}]", beam_view_to_string(callback_data.ctx,beams_state.beam_views[i]))
208247

209248
class Llama:
210249
"""High-level Python wrapper for a llama.cpp model."""
@@ -494,6 +533,7 @@ def eval(self, tokens: Sequence[int]):
494533
tokens: The list of tokens to evaluate.
495534
"""
496535
assert self.ctx is not None
536+
497537
n_ctx = self._n_ctx
498538
for i in range(0, len(tokens), self.n_batch):
499539
batch = tokens[i : min(len(tokens), i + self.n_batch)]
@@ -734,6 +774,7 @@ def generate(
734774
logits_processor: Optional[LogitsProcessorList] = None,
735775
stopping_criteria: Optional[StoppingCriteriaList] = None,
736776
grammar: Optional[LlamaGrammar] = None,
777+
beam_width: int = 0,
737778
) -> Generator[int, Optional[Sequence[int]], None]:
738779
"""Create a generator of tokens from a prompt.
739780
@@ -775,6 +816,26 @@ def generate(
775816
if grammar is not None:
776817
grammar.reset()
777818

819+
if 0 < beam_width:
820+
self.eval(tokens)
821+
callback_data = beam_search_callback_data(self.ctx, [])
822+
beam_search_dictionary[id(callback_data)] = callback_data
823+
callback = llama_cpp.llama_beam_search_callback_fn_t(beam_search_callback)
824+
n_remain = llama_cpp.llama_n_ctx(self.ctx) - self.n_tokens
825+
llama_cpp.llama_beam_search(self.ctx, callback, id(callback_data),
826+
beam_width,
827+
self.n_tokens,
828+
n_remain,
829+
self.n_threads)
830+
beam_search_dictionary.pop(id(callback_data))
831+
# It would be nicer if we could yield from within the callback, but that is impossible.
832+
for token in callback_data.response_tokens:
833+
np.append(self.input_ids, [token])
834+
np.append(self.scores, [0.0])
835+
self.n_tokens += 1
836+
yield token
837+
return
838+
778839
while True:
779840
self.eval(tokens)
780841
token = self.sample(
@@ -791,6 +852,7 @@ def generate(
791852
logits_processor=logits_processor,
792853
grammar=grammar,
793854
)
855+
794856
if stopping_criteria is not None and stopping_criteria(
795857
self._input_ids, self._scores[-1, :]
796858
):
@@ -893,6 +955,7 @@ def _create_completion(
893955
stopping_criteria: Optional[StoppingCriteriaList] = None,
894956
logits_processor: Optional[LogitsProcessorList] = None,
895957
grammar: Optional[LlamaGrammar] = None,
958+
beam_width: int = 0,
896959
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
897960
assert self.ctx is not None
898961

@@ -971,6 +1034,7 @@ def _create_completion(
9711034
stopping_criteria=stopping_criteria,
9721035
logits_processor=logits_processor,
9731036
grammar=grammar,
1037+
beam_width=beam_width,
9741038
):
9751039
if token == self._token_eos:
9761040
text = self.detokenize(completion_tokens)
@@ -1354,6 +1418,7 @@ def create_completion(
13541418
stopping_criteria: Optional[StoppingCriteriaList] = None,
13551419
logits_processor: Optional[LogitsProcessorList] = None,
13561420
grammar: Optional[LlamaGrammar] = None,
1421+
beam_width: int = 0,
13571422
) -> Union[Completion, Iterator[CompletionChunk]]:
13581423
"""Generate text from a prompt.
13591424
@@ -1369,6 +1434,7 @@ def create_completion(
13691434
repeat_penalty: The penalty to apply to repeated tokens.
13701435
top_k: The top-k value to use for sampling.
13711436
stream: Whether to stream the results.
1437+
beam_width: Number of beams to use in beam search. 0 disables.
13721438
13731439
Raises:
13741440
ValueError: If the requested tokens exceed the context window.
@@ -1398,7 +1464,8 @@ def create_completion(
13981464
model=model,
13991465
stopping_criteria=stopping_criteria,
14001466
logits_processor=logits_processor,
1401-
grammar=grammar
1467+
grammar=grammar,
1468+
beam_width=beam_width,
14021469
)
14031470
if stream:
14041471
chunks: Iterator[CompletionChunk] = completion_or_chunks
@@ -1429,6 +1496,7 @@ def __call__(
14291496
stopping_criteria: Optional[StoppingCriteriaList] = None,
14301497
logits_processor: Optional[LogitsProcessorList] = None,
14311498
grammar: Optional[LlamaGrammar] = None,
1499+
beam_width: int = 0,
14321500
) -> Union[Completion, Iterator[CompletionChunk]]:
14331501
"""Generate text from a prompt.
14341502
@@ -1444,6 +1512,7 @@ def __call__(
14441512
repeat_penalty: The penalty to apply to repeated tokens.
14451513
top_k: The top-k value to use for sampling.
14461514
stream: Whether to stream the results.
1515+
beam_width: Number of beams to use in beam search. 0 disables.
14471516
14481517
Raises:
14491518
ValueError: If the requested tokens exceed the context window.
@@ -1474,6 +1543,7 @@ def __call__(
14741543
stopping_criteria=stopping_criteria,
14751544
logits_processor=logits_processor,
14761545
grammar=grammar,
1546+
beam_width=beam_width,
14771547
)
14781548

14791549
def _convert_text_completion_to_chat(

llama_cpp/llama_cpp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1456,7 +1456,7 @@ class llama_beams_state(ctypes.Structure):
14561456
# LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, size_t n_beams, int n_past, int n_predict, int n_threads);
14571457
def llama_beam_search(
14581458
ctx: llama_context_p,
1459-
callback: "ctypes._CFuncPtr[None, c_void_p, llama_beams_state]", # type: ignore
1459+
callback: llama_beam_search_callback_fn_t,
14601460
callback_data: c_void_p,
14611461
n_beams: Union[c_size_t, int],
14621462
n_past: Union[c_int, int],
@@ -1467,6 +1467,8 @@ def llama_beam_search(
14671467
ctx, callback, callback_data, n_beams, n_past, n_predict, n_threads
14681468
)
14691469

1470+
_lib.llama_beam_search.argtypes = [llama_context_p, llama_beam_search_callback_fn_t, c_void_p, c_size_t, c_int, c_int, c_int]
1471+
_lib.llama_beam_search.restype = None
14701472

14711473
# Performance information
14721474

llama_cpp/server/app.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,7 @@ class CreateCompletionRequest(BaseModel):
549549
top_k: int = top_k_field
550550
repeat_penalty: float = repeat_penalty_field
551551
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
552+
beam_width: int = 0
552553

553554
model_config = {
554555
"json_schema_extra": {

0 commit comments

Comments
 (0)