Skip to content

Commit 5522925

Browse files
committed
Improve beam_state by adding+using struct beam_view.
1 parent 10c8fd2 commit 5522925

File tree

3 files changed

+53
-39
lines changed

3 files changed

+53
-39
lines changed

examples/beam_search/beam_search.cpp

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,24 @@
2727
#include <signal.h>
2828
#endif
2929

30+
// Used for debugging to print out beam tokens.
31+
struct ostream_beam_view {
32+
llama_context* ctx;
33+
beam_view bv;
34+
};
35+
std::ostream& operator<<(std::ostream& os, ostream_beam_view const& obv) {
36+
os << "p(" << obv.bv.p << ") eos(" << std::boolalpha << obv.bv.eos() << ") tokens(";
37+
for (size_t i=0 ; i<obv.bv.n_tokens ; ++i) {
38+
os << llama_token_to_str(obv.ctx, obv.bv.tokens[i]);
39+
}
40+
return os << ')';
41+
}
42+
43+
// Put here anything you want back in beam_search_callback().
44+
struct beam_search_callback_state {
45+
llama_context* ctx;
46+
std::vector<llama_token>* response;
47+
};
3048

3149
// Function matching type llama_beam_search_callback_fn_t.
3250
// Custom callback example is called each time the beams lengths increase:
@@ -35,22 +53,27 @@
3553
// This is also called when the stop condition is met.
3654
// Collect tokens into std::vector<llama_token> response which is pointed to by callback_state.
3755
beam_search_control beam_search_callback(void* callback_state, beams_state const beams_state) {
56+
auto const state = *static_cast<beam_search_callback_state*>(callback_state);
3857
printf(","); // Show progress
3958
if (size_t const n = beams_state.common_prefix_length) {
40-
auto* response = static_cast<std::vector<llama_token>*>(callback_state);
41-
response->resize(response->size() + n);
59+
state.response->resize(state.response->size() + n);
4260
assert(0u < beams_state.n_beams);
43-
std::copy(beams_state.beams[0], beams_state.beams[0] + n, response->end() - n);
61+
llama_token const* tokens = beams_state.beam_views[0].tokens;
62+
std::copy(tokens, tokens + n, state.response->end() - n);
4463
printf("%lu", n);
4564
}
4665
fflush(stdout);
47-
#if 0 // DEBUG: print current beams for this iteration
48-
std::cout << "\n\nCurrent beams:\n";
49-
for (size_t j=0 ; j < beams.size() ; ++j) {
50-
std::cout << "beams["<<j<<"]: " << ostream_beam{ctx,beams[j]} << std::endl;
51-
}
66+
#if 1 // DEBUG: print current beams for this iteration
67+
std::cout << "\n\nCurrent beams:\n";
68+
for (size_t i=0 ; i < beams_state.n_beams ; ++i) {
69+
std::cout << "beams["<<i<<"]: " << ostream_beam_view{state.ctx,beams_state.beam_views[i]} << std::endl;
70+
}
5271
#endif
53-
return { beams_state.n_beams, false }; // Continue beam search.
72+
beam_search_control control {
73+
beams_state.n_beams, // = collapse_to. Any index out of range means do not collapse beams.
74+
false // = stop. Don't stop beam search.
75+
};
76+
return control;
5477
}
5578

5679
int main(int argc, char ** argv)
@@ -140,9 +163,10 @@ int main(int argc, char ** argv)
140163
n_past += tokens_list.size();
141164

142165
std::vector<llama_token> response;
166+
beam_search_callback_state callback_state{ctx, &response};
143167
size_t const beam_width = static_cast<size_t>(params.n_beams);
144168
int const n_predict = 256;
145-
llama_beam_search(ctx, beam_search_callback, &response, beam_width, n_past, n_predict, params.n_threads);
169+
llama_beam_search(ctx, beam_search_callback, &callback_state, beam_width, n_past, n_predict, params.n_threads);
146170

