Skip to content

Commit ce55910

Browse files
committed
Add llama_beam_search_callback_fn_t and improve comments.
1 parent 2beb315 commit ce55910

File tree

3 files changed

+55
-41
lines changed

3 files changed

+55
-41
lines changed

examples/beam_search/beam_search.cpp

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,22 @@
2828
#endif
2929

3030

31+
// Custom callback example is called each time the beams lengths increase:
32+
// * Show progress by printing ',' following by number of convergent beam tokens if any.
33+
// * When all beams converge to a common prefix, they are made available in tokens_view.
34+
// This is also called when the stop condition is met, in which case the beam with the
35+
// highest probability is chosen, and its remaining tokens are available in tokens_view.
36+
// Collect them into std::vector<llama_token> response which is pointed to by callback_state.
37+
void beam_search_callback(void* callback_state, llama_tokens_view const tokens_view) {
38+
printf(","); // Show progress
39+
if (size_t const n = tokens_view.size) {
40+
auto* response = static_cast<std::vector<llama_token>*>(callback_state);
41+
response->resize(response->size() + n);
42+
std::copy(tokens_view.tokens, tokens_view.tokens + n, response->end() - n);
43+
printf("%lu", n);
44+
}
45+
fflush(stdout);
46+
}
3147

3248
int main(int argc, char ** argv)
3349
{
@@ -115,9 +131,15 @@ int main(int argc, char ** argv)
115131
}
116132
n_past += tokens_list.size();
117133

118-
int const n_predict = 1024;
119-
char const* response = llama_beam_search(ctx, params.n_beams, n_past, n_predict, params.n_threads);
120-
printf("\nDone:\n\n%s%s\n", params.prompt.c_str(), response);
134+
std::vector<llama_token> response;
135+
int const n_predict = 256;
136+
llama_beam_search(ctx, beam_search_callback, &response, params.n_beams, n_past, n_predict, params.n_threads);
137+
138+
printf("\n\n");
139+
for (llama_token const token_id : response) {
140+
printf("%s", llama_token_to_str(ctx,token_id));
141+
}
142+
printf("\n");
121143
#else
122144
//---------------------------------
123145
// Main prediction loop :

llama.cpp

