Skip to content

Commit e156b30

Browse files
committed
Improve beam_search callback pattern by giving access to beams_state, and to return a beam_search_control struct to control execution.
1 parent 117088a commit e156b30

File tree

3 files changed

+100
-49
lines changed

3 files changed

+100
-49
lines changed

examples/beam_search/beam_search.cpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,21 +28,29 @@
2828
#endif
2929

3030

31+
// Function matching type llama_beam_search_callback_fn_t.
3132
// Custom callback example is called each time the beams lengths increase:
3233
// * Show progress by printing ',' following by number of convergent beam tokens if any.
33-
// * When all beams converge to a common prefix, they are made available in tokens_view.
34-
// This is also called when the stop condition is met, in which case the beam with the
35-
// highest probability is chosen, and its remaining tokens are available in tokens_view.
36-
// Collect them into std::vector<llama_token> response which is pointed to by callback_state.
37-
void beam_search_callback(void* callback_state, llama_tokens_view const tokens_view) {
34+
// * When all beams converge to a common prefix, they are made available in beams_state.beams[0].
35+
// This is also called when the stop condition is met.
36+
// Collect tokens into std::vector<llama_token> response which is pointed to by callback_state.
37+
beam_search_control beam_search_callback(void* callback_state, beams_state const beams_state) {
3838
printf(","); // Show progress
39-
if (size_t const n = tokens_view.size) {
39+
if (size_t const n = beams_state.common_prefix_length) {
4040
auto* response = static_cast<std::vector<llama_token>*>(callback_state);
4141
response->resize(response->size() + n);
42-
std::copy(tokens_view.tokens, tokens_view.tokens + n, response->end() - n);
42+
assert(0u < beams_state.n_beams);
43+
std::copy(beams_state.beams[0], beams_state.beams[0] + n, response->end() - n);
4344
printf("%lu", n);
4445
}
4546
fflush(stdout);
47+
#if 0 // DEBUG: print current beams for this iteration
48+
std::cout << "\n\nCurrent beams:\n";
49+
for (size_t j=0 ; j < beams.size() ; ++j) {
50+
std::cout << "beams["<<j<<"]: " << ostream_beam{ctx,beams[j]} << std::endl;
51+
}
52+
#endif
53+
return { beams_state.n_beams, false }; // Continue beam search.
4654
}
4755

4856
int main(int argc, char ** argv)

llama.cpp

Lines changed: 63 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2962,30 +2962,28 @@ struct beam_search {
29622962
// true iff llama_eval() has been called with non-empty common prefix in current loop iteration.
29632963
bool common_prefix_evaluated;
29642964

2965+
// Memory used by beam_state
2966+
std::vector<size_t> beam_lengths;
2967+
std::vector<llama_token const*> beam_ptrs;
2968+
29652969
beam_search(llama_context * ctx, int beam_width, int n_past, int n_predict, int n_threads)
29662970
: ctx(ctx)
29672971
, beam_width(beam_width)
29682972
, n_past(n_past)
29692973
, n_predict(n_predict)
2970-
, n_threads(n_threads) {
2974+
, n_threads(n_threads)
2975+
, beam_lengths(beam_width)
2976+
, beam_ptrs(beam_width) {
29712977
beams.reserve(beam_width);
29722978
next_beams.reserve(beam_width);
29732979
}
29742980

2975-
// Find common_prefix_length based on beams.
2976-
// Requires beams is not empty.
2977-
size_t find_common_prefix_length() {
2978-
size_t common_prefix_length = beams[0].tokens.size();
2979-
for (size_t i=1 ; i<beams.size() ; ++i) {
2980-
common_prefix_length = std::min(common_prefix_length, beams[i].tokens.size());
2981-
for (size_t j=0 ; j<common_prefix_length ; ++j) {
2982-
if (beams[0].tokens[j] != beams[i].tokens[j]) {
2983-
common_prefix_length = j;
2984-
break;
2985-
}
2986-
}
2981+
// Collapse beams to a single beam given by index.
2982+
void collapse_beams(size_t const beam_idx) {
2983+
if (0u < beam_idx) {
2984+
std::swap(beams[0], beams[beam_idx]);
29872985
}
2988-
return common_prefix_length;
2986+
beams.resize(1);
29892987
}
29902988

29912989
// Min-heaps are used to efficiently collect the top-k elements (k=beam_width).
@@ -3056,49 +3054,78 @@ struct beam_search {
30563054
}
30573055
}
30583056

3057+
// Find common_prefix_length based on beams.
3058+
// Requires beams is not empty.
3059+
size_t find_common_prefix_length() {
3060+
size_t common_prefix_length = beams[0].tokens.size();
3061+
for (size_t i=1 ; i<beams.size() ; ++i) {
3062+
common_prefix_length = std::min(common_prefix_length, beams[i].tokens.size());
3063+
for (size_t j=0 ; j<common_prefix_length ; ++j) {
3064+
if (beams[0].tokens[j] != beams[i].tokens[j]) {
3065+
common_prefix_length = j;
3066+
break;
3067+
}
3068+
}
3069+
}
3070+
return common_prefix_length;
3071+
}
3072+
3073+
// Construct beams_state to send back to caller via the callback function.
3074+
// Side effect: set common_prefix_length = find_common_prefix_length();
3075+
beams_state get_beams_state(bool const last_call) {
3076+
for (size_t i=0 ; i<beams.size() ; ++i) {
3077+
beam_lengths[i] = beams[i].tokens.size();
3078+
beam_ptrs[i] = beams[i].tokens.data();
3079+
}
3080+
common_prefix_length = find_common_prefix_length();
3081+
return {beams.size(), beam_lengths.data(), beam_ptrs.data(), common_prefix_length, last_call};
3082+
}
3083+
30593084
// Loop:
3060-
// * while i < n_predict, OR
3061-
// * until all of the beams have nreached end-of-sentence, OR
3062-
// * until the highest probability beam is at end-of-sentence
3085+
// * while i < n_predict, AND
3086+
// * any of the beams have not yet reached end-of-sentence, AND
3087+
// * the highest probability beams (plural in case of ties) are not at end-of-sentence
30633088
// (since all other beam probabilities can only decrease)
30643089
void loop(llama_beam_search_callback_fn_t const callback, void* const callback_state) {
30653090
beams.push_back({{}, 1.0f}); // Start with one empty beam w/ probability = 1.0.
3066-
auto const eos = [](beam const& beam) { return beam.eos(); };
3067-
for (int i=0 ; i<n_predict && !std::all_of(beams.begin(),beams.end(),eos) && !eos(top_beam()) ; ++i) {
3068-
common_prefix_length = find_common_prefix_length();
3069-
llama_tokens_view const common_prefix{beams[0].tokens.data(), common_prefix_length};
3070-
callback(callback_state, common_prefix);
3091+
auto const not_eos = [](beam const& beam) { return !beam.eos(); };
3092+
for (int i=0 ; i<n_predict && std::any_of(beams.begin(),beams.end(),not_eos) &&
3093+
!beams[top_beam_index()].eos() ; ++i) {
3094+
beam_search_control const control = callback(callback_state, get_beams_state(false));
3095+
if (control.collapse_to < beams.size()) {
3096+
// Caller has manually selected a specific beam. Collapse beams into it.
3097+
collapse_beams(control.collapse_to);
3098+
}
3099+
if (control.stop) {
3100+
break;
3101+
}
30713102
common_prefix_evaluated = false;
30723103
for (beam& beam : beams) {
30733104
fill_next_beams_by_top_probabilities(beam);
30743105
}
30753106
beams.swap(next_beams);
30763107
renormalize_beam_probabilities(beams);
30773108
std::for_each(next_beams.begin(), next_beams.end(), [](beam& beam) { beam.p = 0.0f; });
3078-
#if 0 // DEBUG: print current beams for this iteration
3079-
std::cout << "\n\nCurrent beams:\n";
3080-
for (size_t j=0 ; j < beams.size() ; ++j) {
3081-
std::cout << "beams["<<j<<"]: " << ostream_beam{ctx,beams[j]} << std::endl;
3082-
}
3083-
#endif
30843109
}
3085-
beam& top_b = top_beam();
3086-
llama_tokens_view const top_beam_tokens{top_b.tokens.data(), top_b.tokens.size()};
3087-
callback(callback_state, top_beam_tokens);
3110+
collapse_beams(top_beam_index());
3111+
callback(callback_state, get_beams_state(true));
30883112
}
30893113

30903114
// As beams grow, the cumulative probabilities decrease.
30913115
// Renormalize them to avoid floating point underflow.
30923116
static void renormalize_beam_probabilities(std::vector<beam>& beams) {
30933117
auto const sum_p = [](float sum, beam& beam) { return sum + beam.p; };
30943118
float const inv_sum = 1.0f / std::accumulate(beams.begin(), beams.end(), 0.0f, sum_p);
3095-
std::for_each(beams.begin(), beams.end(), [inv_sum](beam& beam) { beam.p *= inv_sum; });
3119+
std::for_each(beams.begin(), beams.end(), [=](beam& beam) { beam.p *= inv_sum; });
30963120
}
30973121

3098-
// Return beam with highest probability.
3099-
beam& top_beam() {
3100-
auto const by_p = [](beam const& a, beam const& b) { return a.p < b.p; };
3101-
return *std::max_element(beams.begin(), beams.end(), by_p);
3122+
// Return index of highest ranking beam by (probability,eos()).
3123+
// In other words choose most probable beam. In case of ties, choose beam at end-of-sentence.
3124+
// Assumes beams is non-empty.
3125+
size_t top_beam_index() {
3126+
auto const by_p_and_eos = [](beam const& a, beam const& b) {
3127+
return a.p < b.p || (a.p == b.p && a.eos() < b.eos()); };
3128+
return std::max_element(beams.begin(), beams.end(), by_p_and_eos) - beams.begin();
31023129
}
31033130
};
31043131

@@ -3110,8 +3137,6 @@ void llama_beam_search(llama_context * ctx,
31103137

31113138
beam_search beam_search(ctx, beam_width, n_past, n_predict, n_threads);
31123139

3113-
// callback(callback_state, common_beam_prefix) is called on each iteration, and when
3114-
// stop condition is met with remaining tokens from beam with the highest probability.
31153140
beam_search.loop(callback, callback_state);
31163141

31173142
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;

llama.h

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -443,14 +443,32 @@ extern "C" {
443443
/// @details Accepts the sampled token into the grammar
444444
LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token);
445445

446-
struct llama_tokens_view {
447-
llama_token const* tokens;
448-
size_t size;
446+
// Passed to beam_search_callback function.
447+
// Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams
448+
// (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks.
449+
// These pointers are valid only during the synchronous callback, so should not be saved.
450+
struct beams_state {
451+
size_t n_beams; // Number of elements in beam_lengths[] and beams[].
452+
size_t const* beam_lengths; // Length of each beam.
453+
llama_token const* const* beams; // Current tokens in each beam.
454+
size_t common_prefix_length; // Current max length of prefix tokens shared by all beams.
455+
bool last_call; // True iff this is the last callback invocation.
456+
};
457+
// Must be returned by beam_search_callback function.
458+
struct beam_search_control {
459+
size_t collapse_to; // Collapse to a beam index. Ignored if n_beams <= collapse_to.
460+
bool stop; // Stop beam search. Set to false to continue.
449461
};
450-
typedef void (*llama_beam_search_callback_fn_t)(void* state, llama_tokens_view const);
462+
// Type of pointer to the beam_search_callback function.
463+
// void* callback_state is any custom data passed to llama_beam_search, that is subsequently
464+
// passed back to beam_search_callback. This avoids having to use global variables in the callback.
465+
typedef beam_search_control (*llama_beam_search_callback_fn_t)(void* callback_state, beams_state);
451466

452467
/// @details Deterministically returns entire sentence constructed by a beam search.
453468
/// @param ctx Pointer to the llama_context.
469+
/// @param callback Invoked for each iteration of the beam_search loop, passing in beams_state.
470+
/// The return beam_search_control can be used to control the beam_search execution.
471+
/// @param callback_state A pointer that is passed back to callback and nothing more.
454472
/// @param beam_width The number of parallel beams to use.
455473
/// @param n_past The number of tokens already evaluated.
456474
/// @param n_predict The maximum number of tokens to predict.

0 commit comments

Comments
 (0)