17
17
Callable ,
18
18
)
19
19
from collections import deque , OrderedDict
20
+ from dataclasses import dataclass
20
21
21
22
import diskcache
22
23
import ctypes
@@ -205,6 +206,44 @@ def __call__(
205
206
) -> bool :
206
207
return any ([stopping_criteria (input_ids , logits ) for stopping_criteria in self ])
207
208
209
+ # Custom data that is accessible to the beam_search_callback() function.
210
+ @dataclass
211
+ class beam_search_callback_data :
212
+ ctx : llama_cpp .llama_context_p
213
+ response_tokens : List [int ]
214
+
215
+ # Used for debugging to view beam states
216
+ def beam_view_to_string (ctx , beam_view ):
217
+ string = f"p({ beam_view .p } ): "
218
+ for i in range (beam_view .n_tokens ):
219
+ string += llama_cpp .llama_token_get_text (ctx , beam_view .tokens [i ]).decode ("utf-8" )
220
+ return string
221
+
222
+ # One requirement of the callback is that it MUST determine end-of-beam.
223
+ def is_at_eob (ctx , tokens , n_tokens ) :
224
+ return 0 < n_tokens and tokens [n_tokens - 1 ] == llama_cpp .llama_token_eos (ctx );
225
+
226
+ # beam_search_callback requires a global dictionary to pass data via their object id.
227
+ beam_search_dictionary = {}
228
+
229
+ # beam_search_callback() must flag beams when they reach end-of-sentence.
230
+ # TODO: Use stop_sequences.
231
+ def beam_search_callback (callback_data_id , beams_state ):
232
+ callback_data = beam_search_dictionary [callback_data_id ]
233
+ for i in range (beams_state .n_beams ):
234
+ beam_view = beams_state .beam_views [i ]
235
+ if not beam_view .eob and is_at_eob (callback_data .ctx , beam_view .tokens , beam_view .n_tokens ):
236
+ beam_view .eob = True ; # Flag beams as EOB as required.
237
+ # Collect tokens into callback_data.response_tokens
238
+ if 0 < beams_state .common_prefix_length :
239
+ assert (0 < beams_state .n_beams );
240
+ tokens = ctypes .cast (beams_state .beam_views [0 ].tokens , ctypes .POINTER (ctypes .c_int * beams_state .common_prefix_length )).contents
241
+ callback_data .response_tokens .extend (tokens )
242
+
243
+ # DEBUG print beams and their relative probabilities
244
+ #print(f"\n\nCurrent beams (last_call={beams_state.last_call}):\n")
245
+ #for i in range(beams_state.n_beams):
246
+ # print(f"beams[{i}]", beam_view_to_string(callback_data.ctx,beams_state.beam_views[i]))
208
247
209
248
class Llama :
210
249
"""High-level Python wrapper for a llama.cpp model."""
@@ -494,6 +533,7 @@ def eval(self, tokens: Sequence[int]):
494
533
tokens: The list of tokens to evaluate.
495
534
"""
496
535
assert self .ctx is not None
536
+
497
537
n_ctx = self ._n_ctx
498
538
for i in range (0 , len (tokens ), self .n_batch ):
499
539
batch = tokens [i : min (len (tokens ), i + self .n_batch )]
@@ -734,6 +774,7 @@ def generate(
734
774
logits_processor : Optional [LogitsProcessorList ] = None ,
735
775
stopping_criteria : Optional [StoppingCriteriaList ] = None ,
736
776
grammar : Optional [LlamaGrammar ] = None ,
777
+ beam_width : int = 0 ,
737
778
) -> Generator [int , Optional [Sequence [int ]], None ]:
738
779
"""Create a generator of tokens from a prompt.
739
780
@@ -775,6 +816,26 @@ def generate(
775
816
if grammar is not None :
776
817
grammar .reset ()
777
818
819
+ if 0 < beam_width :
820
+ self .eval (tokens )
821
+ callback_data = beam_search_callback_data (self .ctx , [])
822
+ beam_search_dictionary [id (callback_data )] = callback_data
823
+ callback = llama_cpp .llama_beam_search_callback_fn_t (beam_search_callback )
824
+ n_remain = llama_cpp .llama_n_ctx (self .ctx ) - self .n_tokens
825
+ llama_cpp .llama_beam_search (self .ctx , callback , id (callback_data ),
826
+ beam_width ,
827
+ self .n_tokens ,
828
+ n_remain ,
829
+ self .n_threads )
830
+ beam_search_dictionary .pop (id (callback_data ))
831
+ # It would be nicer if we could yield from within the callback, but that is impossible.
832
+ for token in callback_data .response_tokens :
833
+ np .append (self .input_ids , [token ])
834
+ np .append (self .scores , [0.0 ])
835
+ self .n_tokens += 1
836
+ yield token
837
+ return
838
+
778
839
while True :
779
840
self .eval (tokens )
780
841
token = self .sample (
@@ -791,6 +852,7 @@ def generate(
791
852
logits_processor = logits_processor ,
792
853
grammar = grammar ,
793
854
)
855
+
794
856
if stopping_criteria is not None and stopping_criteria (
795
857
self ._input_ids , self ._scores [- 1 , :]
796
858
):
@@ -893,6 +955,7 @@ def _create_completion(
893
955
stopping_criteria : Optional [StoppingCriteriaList ] = None ,
894
956
logits_processor : Optional [LogitsProcessorList ] = None ,
895
957
grammar : Optional [LlamaGrammar ] = None ,
958
+ beam_width : int = 0 ,
896
959
) -> Union [Iterator [Completion ], Iterator [CompletionChunk ]]:
897
960
assert self .ctx is not None
898
961
@@ -971,6 +1034,7 @@ def _create_completion(
971
1034
stopping_criteria = stopping_criteria ,
972
1035
logits_processor = logits_processor ,
973
1036
grammar = grammar ,
1037
+ beam_width = beam_width ,
974
1038
):
975
1039
if token == self ._token_eos :
976
1040
text = self .detokenize (completion_tokens )
@@ -1354,6 +1418,7 @@ def create_completion(
1354
1418
stopping_criteria : Optional [StoppingCriteriaList ] = None ,
1355
1419
logits_processor : Optional [LogitsProcessorList ] = None ,
1356
1420
grammar : Optional [LlamaGrammar ] = None ,
1421
+ beam_width : int = 0 ,
1357
1422
) -> Union [Completion , Iterator [CompletionChunk ]]:
1358
1423
"""Generate text from a prompt.
1359
1424
@@ -1369,6 +1434,7 @@ def create_completion(
1369
1434
repeat_penalty: The penalty to apply to repeated tokens.
1370
1435
top_k: The top-k value to use for sampling.
1371
1436
stream: Whether to stream the results.
1437
+ beam_width: Number of beams to use in beam search. 0 disables.
1372
1438
1373
1439
Raises:
1374
1440
ValueError: If the requested tokens exceed the context window.
@@ -1398,7 +1464,8 @@ def create_completion(
1398
1464
model = model ,
1399
1465
stopping_criteria = stopping_criteria ,
1400
1466
logits_processor = logits_processor ,
1401
- grammar = grammar
1467
+ grammar = grammar ,
1468
+ beam_width = beam_width ,
1402
1469
)
1403
1470
if stream :
1404
1471
chunks : Iterator [CompletionChunk ] = completion_or_chunks
@@ -1429,6 +1496,7 @@ def __call__(
1429
1496
stopping_criteria : Optional [StoppingCriteriaList ] = None ,
1430
1497
logits_processor : Optional [LogitsProcessorList ] = None ,
1431
1498
grammar : Optional [LlamaGrammar ] = None ,
1499
+ beam_width : int = 0 ,
1432
1500
) -> Union [Completion , Iterator [CompletionChunk ]]:
1433
1501
"""Generate text from a prompt.
1434
1502
@@ -1444,6 +1512,7 @@ def __call__(
1444
1512
repeat_penalty: The penalty to apply to repeated tokens.
1445
1513
top_k: The top-k value to use for sampling.
1446
1514
stream: Whether to stream the results.
1515
+ beam_width: Number of beams to use in beam search. 0 disables.
1447
1516
1448
1517
Raises:
1449
1518
ValueError: If the requested tokens exceed the context window.
@@ -1474,6 +1543,7 @@ def __call__(
1474
1543
stopping_criteria = stopping_criteria ,
1475
1544
logits_processor = logits_processor ,
1476
1545
grammar = grammar ,
1546
+ beam_width = beam_width ,
1477
1547
)
1478
1548
1479
1549
def _convert_text_completion_to_chat (
0 commit comments