Skip to content

Commit 7470edd

Browse files
committed
Drop struct llama_beam_search_control. Instead, callback sends/receives data to/from beam_search via llama_beams_state. EOS determination is now responsibility of callback.
1 parent a5a220b commit 7470edd

File tree

3 files changed

+47
-42
lines changed

3 files changed

+47
-42
lines changed

examples/beam_search/beam_search.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ struct ostream_beam_view {
3333
llama_beam_view beam_view;
3434
};
3535
std::ostream& operator<<(std::ostream& os, ostream_beam_view const& obv) {
36-
os << "p(" << obv.beam_view.p << ") eos(" << std::boolalpha << obv.beam_view.eos() << ") tokens(";
36+
os << "p(" << obv.beam_view.p << ") eos(" << std::boolalpha << obv.beam_view.eos << ") tokens(";
3737
for (size_t i=0 ; i<obv.beam_view.n_tokens ; ++i) {
3838
os << llama_token_to_str(obv.ctx, obv.beam_view.tokens[i]);
3939
}
@@ -46,14 +46,25 @@ struct beam_search_callback_state {
4646
std::vector<llama_token>* response;
4747
};
4848

49+
bool is_at_eos(beam_search_callback_state, llama_token const* tokens, size_t const n_tokens) {
50+
return n_tokens && tokens[n_tokens-1] == llama_token_eos();
51+
}
52+
4953
// Function matching type llama_beam_search_callback_fn_t.
5054
// Custom callback example is called each time the beams lengths increase:
5155
// * Show progress by printing ',' following by number of convergent beam tokens if any.
5256
// * When all beams converge to a common prefix, they are made available in beams_state.beams[0].
5357
// This is also called when the stop condition is met.
5458
// Collect tokens into std::vector<llama_token> response which is pointed to by callback_state.
55-
llama_beam_search_control beam_search_callback(void* callback_state, llama_beams_state const beams_state) {
59+
void beam_search_callback(void* callback_state, llama_beams_state beams_state) {
5660
auto const state = *static_cast<beam_search_callback_state*>(callback_state);
61+
// Mark beams as EOS as needed.
62+
for (size_t i=0 ; i<beams_state.n_beams ; ++i) {
63+
llama_beam_view& beam_view = beams_state.beam_views[i];
64+
if (!beam_view.eos && is_at_eos(state, beam_view.tokens, beam_view.n_tokens)) {
65+
beam_view.eos = true;
66+
}
67+
}
5768
printf(","); // Show progress
5869
if (size_t const n = beams_state.common_prefix_length) {
5970
state.response->resize(state.response->size() + n);
@@ -69,10 +80,6 @@ llama_beam_search_control beam_search_callback(void* callback_state, llama_beams
6980
std::cout << "beams["<<i<<"]: " << ostream_beam_view{state.ctx,beams_state.beam_views[i]} << std::endl;
7081
}
7182
#endif
72-
return llama_beam_search_control{
73-
beams_state.n_beams, // = collapse_to. Any index out of range means do not collapse beams.
74-
false // = stop. Don't stop beam search.
75-
};
7683
}
7784

7885
int main(int argc, char ** argv)

llama.cpp

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
#include <algorithm>
4848
#include <initializer_list>
4949
#include <thread>
50+
#include <tuple>
5051
#include <atomic>
5152
#include <mutex>
5253
#include <sstream>
@@ -2893,13 +2894,17 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
28932894
struct llama_beam {
28942895
std::vector<llama_token> tokens;
28952896
float p; // Cumulative beam probability (renormalized relative to all beams)
2896-
// end-of-sentence
2897-
bool eos() const { return !tokens.empty() && tokens.back() == llama_token_eos(); }
2897+
bool eos; // Initialize end-of-sentence to false. Callback sets this to true.
2898+
// Sort beams by probability. In case of ties, prefer beams at eos.
2899+
bool operator<(llama_beam const& rhs) const {
2900+
return std::make_tuple(p, eos) < std::make_tuple(rhs.p, rhs.eos);
2901+
}
28982902
// Shift off first n tokens and discard them.
28992903
void shift_tokens(size_t const n) {
29002904
std::copy(tokens.begin() + n, tokens.end(), tokens.begin());
29012905
tokens.resize(tokens.size() - n);
29022906
}
2907+
llama_beam_view view() const { return {tokens.data(), tokens.size(), p, eos}; }
29032908
};
29042909

29052910
// A struct for calculating logit-related info.
@@ -2961,7 +2966,7 @@ struct beam_search {
29612966
// true iff llama_eval() has been called with non-empty common prefix in current loop iteration.
29622967
bool common_prefix_evaluated;
29632968

2964-
// Temporary memory used by llama_beams_state to pass back via callback.
2969+
// Used to communicate to/from callback on beams state.
29652970
std::vector<llama_beam_view> beam_views;
29662971

29672972
beam_search(llama_context * ctx, size_t beam_width, int n_past, int n_predict, int n_threads)
@@ -2996,7 +3001,7 @@ struct beam_search {
29963001
// with the common token prefix, so shift it off this beam.
29973002
beam.shift_tokens(common_prefix_length);
29983003
}
2999-
if (beam.eos()) {
3004+
if (beam.eos) {
30003005
// beam is at end-of-sentence, so just copy it to next_beams if its probability is high enough.
30013006
if (next_beams.size() < beam_width) {
30023007
next_beams.push_back(std::move(beam));
@@ -3071,7 +3076,7 @@ struct beam_search {
30713076
// Side effect: set common_prefix_length = find_common_prefix_length();
30723077
llama_beams_state get_beams_state(bool const last_call) {
30733078
for (size_t i=0 ; i<beams.size() ; ++i) {
3074-
beam_views[i] = llama_beam_view{beams[i].tokens.data(), beams[i].tokens.size(), beams[i].p};
3079+
beam_views[i] = beams[i].view();
30753080
}
30763081
common_prefix_length = find_common_prefix_length();
30773082
return {beam_views.data(), beams.size(), common_prefix_length, last_call};
@@ -3080,28 +3085,24 @@ struct beam_search {
30803085
// Loop:
30813086
// * while i < n_predict, AND
30823087
// * any of the beams have not yet reached end-of-sentence, AND
3083-
// * the highest probability beams (plural in case of ties) are not at end-of-sentence
3088+
// * the highest probability beam(s) (plural in case of ties) are not at end-of-sentence
30843089
// (since all other beam probabilities can only decrease)
30853090
void loop(llama_beam_search_callback_fn_t const callback, void* const callback_state) {
3086-
beams.push_back({{}, 1.0f}); // Start with one empty beam w/ probability = 1.0.
3087-
auto const not_eos = [](llama_beam const& beam) { return !beam.eos(); };
3091+
beams.push_back({{}, 1.0f, false}); // Start with one empty beam w/ probability = 1.0 and !eos.
3092+
auto const not_eos = [](llama_beam const& beam) { return !beam.eos; };
30883093
for (int i=0 ; i<n_predict && std::any_of(beams.begin(),beams.end(),not_eos) &&
3089-
!beams[top_beam_index()].eos() ; ++i) {
3090-
llama_beam_search_control const control = callback(callback_state, get_beams_state(false));
3091-
if (control.collapse_to < beams.size()) {
3092-
// Caller has manually selected a specific beam. Collapse beams into it.
3093-
collapse_beams(control.collapse_to);
3094-
}
3095-
if (control.stop) {
3096-
break;
3097-
}
3098-
common_prefix_evaluated = false;
3094+
!beams[top_beam_index()].eos ; ++i) {
3095+
callback(callback_state, get_beams_state(false));
3096+
update_beams_from_beam_views(); // Update values (p,eos) that callback may have changed.
3097+
common_prefix_evaluated = false; // Any common prefix has not yet been llama_eval()ed.
3098+
// Zero-out next_beam probabilities to place them last in following min-heap.
3099+
std::for_each(next_beams.begin(), next_beams.end(), [](llama_beam& beam) { beam.p = 0.0f; });
30993100
for (llama_beam& beam : beams) {
31003101
fill_next_beams_by_top_probabilities(beam);
31013102
}
3103+
// next_beams become the beams of next/final iteration. Swap them to re-use memory.
31023104
beams.swap(next_beams);
31033105
renormalize_beam_probabilities(beams);
3104-
std::for_each(next_beams.begin(), next_beams.end(), [](llama_beam& beam) { beam.p = 0.0f; });
31053106
}
31063107
collapse_beams(top_beam_index());
31073108
callback(callback_state, get_beams_state(true));
@@ -3115,13 +3116,17 @@ struct beam_search {
31153116
std::for_each(beams.begin(), beams.end(), [=](llama_beam& beam) { beam.p *= inv_sum; });
31163117
}
31173118

3118-
// Return index of highest ranking beam by (probability,eos()).
3119-
// In other words choose most probable beam. In case of ties, choose beam at end-of-sentence.
3120-
// Assumes beams is non-empty.
3119+
// Assumes beams is non-empty. Uses llama_beam::operator<() for ordering.
31213120
size_t top_beam_index() {
3122-
auto const by_p_and_eos = [](llama_beam const& a, llama_beam const& b) {
3123-
return a.p < b.p || (a.p == b.p && a.eos() < b.eos()); };
3124-
return std::max_element(beams.begin(), beams.end(), by_p_and_eos) - beams.begin();
3121+
return std::max_element(beams.begin(), beams.end()) - beams.begin();
3122+
}
3123+
3124+
// Copy (p,eos) for each beam which may have been changed by the callback.
3125+
void update_beams_from_beam_views() {
3126+
for (size_t i=0 ; i<beams.size() ; ++i) {
3127+
beams[i].p = beam_views[i].p;
3128+
beams[i].eos = beam_views[i].eos;
3129+
}
31253130
}
31263131
};
31273132

llama.h

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -460,34 +460,27 @@ extern "C" {
460460
/// @details Accepts the sampled token into the grammar
461461
LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token);
462462

463-
// Lightweight view of a beam
464463
struct llama_beam_view {
465464
llama_token const* tokens;
466465
size_t n_tokens;
467-
float p; // Cumulative beam probability (renormalized relative to all beams)
468-
// end-of-sentence
469-
bool eos() const { return n_tokens && tokens[n_tokens-1u] == llama_token_eos(); }
466+
float p; // Cumulative beam probability (renormalized relative to all beams)
467+
bool eos; // Callback should set this to true when a beam is at end-of-sentence.
470468
};
471469

472470
// Passed to beam_search_callback function.
473471
// Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams
474472
// (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks.
475473
// These pointers are valid only during the synchronous callback, so should not be saved.
476474
struct llama_beams_state {
477-
llama_beam_view* beam_views; // View of each beam.
475+
llama_beam_view* beam_views;
478476
size_t n_beams; // Number of elements in beam_views[].
479477
size_t common_prefix_length; // Current max length of prefix tokens shared by all beams.
480478
bool last_call; // True iff this is the last callback invocation.
481479
};
482-
// Must be returned by beam_search_callback function.
483-
struct llama_beam_search_control {
484-
size_t collapse_to; // Collapse to a beam index. Ignored if n_beams <= collapse_to.
485-
bool stop; // Stop beam search. Set to false to continue.
486-
};
487480
// Type of pointer to the beam_search_callback function.
488481
// void* callback_state is any custom data passed to llama_beam_search, that is subsequently
489482
// passed back to beam_search_callback. This avoids having to use global variables in the callback.
490-
typedef llama_beam_search_control (*llama_beam_search_callback_fn_t)(void* callback_state, llama_beams_state);
483+
typedef void (*llama_beam_search_callback_fn_t)(void* callback_state, llama_beams_state);
491484

492485
/// @details Deterministically returns entire sentence constructed by a beam search.
493486
/// @param ctx Pointer to the llama_context.

0 commit comments

Comments
 (0)