Skip to content

Commit 29ec1a0

Browse files
committed
Add llama_beam_search_callback_fn_t and improve comments.
1 parent b1dbc44 commit 29ec1a0

File tree

3 files changed

+55
-41
lines changed

3 files changed

+55
-41
lines changed

examples/beam_search/beam_search.cpp

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,22 @@
2828
#endif
2929

3030

31+
// Custom callback example is called each time the beams lengths increase:
32+
// * Show progress by printing ',' following by number of convergent beam tokens if any.
33+
// * When all beams converge to a common prefix, they are made available in tokens_view.
34+
// This is also called when the stop condition is met, in which case the beam with the
35+
// highest probability is chosen, and its remaining tokens are available in tokens_view.
36+
// Collect them into std::vector<llama_token> response which is pointed to by callback_state.
37+
void beam_search_callback(void* callback_state, llama_tokens_view const tokens_view) {
38+
printf(","); // Show progress
39+
if (size_t const n = tokens_view.size) {
40+
auto* response = static_cast<std::vector<llama_token>*>(callback_state);
41+
response->resize(response->size() + n);
42+
std::copy(tokens_view.tokens, tokens_view.tokens + n, response->end() - n);
43+
printf("%lu", n);
44+
}
45+
fflush(stdout);
46+
}
3147

3248
int main(int argc, char ** argv)
3349
{
@@ -115,9 +131,15 @@ int main(int argc, char ** argv)
115131
}
116132
n_past += tokens_list.size();
117133

118-
int const n_predict = 1024;
119-
char const* response = llama_beam_search(ctx, params.n_beams, n_past, n_predict, params.n_threads);
120-
printf("\nDone:\n\n%s%s\n", params.prompt.c_str(), response);
134+
std::vector<llama_token> response;
135+
int const n_predict = 256;
136+
llama_beam_search(ctx, beam_search_callback, &response, params.n_beams, n_past, n_predict, params.n_threads);
137+
138+
printf("\n\n");
139+
for (llama_token const token_id : response) {
140+
printf("%s", llama_token_to_str(ctx,token_id));
141+
}
142+
printf("\n");
121143
#else
122144
//---------------------------------
123145
// Main prediction loop :

llama.cpp

