Skip to content

Commit a5a220b

Browse files
committed
Use llama_ prefix for struct names.
1 parent 97cce16 commit a5a220b

File tree

3 files changed

+41
-43
lines changed

3 files changed

+41
-43
lines changed

examples/beam_search/beam_search.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@
3030
// Used for debugging to print out beam tokens.
3131
struct ostream_beam_view {
3232
llama_context* ctx;
33-
beam_view bv;
33+
llama_beam_view beam_view;
3434
};
3535
std::ostream& operator<<(std::ostream& os, ostream_beam_view const& obv) {
36-
os << "p(" << obv.bv.p << ") eos(" << std::boolalpha << obv.bv.eos() << ") tokens(";
37-
for (size_t i=0 ; i<obv.bv.n_tokens ; ++i) {
38-
os << llama_token_to_str(obv.ctx, obv.bv.tokens[i]);
36+
os << "p(" << obv.beam_view.p << ") eos(" << std::boolalpha << obv.beam_view.eos() << ") tokens(";
37+
for (size_t i=0 ; i<obv.beam_view.n_tokens ; ++i) {
38+
os << llama_token_to_str(obv.ctx, obv.beam_view.tokens[i]);
3939
}
4040
return os << ')';
4141
}
@@ -52,7 +52,7 @@ struct beam_search_callback_state {
5252
// * When all beams converge to a common prefix, they are made available in beams_state.beams[0].
5353
// This is also called when the stop condition is met.
5454
// Collect tokens into std::vector<llama_token> response which is pointed to by callback_state.
55-
beam_search_control beam_search_callback(void* callback_state, beams_state const beams_state) {
55+
llama_beam_search_control beam_search_callback(void* callback_state, llama_beams_state const beams_state) {
5656
auto const state = *static_cast<beam_search_callback_state*>(callback_state);
5757
printf(","); // Show progress
5858
if (size_t const n = beams_state.common_prefix_length) {
@@ -69,11 +69,10 @@ beam_search_control beam_search_callback(void* callback_state, beams_state const
6969
std::cout << "beams["<<i<<"]: " << ostream_beam_view{state.ctx,beams_state.beam_views[i]} << std::endl;
7070
}
7171
#endif
72-
beam_search_control control {
72+
return llama_beam_search_control{
7373
beams_state.n_beams, // = collapse_to. Any index out of range means do not collapse beams.
7474
false // = stop. Don't stop beam search.
7575
};
76-
return control;
7776
}
7877

7978
int main(int argc, char ** argv)

llama.cpp

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
#include <ctime>
3737
#include <cinttypes>
3838
#include <fstream>
39-
#include <functional>
4039
#include <random>
4140
#include <map>
4241
#include <unordered_map>
@@ -2891,7 +2890,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
28912890
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
28922891
}
28932892

2894-
struct beam {
2893+
struct llama_beam {
28952894
std::vector<llama_token> tokens;
28962895
float p; // Cumulative beam probability (renormalized relative to all beams)
28972896
// end-of-sentence
@@ -2954,16 +2953,16 @@ struct beam_search {
29542953
int n_past;
29552954
int n_predict;
29562955
int n_threads;
2957-
std::vector<beam> beams;
2958-
std::vector<beam> next_beams;
2956+
std::vector<llama_beam> beams;
2957+
std::vector<llama_beam> next_beams;
29592958

29602959
// Re-calculated on each loop iteration
29612960
size_t common_prefix_length;
29622961
// true iff llama_eval() has been called with non-empty common prefix in current loop iteration.
29632962
bool common_prefix_evaluated;
29642963

2965-
// Temporary memory used by beams_state to pass back via callback.
2966-
std::vector<beam_view> beam_views;
2964+
// Temporary memory used by llama_beams_state to pass back via callback.
2965+
std::vector<llama_beam_view> beam_views;
29672966

29682967
beam_search(llama_context * ctx, size_t beam_width, int n_past, int n_predict, int n_threads)
29692968
: ctx(ctx)
@@ -2989,32 +2988,32 @@ struct beam_search {
29892988
// * Gather elements until the vector is full, then call std::make_heap() on it.
29902989
// * If the heap is full and a new element is found that should be included, pop the
29912990
// least element to the back(), replace it with the new, then push it into the heap.
2992-
void fill_next_beams_by_top_probabilities(beam& b) {
2991+
void fill_next_beams_by_top_probabilities(llama_beam& beam) {
29932992
// Min-heaps use a greater-than comparator.
2994-
auto const comp = [](beam const& a, beam const& b) { return a.p > b.p; };
2993+
auto const comp = [](llama_beam const& a, llama_beam const& b) { return a.p > b.p; };
29952994
if (common_prefix_evaluated) {
29962995
// llama_eval was already called during this iteration
29972996
// with the common token prefix, so shift it off this beam.
2998-
b.shift_tokens(common_prefix_length);
2997+
beam.shift_tokens(common_prefix_length);
29992998
}
3000-
if (b.eos()) {
2999+
if (beam.eos()) {
30013000
// beam is at end-of-sentence, so just copy it to next_beams if its probability is high enough.
30023001
if (next_beams.size() < beam_width) {
3003-
next_beams.push_back(std::move(b));
3002+
next_beams.push_back(std::move(beam));
30043003
if (next_beams.size() == beam_width) {
30053004
std::make_heap(next_beams.begin(), next_beams.end(), comp);
30063005
}
3007-
} else if (next_beams.front().p < b.p) {
3006+
} else if (next_beams.front().p < beam.p) {
30083007
std::pop_heap(next_beams.begin(), next_beams.end(), comp);
3009-
next_beams.back() = std::move(b);
3008+
next_beams.back() = std::move(beam);
30103009
std::push_heap(next_beams.begin(), next_beams.end(), comp);
30113010
}
30123011
} else {
30133012
// beam is not at end-of-sentence, so branch with next top_k tokens.
3014-
if (!b.tokens.empty()) {
3015-
llama_eval(ctx, b.tokens.data(), b.tokens.size(), n_past, n_threads);
3013+
if (!beam.tokens.empty()) {
3014+
llama_eval(ctx, beam.tokens.data(), beam.tokens.size(), n_past, n_threads);
30163015
if (!common_prefix_evaluated && common_prefix_length) {
3017-
b.shift_tokens(common_prefix_length);
3016+
beam.shift_tokens(common_prefix_length);
30183017
n_past += common_prefix_length;
30193018
common_prefix_evaluated = true;
30203019
}
@@ -3024,7 +3023,7 @@ struct beam_search {
30243023
size_t i=0;
30253024
if (next_beams.size() < beam_width) {
30263025
for (; next_beams.size() < beam_width ; ++i) {
3027-
beam next_beam = b;
3026+
llama_beam next_beam = beam;
30283027
next_beam.tokens.push_back(next_tokens[i].id);
30293028
next_beam.p *= logit_info.probability_from_logit(next_tokens[i].logit);
30303029
next_beams.push_back(std::move(next_beam));
@@ -3033,17 +3032,17 @@ struct beam_search {
30333032
} else {
30343033
for (; next_beams.front().p == 0.0f ; ++i) {
30353034
std::pop_heap(next_beams.begin(), next_beams.end(), comp);
3036-
next_beams.back() = b;
3035+
next_beams.back() = beam;
30373036
next_beams.back().tokens.push_back(next_tokens[i].id);
30383037
next_beams.back().p *= logit_info.probability_from_logit(next_tokens[i].logit);
30393038
std::push_heap(next_beams.begin(), next_beams.end(), comp);
30403039
}
30413040
}
30423041
for (; i < beam_width ; ++i) {
3043-
float const next_p = b.p * logit_info.probability_from_logit(next_tokens[i].logit);
3042+
float const next_p = beam.p * logit_info.probability_from_logit(next_tokens[i].logit);
30443043
if (next_beams.front().p < next_p) {
30453044
std::pop_heap(next_beams.begin(), next_beams.end(), comp);
3046-
next_beams.back() = b;
3045+
next_beams.back() = beam;
30473046
next_beams.back().tokens.push_back(next_tokens[i].id);
30483047
next_beams.back().p = next_p;
30493048
std::push_heap(next_beams.begin(), next_beams.end(), comp);
@@ -3070,9 +3069,9 @@ struct beam_search {
30703069

30713070
// Construct beams_state to send back to caller via the callback function.
30723071
// Side effect: set common_prefix_length = find_common_prefix_length();
3073-
beams_state get_beams_state(bool const last_call) {
3072+
llama_beams_state get_beams_state(bool const last_call) {
30743073
for (size_t i=0 ; i<beams.size() ; ++i) {
3075-
beam_views[i] = beam_view{beams[i].tokens.data(), beams[i].tokens.size(), beams[i].p};
3074+
beam_views[i] = llama_beam_view{beams[i].tokens.data(), beams[i].tokens.size(), beams[i].p};
30763075
}
30773076
common_prefix_length = find_common_prefix_length();
30783077
return {beam_views.data(), beams.size(), common_prefix_length, last_call};
@@ -3085,10 +3084,10 @@ struct beam_search {
30853084
// (since all other beam probabilities can only decrease)
30863085
void loop(llama_beam_search_callback_fn_t const callback, void* const callback_state) {
30873086
beams.push_back({{}, 1.0f}); // Start with one empty beam w/ probability = 1.0.
3088-
auto const not_eos = [](beam const& beam) { return !beam.eos(); };
3087+
auto const not_eos = [](llama_beam const& beam) { return !beam.eos(); };
30893088
for (int i=0 ; i<n_predict && std::any_of(beams.begin(),beams.end(),not_eos) &&
30903089
!beams[top_beam_index()].eos() ; ++i) {
3091-
beam_search_control const control = callback(callback_state, get_beams_state(false));
3090+
llama_beam_search_control const control = callback(callback_state, get_beams_state(false));
30923091
if (control.collapse_to < beams.size()) {
30933092
// Caller has manually selected a specific beam. Collapse beams into it.
30943093
collapse_beams(control.collapse_to);
@@ -3097,30 +3096,30 @@ struct beam_search {
30973096
break;
30983097
}
30993098
common_prefix_evaluated = false;
3100-
for (beam& beam : beams) {
3099+
for (llama_beam& beam : beams) {
31013100
fill_next_beams_by_top_probabilities(beam);
31023101
}
31033102
beams.swap(next_beams);
31043103
renormalize_beam_probabilities(beams);
3105-
std::for_each(next_beams.begin(), next_beams.end(), [](beam& beam) { beam.p = 0.0f; });
3104+
std::for_each(next_beams.begin(), next_beams.end(), [](llama_beam& beam) { beam.p = 0.0f; });
31063105
}
31073106
collapse_beams(top_beam_index());
31083107
callback(callback_state, get_beams_state(true));
31093108
}
31103109

31113110
// As beams grow, the cumulative probabilities decrease.
31123111
// Renormalize them to avoid floating point underflow.
3113-
static void renormalize_beam_probabilities(std::vector<beam>& beams) {
3114-
auto const sum_p = [](float sum, beam& beam) { return sum + beam.p; };
3112+
static void renormalize_beam_probabilities(std::vector<llama_beam>& beams) {
3113+
auto const sum_p = [](float sum, llama_beam& beam) { return sum + beam.p; };
31153114
float const inv_sum = 1.0f / std::accumulate(beams.begin(), beams.end(), 0.0f, sum_p);
3116-
std::for_each(beams.begin(), beams.end(), [=](beam& beam) { beam.p *= inv_sum; });
3115+
std::for_each(beams.begin(), beams.end(), [=](llama_beam& beam) { beam.p *= inv_sum; });
31173116
}
31183117

31193118
// Return index of highest ranking beam by (probability,eos()).
31203119
// In other words choose most probable beam. In case of ties, choose beam at end-of-sentence.
31213120
// Assumes beams is non-empty.
31223121
size_t top_beam_index() {
3123-
auto const by_p_and_eos = [](beam const& a, beam const& b) {
3122+
auto const by_p_and_eos = [](llama_beam const& a, llama_beam const& b) {
31243123
return a.p < b.p || (a.p == b.p && a.eos() < b.eos()); };
31253124
return std::max_element(beams.begin(), beams.end(), by_p_and_eos) - beams.begin();
31263125
}

llama.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ extern "C" {
461461
LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token);
462462

463463
// Lightweight view of a beam
464-
struct beam_view {
464+
struct llama_beam_view {
465465
llama_token const* tokens;
466466
size_t n_tokens;
467467
float p; // Cumulative beam probability (renormalized relative to all beams)
@@ -473,27 +473,27 @@ extern "C" {
473473
// Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams
474474
// (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks.
475475
// These pointers are valid only during the synchronous callback, so should not be saved.
476-
struct beams_state {
477-
beam_view* beam_views; // View of each beam.
476+
struct llama_beams_state {
477+
llama_beam_view* beam_views; // View of each beam.
478478
size_t n_beams; // Number of elements in beam_views[].
479479
size_t common_prefix_length; // Current max length of prefix tokens shared by all beams.
480480
bool last_call; // True iff this is the last callback invocation.
481481
};
482482
// Must be returned by beam_search_callback function.
483-
struct beam_search_control {
483+
struct llama_beam_search_control {
484484
size_t collapse_to; // Collapse to a beam index. Ignored if n_beams <= collapse_to.
485485
bool stop; // Stop beam search. Set to false to continue.
486486
};
487487
// Type of pointer to the beam_search_callback function.
488488
// void* callback_state is any custom data passed to llama_beam_search, that is subsequently
489489
// passed back to beam_search_callback. This avoids having to use global variables in the callback.
490-
typedef beam_search_control (*llama_beam_search_callback_fn_t)(void* callback_state, beams_state);
490+
typedef llama_beam_search_control (*llama_beam_search_callback_fn_t)(void* callback_state, llama_beams_state);
491491

492492
/// @details Deterministically returns entire sentence constructed by a beam search.
493493
/// @param ctx Pointer to the llama_context.
494494
/// @param callback Invoked for each iteration of the beam_search loop, passing in beams_state.
495495
/// The return beam_search_control can be used to control the beam_search execution.
496-
/// @param callback_state A pointer that is passed back to callback and nothing more.
496+
/// @param callback_state A pointer that is simply passed back to callback.
497497
/// @param beam_width The number of parallel beams to use.
498498
/// @param n_past The number of tokens already evaluated.
499499
/// @param n_predict The maximum number of tokens to predict.

0 commit comments

Comments
 (0)