Lines changed: 23 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2977,8 +2977,8 @@ struct beam_search {
29772977
int common_prefix_length;
29782978
// true iff llama_eval() has been called with common prefix in current loop iteration.
29792979
bool common_prefix_evaluated;
2980-
// Save token prefix common to all beams here
2981-
std::vector<llama_token> response;
2980+
// Save token prefix common to all beams. Cleared after each loop iteration.
2981+
std::vector<llama_token> common_prefix;
29822982

29832983
beam_search(llama_context * ctx, int beam_width, int n_past, int n_predict, int n_threads)
29842984
: ctx(ctx)
@@ -3005,11 +3005,11 @@ struct beam_search {
30053005
return common_prefix_length;
30063006
}
30073007

3008-
// Min-heaps are used to efficiently gather the top-k elements (k=beam_width).
3008+
// Min-heaps are used to efficiently collect the top-k elements (k=beam_width).
30093009
// The repetative patterns below reflect the 2 stages of heaps:
30103010
// * Gather elements until the vector is full, then call std::make_heap() on it.
3011-
// * If the heap is full and a new element is found that should be included,
3012-
// pop off the least element, replace it with the new, then push it into the heap.
3011+
// * If the heap is full and a new element is found that should be included, pop the
3012+
// least element to the back(), replace it with the new, then push it into the heap.
30133013
void fill_next_beams_by_top_probabilities(beam& b) {
30143014
// Min-heaps use a greater-than comparator.
30153015
auto const comp = [](beam const& a, beam const& b) { return a.p > b.p; };
@@ -3021,7 +3021,7 @@ struct beam_search {
30213021
if (b.eos()) {
30223022
// beam is at end-of-sentence, so just copy it to next_beams if its probability is high enough.
30233023
if (next_beams.size() < static_cast<size_t>(beam_width)) {
3024-
next_beams.push_back(b);
3024+
next_beams.push_back(std::move(b));
30253025
if (next_beams.size() == static_cast<size_t>(beam_width)) {
30263026
std::make_heap(next_beams.begin(), next_beams.end(), comp);
30273027
}
@@ -3035,7 +3035,7 @@ struct beam_search {
30353035
if (!b.tokens.empty()) {
30363036
llama_eval(ctx, b.tokens.data(), b.tokens.size(), n_past, n_threads);
30373037
if (!common_prefix_evaluated && common_prefix_length) {
3038-
b.shift_tokens(response, common_prefix_length);
3038+
b.shift_tokens(common_prefix, common_prefix_length);
30393039
n_past += common_prefix_length;
30403040
common_prefix_evaluated = true;
30413041
}
@@ -3074,12 +3074,12 @@ struct beam_search {
30743074
}
30753075

30763076
// Loop:
3077-
// * while i < n_predict
3078-
// * until all of the beams have nreached end-of-sentence
3077+
// * while i < n_predict, OR
3078+
// * until all of the beams have nreached end-of-sentence, OR
30793079
// * until the highest probability beam is at end-of-sentence
30803080
// (since all other beam probabilities can only decrease)
3081-
void loop(std::function<void(std::vector<beam>&)> const callback) {
3082-
beams.push_back({{}, 1.0f});
3081+
void loop(llama_beam_search_callback_fn_t const callback, void* const callback_state) {
3082+
beams.push_back({{}, 1.0f}); // Start with one empty beam w/ probability = 1.0.
30833083
auto const eos = [](beam const& beam) { return beam.eos(); };
30843084
for (int i=0 ; i<n_predict && !std::all_of(beams.begin(),beams.end(),eos) && !eos(top_beam()) ; ++i) {
30853085
common_prefix_evaluated = false;
@@ -3090,10 +3090,14 @@ struct beam_search {
30903090
beams.swap(next_beams);
30913091
renormalize_beam_probabilities(beams);
30923092
std::for_each(next_beams.begin(), next_beams.end(), [](beam& beam) { beam.p = 0.0f; });
3093-
callback(beams);
3093+
llama_tokens_view const common_beam_prefix{common_prefix.data(), common_prefix.size()};
3094+
callback(callback_state, common_beam_prefix);
3095+
common_prefix.clear();
30943096
}
30953097
beam& top_b = top_beam();
3096-
top_b.shift_tokens(response, top_b.tokens.size());
3098+
top_b.shift_tokens(common_prefix, top_b.tokens.size());
3099+
llama_tokens_view const common_beam_prefix{common_prefix.data(), common_prefix.size()};
3100+
callback(callback_state, common_beam_prefix);
30973101
}
30983102

30993103
// As beams grow, the cumulative probabilities decrease.
@@ -3111,38 +3115,20 @@ struct beam_search {
31113115
}
31123116
};
31133117

3114-
// Not thread-safe.
3115-
const char* llama_beam_search(llama_context * ctx, int beam_width,
3116-
int n_past, int const n_predict, int const n_threads) {
3117-
static std::string beam_search_response;
3118+
void llama_beam_search(llama_context * ctx,
3119+
llama_beam_search_callback_fn_t callback, void* callback_state,
3120+
int beam_width, int n_past, int const n_predict, int const n_threads) {
31183121
assert(ctx);
31193122
const int64_t t_start_sample_us = ggml_time_us();
31203123

31213124
beam_search beam_search(ctx, beam_width, n_past, n_predict, n_threads);
31223125

3123-
beam_search.loop([&](std::vector<beam>& beams) {
3124-
#if 1 // DEBUG: print current beams for this iteration
3125-
std::cout << "\n\nCurrent beams:\n";
3126-
for (size_t j=0 ; j < beams.size() ; ++j) {
3127-
std::cout << "beams["<<j<<"]: ";
3128-
out_beam(std::cout, ctx, beams[j]);
3129-
std::cout << std::endl;
3130-
}
3131-
#else
3132-
std::cout << '.' << std::flush; // Show progress
3133-
#endif
3134-
});
3135-
3136-
// Save beam sentence to beam_search_response. Is there a better way?
3137-
std::ostringstream oss;
3138-
for (llama_token const token : beam_search.response) {
3139-
oss << llama_token_to_str(ctx, token);
3140-
}
3141-
beam_search_response = oss.str();
3126+
// callback(callback_state, common_beam_prefix) is called on each iteration, and when
3127+
// stop condition is met with remaining tokens from beam with the highest probability.
3128+
beam_search.loop(callback, callback_state);
31423129

31433130
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
31443131
ctx->n_sample++;
3145-
return beam_search_response.c_str();
31463132
}
31473133

31483134
//

llama.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,13 +460,19 @@ extern "C" {
460460
/// @details Accepts the sampled token into the grammar
461461
LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token);
462462

463+
struct llama_tokens_view {
464+
llama_token const* tokens;
465+
size_t size;
466+
};
467+
typedef void (*llama_beam_search_callback_fn_t)(void* state, llama_tokens_view const);
468+
463469
/// @details Deterministically returns entire sentence constructed by a beam search.
464470
/// @param ctx Pointer to the llama_context.
465471
/// @param beam_width The number of parallel beams to use.
466472
/// @param n_past The number of tokens already evaluated.
467473
/// @param n_predict The maximum number of tokens to predict.
468474
/// @param n_threads The maximum number of threads as passed to llama_eval().
469-
LLAMA_API const char* llama_beam_search(struct llama_context * ctx, int beam_width, int n_past, int n_predict, int n_threads);
475+
LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void* callback_state, int beam_width, int n_past, int n_predict, int n_threads);
470476

471477
// Performance information
472478
LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);

0 commit comments

Comments
 (0)