36
36
#include < ctime>
37
37
#include < cinttypes>
38
38
#include < fstream>
39
- #include < functional>
40
39
#include < random>
41
40
#include < map>
42
41
#include < unordered_map>
@@ -2891,7 +2890,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
2891
2890
ctx->t_sample_us += ggml_time_us () - t_start_sample_us;
2892
2891
}
2893
2892
2894
- struct beam {
2893
+ struct llama_beam {
2895
2894
std::vector<llama_token> tokens;
2896
2895
float p; // Cumulative beam probability (renormalized relative to all beams)
2897
2896
// end-of-sentence
@@ -2954,16 +2953,16 @@ struct beam_search {
2954
2953
int n_past;
2955
2954
int n_predict;
2956
2955
int n_threads;
2957
- std::vector<beam > beams;
2958
- std::vector<beam > next_beams;
2956
+ std::vector<llama_beam > beams;
2957
+ std::vector<llama_beam > next_beams;
2959
2958
2960
2959
// Re-calculated on each loop iteration
2961
2960
size_t common_prefix_length;
2962
2961
// true iff llama_eval() has been called with non-empty common prefix in current loop iteration.
2963
2962
bool common_prefix_evaluated;
2964
2963
2965
- // Temporary memory used by beams_state to pass back via callback.
2966
- std::vector<beam_view > beam_views;
2964
+ // Temporary memory used by llama_beams_state to pass back via callback.
2965
+ std::vector<llama_beam_view > beam_views;
2967
2966
2968
2967
beam_search (llama_context * ctx, size_t beam_width, int n_past, int n_predict, int n_threads)
2969
2968
: ctx(ctx)
@@ -2989,32 +2988,32 @@ struct beam_search {
2989
2988
// * Gather elements until the vector is full, then call std::make_heap() on it.
2990
2989
// * If the heap is full and a new element is found that should be included, pop the
2991
2990
// least element to the back(), replace it with the new, then push it into the heap.
2992
- void fill_next_beams_by_top_probabilities (beam& b ) {
2991
+ void fill_next_beams_by_top_probabilities (llama_beam& beam ) {
2993
2992
// Min-heaps use a greater-than comparator.
2994
- auto const comp = [](beam const & a, beam const & b) { return a.p > b.p ; };
2993
+ auto const comp = [](llama_beam const & a, llama_beam const & b) { return a.p > b.p ; };
2995
2994
if (common_prefix_evaluated) {
2996
2995
// llama_eval was already called during this iteration
2997
2996
// with the common token prefix, so shift it off this beam.
2998
- b .shift_tokens (common_prefix_length);
2997
+ beam .shift_tokens (common_prefix_length);
2999
2998
}
3000
- if (b .eos ()) {
2999
+ if (beam .eos ()) {
3001
3000
// beam is at end-of-sentence, so just copy it to next_beams if its probability is high enough.
3002
3001
if (next_beams.size () < beam_width) {
3003
- next_beams.push_back (std::move (b ));
3002
+ next_beams.push_back (std::move (beam ));
3004
3003
if (next_beams.size () == beam_width) {
3005
3004
std::make_heap (next_beams.begin (), next_beams.end (), comp);
3006
3005
}
3007
- } else if (next_beams.front ().p < b .p ) {
3006
+ } else if (next_beams.front ().p < beam .p ) {
3008
3007
std::pop_heap (next_beams.begin (), next_beams.end (), comp);
3009
- next_beams.back () = std::move (b );
3008
+ next_beams.back () = std::move (beam );
3010
3009
std::push_heap (next_beams.begin (), next_beams.end (), comp);
3011
3010
}
3012
3011
} else {
3013
3012
// beam is not at end-of-sentence, so branch with next top_k tokens.
3014
- if (!b .tokens .empty ()) {
3015
- llama_eval (ctx, b .tokens .data (), b .tokens .size (), n_past, n_threads);
3013
+ if (!beam .tokens .empty ()) {
3014
+ llama_eval (ctx, beam .tokens .data (), beam .tokens .size (), n_past, n_threads);
3016
3015
if (!common_prefix_evaluated && common_prefix_length) {
3017
- b .shift_tokens (common_prefix_length);
3016
+ beam .shift_tokens (common_prefix_length);
3018
3017
n_past += common_prefix_length;
3019
3018
common_prefix_evaluated = true ;
3020
3019
}
@@ -3024,7 +3023,7 @@ struct beam_search {
3024
3023
size_t i=0 ;
3025
3024
if (next_beams.size () < beam_width) {
3026
3025
for (; next_beams.size () < beam_width ; ++i) {
3027
- beam next_beam = b ;
3026
+ llama_beam next_beam = beam ;
3028
3027
next_beam.tokens .push_back (next_tokens[i].id );
3029
3028
next_beam.p *= logit_info.probability_from_logit (next_tokens[i].logit );
3030
3029
next_beams.push_back (std::move (next_beam));
@@ -3033,17 +3032,17 @@ struct beam_search {
3033
3032
} else {
3034
3033
for (; next_beams.front ().p == 0 .0f ; ++i) {
3035
3034
std::pop_heap (next_beams.begin (), next_beams.end (), comp);
3036
- next_beams.back () = b ;
3035
+ next_beams.back () = beam ;
3037
3036
next_beams.back ().tokens .push_back (next_tokens[i].id );
3038
3037
next_beams.back ().p *= logit_info.probability_from_logit (next_tokens[i].logit );
3039
3038
std::push_heap (next_beams.begin (), next_beams.end (), comp);
3040
3039
}
3041
3040
}
3042
3041
for (; i < beam_width ; ++i) {
3043
- float const next_p = b .p * logit_info.probability_from_logit (next_tokens[i].logit );
3042
+ float const next_p = beam .p * logit_info.probability_from_logit (next_tokens[i].logit );
3044
3043
if (next_beams.front ().p < next_p) {
3045
3044
std::pop_heap (next_beams.begin (), next_beams.end (), comp);
3046
- next_beams.back () = b ;
3045
+ next_beams.back () = beam ;
3047
3046
next_beams.back ().tokens .push_back (next_tokens[i].id );
3048
3047
next_beams.back ().p = next_p;
3049
3048
std::push_heap (next_beams.begin (), next_beams.end (), comp);
@@ -3070,9 +3069,9 @@ struct beam_search {
3070
3069
3071
3070
// Construct beams_state to send back to caller via the callback function.
3072
3071
// Side effect: set common_prefix_length = find_common_prefix_length();
3073
- beams_state get_beams_state (bool const last_call) {
3072
+ llama_beams_state get_beams_state (bool const last_call) {
3074
3073
for (size_t i=0 ; i<beams.size () ; ++i) {
3075
- beam_views[i] = beam_view {beams[i].tokens .data (), beams[i].tokens .size (), beams[i].p };
3074
+ beam_views[i] = llama_beam_view {beams[i].tokens .data (), beams[i].tokens .size (), beams[i].p };
3076
3075
}
3077
3076
common_prefix_length = find_common_prefix_length ();
3078
3077
return {beam_views.data (), beams.size (), common_prefix_length, last_call};
@@ -3085,10 +3084,10 @@ struct beam_search {
3085
3084
// (since all other beam probabilities can only decrease)
3086
3085
void loop (llama_beam_search_callback_fn_t const callback, void * const callback_state) {
3087
3086
beams.push_back ({{}, 1 .0f }); // Start with one empty beam w/ probability = 1.0.
3088
- auto const not_eos = [](beam const & beam) { return !beam.eos (); };
3087
+ auto const not_eos = [](llama_beam const & beam) { return !beam.eos (); };
3089
3088
for (int i=0 ; i<n_predict && std::any_of (beams.begin (),beams.end (),not_eos) &&
3090
3089
!beams[top_beam_index ()].eos () ; ++i) {
3091
- beam_search_control const control = callback (callback_state, get_beams_state (false ));
3090
+ llama_beam_search_control const control = callback (callback_state, get_beams_state (false ));
3092
3091
if (control.collapse_to < beams.size ()) {
3093
3092
// Caller has manually selected a specific beam. Collapse beams into it.
3094
3093
collapse_beams (control.collapse_to );
@@ -3097,30 +3096,30 @@ struct beam_search {
3097
3096
break ;
3098
3097
}
3099
3098
common_prefix_evaluated = false ;
3100
- for (beam & beam : beams) {
3099
+ for (llama_beam & beam : beams) {
3101
3100
fill_next_beams_by_top_probabilities (beam);
3102
3101
}
3103
3102
beams.swap (next_beams);
3104
3103
renormalize_beam_probabilities (beams);
3105
- std::for_each (next_beams.begin (), next_beams.end (), [](beam & beam) { beam.p = 0 .0f ; });
3104
+ std::for_each (next_beams.begin (), next_beams.end (), [](llama_beam & beam) { beam.p = 0 .0f ; });
3106
3105
}
3107
3106
collapse_beams (top_beam_index ());
3108
3107
callback (callback_state, get_beams_state (true ));
3109
3108
}
3110
3109
3111
3110
// As beams grow, the cumulative probabilities decrease.
3112
3111
// Renormalize them to avoid floating point underflow.
3113
- static void renormalize_beam_probabilities (std::vector<beam >& beams) {
3114
- auto const sum_p = [](float sum, beam & beam) { return sum + beam.p ; };
3112
+ static void renormalize_beam_probabilities (std::vector<llama_beam >& beams) {
3113
+ auto const sum_p = [](float sum, llama_beam & beam) { return sum + beam.p ; };
3115
3114
float const inv_sum = 1 .0f / std::accumulate (beams.begin (), beams.end (), 0 .0f , sum_p);
3116
- std::for_each (beams.begin (), beams.end (), [=](beam & beam) { beam.p *= inv_sum; });
3115
+ std::for_each (beams.begin (), beams.end (), [=](llama_beam & beam) { beam.p *= inv_sum; });
3117
3116
}
3118
3117
3119
3118
// Return index of highest ranking beam by (probability,eos()).
3120
3119
// In other words choose most probable beam. In case of ties, choose beam at end-of-sentence.
3121
3120
// Assumes beams is non-empty.
3122
3121
size_t top_beam_index () {
3123
- auto const by_p_and_eos = [](beam const & a, beam const & b) {
3122
+ auto const by_p_and_eos = [](llama_beam const & a, llama_beam const & b) {
3124
3123
return a.p < b.p || (a.p == b.p && a.eos () < b.eos ()); };
3125
3124
return std::max_element (beams.begin (), beams.end (), by_p_and_eos) - beams.begin ();
3126
3125
}
0 commit comments