Skip to content

Commit f20e584

Browse files
committed
Improve beam_search callback pattern by giving access to beams_state, and to return a beam_search_control struct to control execution.
1 parent ef7850d commit f20e584

File tree

3 files changed

+100
-49
lines changed

3 files changed

+100
-49
lines changed

examples/beam_search/beam_search.cpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,21 +28,29 @@
2828
#endif
2929

3030

31+
// Function matching type llama_beam_search_callback_fn_t.
3132
// Custom callback example is called each time the beams lengths increase:
3233
// * Show progress by printing ',' following by number of convergent beam tokens if any.
33-
// * When all beams converge to a common prefix, they are made available in tokens_view.
34-
// This is also called when the stop condition is met, in which case the beam with the
35-
// highest probability is chosen, and its remaining tokens are available in tokens_view.
36-
// Collect them into std::vector<llama_token> response which is pointed to by callback_state.
37-
void beam_search_callback(void* callback_state, llama_tokens_view const tokens_view) {
34+
// * When all beams converge to a common prefix, they are made available in beams_state.beams[0].
35+
// This is also called when the stop condition is met.
36+
// Collect tokens into std::vector<llama_token> response which is pointed to by callback_state.
37+
beam_search_control beam_search_callback(void* callback_state, beams_state const beams_state) {
3838
printf(","); // Show progress
39-
if (size_t const n = tokens_view.size) {
39+
if (size_t const n = beams_state.common_prefix_length) {
4040
auto* response = static_cast<std::vector<llama_token>*>(callback_state);
4141
response->resize(response->size() + n);
42-
std::copy(tokens_view.tokens, tokens_view.tokens + n, response->end() - n);
42+
assert(0u < beams_state.n_beams);
43+
std::copy(beams_state.beams[0], beams_state.beams[0] + n, response->end() - n);
4344
printf("%lu", n);
4445
}
4546
fflush(stdout);
47+
#if 0 // DEBUG: print current beams for this iteration
48+
std::cout << "\n\nCurrent beams:\n";
49+
for (size_t j=0 ; j < beams.size() ; ++j) {
50+
std::cout << "beams["<<j<<"]: " << ostream_beam{ctx,beams[j]} << std::endl;
51+
}
52+
#endif
53+
return { beams_state.n_beams, false }; // Continue beam search.
4654
}
4755

4856
int main(int argc, char ** argv)

llama.cpp

Lines changed: 63 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2977,30 +2977,28 @@ struct beam_search {
29772977
// true iff llama_eval() has been called with non-empty common prefix in current loop iteration.
29782978
bool common_prefix_evaluated;
29792979

2980+
// Memory used by beam_state
2981+
std::vector<size_t> beam_lengths;
2982+
std::vector<llama_token const*> beam_ptrs;
2983+
29802984
beam_search(llama_context * ctx, int beam_width, int n_past, int n_predict, int n_threads)
29812985
: ctx(ctx)
29822986
, beam_width(beam_width)
29832987
, n_past(n_past)
29842988
, n_predict(n_predict)
2985-
, n_threads(n_threads) {
2989+
, n_threads(n_threads)
2990+
, beam_lengths(beam_width)
2991+
, beam_ptrs(beam_width) {
29862992
beams.reserve(beam_width);
29872993
next_beams.reserve(beam_width);
29882994
}
29892995

2990-
// Find common_prefix_length based on beams.
2991-
// Requires beams is not empty.
2992-
size_t find_common_prefix_length() {
2993-
size_t common_prefix_length = beams[0].tokens.size();
2994-
for (size_t i=1 ; i<beams.size() ; ++i) {
2995-
common_prefix_length = std::min(common_prefix_length, beams[i].tokens.size());
2996-
for (size_t j=0 ; j<common_prefix_length ; ++j) {
2997-
if (beams[0].tokens[j] != beams[i].tokens[j]) {
2998-
common_prefix_length = j;
2999-
break;
3000-
}
3001-
}
2996+
// Collapse beams to a single beam given by index.
2997+
void collapse_beams(size_t const beam_idx) {
2998+
if (0u < beam_idx) {
2999+
std::swap(beams[0], beams[beam_idx]);
30023000
}
3003-
return common_prefix_length;
3001+
beams.resize(1);
30043002
}
30053003

30063004
// Min-heaps are used to efficiently collect the top-k elements (k=beam_width).
@@ -3071,49 +3069,78 @@ struct beam_search {
30713069
}
30723070
}
30733071

3072+
// Find common_prefix_length based on beams.
3073+
// Requires beams is not empty.
3074+
size_t find_common_prefix_length() {
3075+
size_t common_prefix_length = beams[0].tokens.size();
3076+
for (size_t i=1 ; i<beams.size() ; ++i) {
3077+
common_prefix_length = std::min(common_prefix_length, beams[i].tokens.size());
3078+
for (size_t j=0 ; j<common_prefix_length ; ++j) {
3079+
if (beams[0].tokens[j] != beams[i].tokens[j]) {
3080+
common_prefix_length = j;
3081+
break;
3082+
}
3083+
}
3084+
}
3085+
return common_prefix_length;
3086+
}
3087+
3088+
// Construct beams_state to send back to caller via the callback function.
3089+
// Side effect: set common_prefix_length = find_common_prefix_length();
3090+
beams_state get_beams_state(bool const last_call) {
3091+
for (size_t i=0 ; i<beams.size() ; ++i) {
3092+
beam_lengths[i] = beams[i].tokens.size();
3093+
beam_ptrs[i] = beams[i].tokens.data();
3094+
}
3095+
common_prefix_length = find_common_prefix_length();
3096+
return {beams.size(), beam_lengths.data(), beam_ptrs.data(), common_prefix_length, last_call};
3097+
}
3098+
30743099
// Loop:
3075-
// * while i < n_predict, OR
3076-
// * until all of the beams have nreached end-of-sentence, OR
3077-
// * until the highest probability beam is at end-of-sentence
3100+
// * while i < n_predict, AND
3101+
// * any of the beams have not yet reached end-of-sentence, AND
3102+
// * the highest probability beams (plural in case of ties) are not at end-of-sentence
30783103
// (since all other beam probabilities can only decrease)
30793104
void loop(llama_beam_search_callback_fn_t const callback, void* const callback_state) {
30803105
beams.push_back({{}, 1.0f}); // Start with one empty beam w/ probability = 1.0.
3081-
auto const eos = [](beam const& beam) { return beam.eos(); };
3082-
for (int i=0 ; i<n_predict && !std::all_of(beams.begin(),beams.end(),eos) && !eos(top_beam()) ; ++i) {
3083-
common_prefix_length = find_common_prefix_length();
3084-
llama_tokens_view const common_prefix{beams[0].tokens.data(), common_prefix_length};
3085-
callback(callback_state, common_prefix);
3106+
auto const not_eos = [](beam const& beam) { return !beam.eos(); };
3107+
for (int i=0 ; i<n_predict && std::any_of(beams.begin(),beams.end(),not_eos) &&
3108+
!beams[top_beam_index()].eos() ; ++i) {
3109+
beam_search_control const control = callback(callback_state, get_beams_state(false));
3110+
if (control.collapse_to < beams.size()) {
3111+
// Caller has manually selected a specific beam. Collapse beams into it.
3112+
collapse_beams(control.collapse_to);
3113+
}
3114+
if (control.stop) {
3115+
break;
3116+
}
30863117
common_prefix_evaluated = false;
30873118
for (beam& beam : beams) {
30883119
fill_next_beams_by_top_probabilities(beam);
30893120
}
30903121
beams.swap(next_beams);
30913122
renormalize_beam_probabilities(beams);
30923123
std::for_each(next_beams.begin(), next_beams.end(), [](beam& beam) { beam.p = 0.0f; });
3093-
#if 0 // DEBUG: print current beams for this iteration
3094-
std::cout << "\n\nCurrent beams:\n";
3095-
for (size_t j=0 ; j < beams.size() ; ++j) {
3096-
std::cout << "beams["<<j<<"]: " << ostream_beam{ctx,beams[j]} << std::endl;
3097-
}
3098-
#endif
30993124
}
3100-
beam& top_b = top_beam();
3101-
llama_tokens_view const top_beam_tokens{top_b.tokens.data(), top_b.tokens.size()};
3102-
callback(callback_state, top_beam_tokens);
3125+
collapse_beams(top_beam_index());
3126+
callback(callback_state, get_beams_state(true));
31033127
}
31043128

31053129
// As beams grow, the cumulative probabilities decrease.
31063130
// Renormalize them to avoid floating point underflow.
31073131
static void renormalize_beam_probabilities(std::vector<beam>& beams) {
31083132
auto const sum_p = [](float sum, beam& beam) { return sum + beam.p; };
31093133
float const inv_sum = 1.0f / std::accumulate(beams.begin(), beams.end(), 0.0f, sum_p);
3110-
std::for_each(beams.begin(), beams.end(), [inv_sum](beam& beam) { beam.p *= inv_sum; });
3134+
std::for_each(beams.begin(), beams.end(), [=](beam& beam) { beam.p *= inv_sum; });
31113135
}
31123136

3113-
// Return beam with highest probability.
3114-
beam& top_beam() {
3115-
auto const by_p = [](beam const& a, beam const& b) { return a.p < b.p; };
3116-
return *std::max_element(beams.begin(), beams.end(), by_p);
3137+
// Return index of highest ranking beam by (probability,eos()).
3138+
// In other words choose most probable beam. In case of ties, choose beam at end-of-sentence.
3139+
// Assumes beams is non-empty.
3140+
size_t top_beam_index() {
3141+
auto const by_p_and_eos = [](beam const& a, beam const& b) {
3142+
return a.p < b.p || (a.p == b.p && a.eos() < b.eos()); };
3143+
return std::max_element(beams.begin(), beams.end(), by_p_and_eos) - beams.begin();
31173144
}
31183145
};
31193146

@@ -3125,8 +3152,6 @@ void llama_beam_search(llama_context * ctx,
31253152

31263153
beam_search beam_search(ctx, beam_width, n_past, n_predict, n_threads);
31273154

3128-
// callback(callback_state, common_beam_prefix) is called on each iteration, and when
3129-
// stop condition is met with remaining tokens from beam with the highest probability.
31303155
beam_search.loop(callback, callback_state);
31313156

31323157
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;

llama.h

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -460,14 +460,32 @@ extern "C" {
460460
/// @details Accepts the sampled token into the grammar
461461
LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token);
462462

463-
struct llama_tokens_view {
464-
llama_token const* tokens;
465-
size_t size;
463+
// Passed to beam_search_callback function.
464+
// Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams
465+
// (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks.
466+
// These pointers are valid only during the synchronous callback, so should not be saved.
467+
struct beams_state {
468+
size_t n_beams; // Number of elements in beam_lengths[] and beams[].
469+
size_t const* beam_lengths; // Length of each beam.
470+
llama_token const* const* beams; // Current tokens in each beam.
471+
size_t common_prefix_length; // Current max length of prefix tokens shared by all beams.
472+
bool last_call; // True iff this is the last callback invocation.
473+
};
474+
// Must be returned by beam_search_callback function.
475+
struct beam_search_control {
476+
size_t collapse_to; // Collapse to a beam index. Ignored if n_beams <= collapse_to.
477+
bool stop; // Stop beam search. Set to false to continue.
466478
};
467-
typedef void (*llama_beam_search_callback_fn_t)(void* state, llama_tokens_view const);
479+
// Type of pointer to the beam_search_callback function.
480+
// void* callback_state is any custom data passed to llama_beam_search, that is subsequently
481+
// passed back to beam_search_callback. This avoids having to use global variables in the callback.
482+
typedef beam_search_control (*llama_beam_search_callback_fn_t)(void* callback_state, beams_state);
468483

469484
/// @details Deterministically returns entire sentence constructed by a beam search.
470485
/// @param ctx Pointer to the llama_context.
486+
/// @param callback Invoked for each iteration of the beam_search loop, passing in beams_state.
487+
/// The return beam_search_control can be used to control the beam_search execution.
488+
/// @param callback_state A pointer that is passed back to callback and nothing more.
471489
/// @param beam_width The number of parallel beams to use.
472490
/// @param n_past The number of tokens already evaluated.
473491
/// @param n_predict The maximum number of tokens to predict.

0 commit comments

Comments
 (0)