Lines changed: 23 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2962,8 +2962,8 @@ struct beam_search {
29622962
int common_prefix_length;
29632963
// true iff llama_eval() has been called with common prefix in current loop iteration.
29642964
bool common_prefix_evaluated;
2965-
// Save token prefix common to all beams here
2966-
std::vector<llama_token> response;
2965+
// Save token prefix common to all beams. Cleared after each loop iteration.
2966+
std::vector<llama_token> common_prefix;
29672967

29682968
beam_search(llama_context * ctx, int beam_width, int n_past, int n_predict, int n_threads)
29692969
: ctx(ctx)
@@ -2990,11 +2990,11 @@ struct beam_search {
29902990
return common_prefix_length;
29912991
}
29922992

2993-
// Min-heaps are used to efficiently gather the top-k elements (k=beam_width).
2993+
// Min-heaps are used to efficiently collect the top-k elements (k=beam_width).
29942994
// The repetative patterns below reflect the 2 stages of heaps:
29952995
// * Gather elements until the vector is full, then call std::make_heap() on it.
2996-
// * If the heap is full and a new element is found that should be included,
2997-
// pop off the least element, replace it with the new, then push it into the heap.
2996+
// * If the heap is full and a new element is found that should be included, pop the
2997+
// least element to the back(), replace it with the new, then push it into the heap.
29982998
void fill_next_beams_by_top_probabilities(beam& b) {
29992999
// Min-heaps use a greater-than comparator.
30003000
auto const comp = [](beam const& a, beam const& b) { return a.p > b.p; };
@@ -3006,7 +3006,7 @@ struct beam_search {
30063006
if (b.eos()) {
30073007
// beam is at end-of-sentence, so just copy it to next_beams if its probability is high enough.
30083008
if (next_beams.size() < static_cast<size_t>(beam_width)) {
3009-
next_beams.push_back(b);
3009+
next_beams.push_back(std::move(b));
30103010
if (next_beams.size() == static_cast<size_t>(beam_width)) {
30113011
std::make_heap(next_beams.begin(), next_beams.end(), comp);
30123012
}
@@ -3020,7 +3020,7 @@ struct beam_search {
30203020
if (!b.tokens.empty()) {
30213021
llama_eval(ctx, b.tokens.data(), b.tokens.size(), n_past, n_threads);
30223022
if (!common_prefix_evaluated && common_prefix_length) {
3023-
b.shift_tokens(response, common_prefix_length);
3023+
b.shift_tokens(common_prefix, common_prefix_length);
30243024
n_past += common_prefix_length;
30253025
common_prefix_evaluated = true;
30263026
}
@@ -3059,12 +3059,12 @@ struct beam_search {
30593059
}
30603060

30613061
// Loop:
3062-
// * while i < n_predict
3063-
// * until all of the beams have nreached end-of-sentence
3062+
// * while i < n_predict, OR
3063+
// * until all of the beams have nreached end-of-sentence, OR
30643064
// * until the highest probability beam is at end-of-sentence
30653065
// (since all other beam probabilities can only decrease)
3066-
void loop(std::function<void(std::vector<beam>&)> const callback) {
3067-
beams.push_back({{}, 1.0f});
3066+
void loop(llama_beam_search_callback_fn_t const callback, void* const callback_state) {
3067+
beams.push_back({{}, 1.0f}); // Start with one empty beam w/ probability = 1.0.
30683068
auto const eos = [](beam const& beam) { return beam.eos(); };
30693069
for (int i=0 ; i<n_predict && !std::all_of(beams.begin(),beams.end(),eos) && !eos(top_beam()) ; ++i) {
30703070
common_prefix_evaluated = false;
@@ -3075,10 +3075,14 @@ struct beam_search {
30753075
beams.swap(next_beams);
30763076
renormalize_beam_probabilities(beams);
30773077
std::for_each(next_beams.begin(), next_beams.end(), [](beam& beam) { beam.p = 0.0f; });
3078-
callback(beams);
3078+
llama_tokens_view const common_beam_prefix{common_prefix.data(), common_prefix.size()};
3079+
callback(callback_state, common_beam_prefix);
3080+
common_prefix.clear();
30793081
}
30803082
beam& top_b = top_beam();
3081-
top_b.shift_tokens(response, top_b.tokens.size());
3083+
top_b.shift_tokens(common_prefix, top_b.tokens.size());
3084+
llama_tokens_view const common_beam_prefix{common_prefix.data(), common_prefix.size()};
3085+
callback(callback_state, common_beam_prefix);
30823086
}
30833087

30843088
// As beams grow, the cumulative probabilities decrease.
@@ -3096,38 +3100,20 @@ struct beam_search {
30963100
}
30973101
};
30983102

3099-
// Not thread-safe.
3100-
const char* llama_beam_search(llama_context * ctx, int beam_width,
3101-
int n_past, int const n_predict, int const n_threads) {
3102-
static std::string beam_search_response;
3103+
void llama_beam_search(llama_context * ctx,
3104+
llama_beam_search_callback_fn_t callback, void* callback_state,
3105+
int beam_width, int n_past, int const n_predict, int const n_threads) {
31033106
assert(ctx);
31043107
const int64_t t_start_sample_us = ggml_time_us();
31053108

31063109
beam_search beam_search(ctx, beam_width, n_past, n_predict, n_threads);
31073110

3108-
beam_search.loop([&](std::vector<beam>& beams) {
3109-
#if 1 // DEBUG: print current beams for this iteration
3110-
std::cout << "\n\nCurrent beams:\n";
3111-
for (size_t j=0 ; j < beams.size() ; ++j) {
3112-
std::cout << "beams["<<j<<"]: ";
3113-
out_beam(std::cout, ctx, beams[j]);
3114-
std::cout << std::endl;
3115-
}
3116-
#else
3117-
std::cout << '.' << std::flush; // Show progress
3118-
#endif
3119-
});
3120-
3121-
// Save beam sentence to beam_search_response. Is there a better way?
3122-
std::ostringstream oss;
3123-
for (llama_token const token : beam_search.response) {
3124-
oss << llama_token_to_str(ctx, token);
3125-
}
3126-
beam_search_response = oss.str();
3111+
// callback(callback_state, common_beam_prefix) is called on each iteration, and when
3112+
// stop condition is met with remaining tokens from beam with the highest probability.
3113+
beam_search.loop(callback, callback_state);
31273114

31283115
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
31293116
ctx->n_sample++;
3130-
return beam_search_response.c_str();
31313117
}
31323118

31333119
//

llama.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,13 +443,19 @@ extern "C" {
443443
/// @details Accepts the sampled token into the grammar
444444
LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token);
445445

446+
struct llama_tokens_view {
447+
llama_token const* tokens;
448+
size_t size;
449+
};
450+
typedef void (*llama_beam_search_callback_fn_t)(void* state, llama_tokens_view const);
451+
446452
/// @details Deterministically returns entire sentence constructed by a beam search.
447453
/// @param ctx Pointer to the llama_context.
448454
/// @param beam_width The number of parallel beams to use.
449455
/// @param n_past The number of tokens already evaluated.
450456
/// @param n_predict The maximum number of tokens to predict.
451457
/// @param n_threads The maximum number of threads as passed to llama_eval().
452-
LLAMA_API const char* llama_beam_search(struct llama_context * ctx, int beam_width, int n_past, int n_predict, int n_threads);
458+
LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void* callback_state, int beam_width, int n_past, int n_predict, int n_threads);
453459

454460
// Performance information
455461
LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);

0 commit comments

Comments
 (0)