147171
printf("\n\n");
148172
for (llama_token const token_id : response) {

llama.cpp

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,6 @@
5353
#include <sstream>
5454
#include <numeric>
5555

56-
#include <iostream>
57-
5856
#if defined(_MSC_VER)
5957
#pragma warning(disable: 4244 4267) // possible loss of data
6058
#endif
@@ -2895,7 +2893,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
28952893

28962894
struct beam {
28972895
std::vector<llama_token> tokens;
2898-
float p; // Cumulative beam probability (renormalized with each token)
2896+
float p; // Cumulative beam probability (renormalized relative to all beams)
28992897
// end-of-sentence
29002898
bool eos() const { return !tokens.empty() && tokens.back() == llama_token_eos(); }
29012899
// Shift off first n tokens and discard them.
@@ -2905,19 +2903,6 @@ struct beam {
29052903
}
29062904
};
29072905

2908-
// Used for debugging to print out beam tokens.
2909-
struct ostream_beam {
2910-
llama_context* ctx;
2911-
beam& b;
2912-
};
2913-
std::ostream& operator<<(std::ostream& os, ostream_beam const& osb) {
2914-
os << "p(" << osb.b.p << ") eos(" << std::boolalpha << osb.b.eos() << ") tokens(";
2915-
for (llama_token const token_id : osb.b.tokens) {
2916-
os << llama_token_to_str(osb.ctx, token_id);
2917-
}
2918-
return os << ')';
2919-
}
2920-
29212906
// A struct for calculating logit-related info.
29222907
struct logit_info {
29232908
float const* const logits;
@@ -2977,18 +2962,16 @@ struct beam_search {
29772962
// true iff llama_eval() has been called with non-empty common prefix in current loop iteration.
29782963
bool common_prefix_evaluated;
29792964

2980-
// Memory used by beam_state
2981-
std::vector<size_t> beam_lengths;
2982-
std::vector<llama_token const*> beam_ptrs;
2965+
// Temporary memory used by beams_state to pass back via callback.
2966+
std::vector<beam_view> beam_views;
29832967

29842968
beam_search(llama_context * ctx, size_t beam_width, int n_past, int n_predict, int n_threads)
29852969
: ctx(ctx)
29862970
, beam_width(beam_width)
29872971
, n_past(n_past)
29882972
, n_predict(n_predict)
29892973
, n_threads(n_threads)
2990-
, beam_lengths(beam_width)
2991-
, beam_ptrs(beam_width) {
2974+
, beam_views(beam_width) {
29922975
beams.reserve(beam_width);
29932976
next_beams.reserve(beam_width);
29942977
}
@@ -3089,11 +3072,10 @@ struct beam_search {
30893072
// Side effect: set common_prefix_length = find_common_prefix_length();
30903073
beams_state get_beams_state(bool const last_call) {
30913074
for (size_t i=0 ; i<beams.size() ; ++i) {
3092-
beam_lengths[i] = beams[i].tokens.size();
3093-
beam_ptrs[i] = beams[i].tokens.data();
3075+
beam_views[i] = beam_view{beams[i].tokens.data(), beams[i].tokens.size(), beams[i].p};
30943076
}
30953077
common_prefix_length = find_common_prefix_length();
3096-
return {beams.size(), beam_lengths.data(), beam_ptrs.data(), common_prefix_length, last_call};
3078+
return {beam_views.data(), beams.size(), common_prefix_length, last_call};
30973079
}
30983080

30993081
// Loop:

llama.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -460,16 +460,24 @@ extern "C" {
460460
/// @details Accepts the sampled token into the grammar
461461
LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token);
462462

463+
// Lightweight view of a beam
464+
struct beam_view {
465+
llama_token const* tokens;
466+
size_t n_tokens;
467+
float p; // Cumulative beam probability (renormalized relative to all beams)
468+
// end-of-sentence
469+
bool eos() const { return n_tokens && tokens[n_tokens-1u] == llama_token_eos(); }
470+
};
471+
463472
// Passed to beam_search_callback function.
464473
// Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams
465474
// (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks.
466475
// These pointers are valid only during the synchronous callback, so should not be saved.
467476
struct beams_state {
468-
size_t n_beams; // Number of elements in beam_lengths[] and beams[].
469-
size_t const* beam_lengths; // Length of each beam.
470-
llama_token const* const* beams; // Current tokens in each beam.
471-
size_t common_prefix_length; // Current max length of prefix tokens shared by all beams.
472-
bool last_call; // True iff this is the last callback invocation.
477+
beam_view* beam_views; // View of each beam.
478+
size_t n_beams; // Number of elements in beam_views[].
479+
size_t common_prefix_length; // Current max length of prefix tokens shared by all beams.
480+
bool last_call; // True iff this is the last callback invocation.
473481
};
474482
// Must be returned by beam_search_callback function.
475483
struct beam_search_control {

0 commit comments

Comments
 (0)