@@ -2977,30 +2977,28 @@ struct beam_search {
2977
2977
// true iff llama_eval() has been called with non-empty common prefix in current loop iteration.
2978
2978
bool common_prefix_evaluated;
2979
2979
2980
+ // Memory used by beam_state
2981
+ std::vector<size_t > beam_lengths;
2982
+ std::vector<llama_token const *> beam_ptrs;
2983
+
2980
2984
beam_search (llama_context * ctx, int beam_width, int n_past, int n_predict, int n_threads)
2981
2985
: ctx(ctx)
2982
2986
, beam_width(beam_width)
2983
2987
, n_past(n_past)
2984
2988
, n_predict(n_predict)
2985
- , n_threads(n_threads) {
2989
+ , n_threads(n_threads)
2990
+ , beam_lengths(beam_width)
2991
+ , beam_ptrs(beam_width) {
2986
2992
beams.reserve (beam_width);
2987
2993
next_beams.reserve (beam_width);
2988
2994
}
2989
2995
2990
- // Find common_prefix_length based on beams.
2991
- // Requires beams is not empty.
2992
- size_t find_common_prefix_length () {
2993
- size_t common_prefix_length = beams[0 ].tokens .size ();
2994
- for (size_t i=1 ; i<beams.size () ; ++i) {
2995
- common_prefix_length = std::min (common_prefix_length, beams[i].tokens .size ());
2996
- for (size_t j=0 ; j<common_prefix_length ; ++j) {
2997
- if (beams[0 ].tokens [j] != beams[i].tokens [j]) {
2998
- common_prefix_length = j;
2999
- break ;
3000
- }
3001
- }
2996
+ // Collapse beams to a single beam given by index.
2997
+ void collapse_beams (size_t const beam_idx) {
2998
+ if (0u < beam_idx) {
2999
+ std::swap (beams[0 ], beams[beam_idx]);
3002
3000
}
3003
- return common_prefix_length ;
3001
+ beams. resize ( 1 ) ;
3004
3002
}
3005
3003
3006
3004
// Min-heaps are used to efficiently collect the top-k elements (k=beam_width).
@@ -3071,49 +3069,78 @@ struct beam_search {
3071
3069
}
3072
3070
}
3073
3071
3072
+ // Find common_prefix_length based on beams.
3073
+ // Requires beams is not empty.
3074
+ size_t find_common_prefix_length () {
3075
+ size_t common_prefix_length = beams[0 ].tokens .size ();
3076
+ for (size_t i=1 ; i<beams.size () ; ++i) {
3077
+ common_prefix_length = std::min (common_prefix_length, beams[i].tokens .size ());
3078
+ for (size_t j=0 ; j<common_prefix_length ; ++j) {
3079
+ if (beams[0 ].tokens [j] != beams[i].tokens [j]) {
3080
+ common_prefix_length = j;
3081
+ break ;
3082
+ }
3083
+ }
3084
+ }
3085
+ return common_prefix_length;
3086
+ }
3087
+
3088
+ // Construct beams_state to send back to caller via the callback function.
3089
+ // Side effect: set common_prefix_length = find_common_prefix_length();
3090
+ beams_state get_beams_state (bool const last_call) {
3091
+ for (size_t i=0 ; i<beams.size () ; ++i) {
3092
+ beam_lengths[i] = beams[i].tokens .size ();
3093
+ beam_ptrs[i] = beams[i].tokens .data ();
3094
+ }
3095
+ common_prefix_length = find_common_prefix_length ();
3096
+ return {beams.size (), beam_lengths.data (), beam_ptrs.data (), common_prefix_length, last_call};
3097
+ }
3098
+
3074
3099
// Loop:
3075
- // * while i < n_predict, OR
3076
- // * until all of the beams have nreached end-of-sentence, OR
3077
- // * until the highest probability beam is at end-of-sentence
3100
+ // * while i < n_predict, AND
3101
+ // * any of the beams have not yet reached end-of-sentence, AND
3102
+ // * the highest probability beams (plural in case of ties) are not at end-of-sentence
3078
3103
// (since all other beam probabilities can only decrease)
3079
3104
void loop (llama_beam_search_callback_fn_t const callback, void * const callback_state) {
3080
3105
beams.push_back ({{}, 1 .0f }); // Start with one empty beam w/ probability = 1.0.
3081
- auto const eos = [](beam const & beam) { return beam.eos (); };
3082
- for (int i=0 ; i<n_predict && !std::all_of (beams.begin (),beams.end (),eos) && !eos (top_beam ()) ; ++i) {
3083
- common_prefix_length = find_common_prefix_length ();
3084
- llama_tokens_view const common_prefix{beams[0 ].tokens .data (), common_prefix_length};
3085
- callback (callback_state, common_prefix);
3106
+ auto const not_eos = [](beam const & beam) { return !beam.eos (); };
3107
+ for (int i=0 ; i<n_predict && std::any_of (beams.begin (),beams.end (),not_eos) &&
3108
+ !beams[top_beam_index ()].eos () ; ++i) {
3109
+ beam_search_control const control = callback (callback_state, get_beams_state (false ));
3110
+ if (control.collapse_to < beams.size ()) {
3111
+ // Caller has manually selected a specific beam. Collapse beams into it.
3112
+ collapse_beams (control.collapse_to );
3113
+ }
3114
+ if (control.stop ) {
3115
+ break ;
3116
+ }
3086
3117
common_prefix_evaluated = false ;
3087
3118
for (beam& beam : beams) {
3088
3119
fill_next_beams_by_top_probabilities (beam);
3089
3120
}
3090
3121
beams.swap (next_beams);
3091
3122
renormalize_beam_probabilities (beams);
3092
3123
std::for_each (next_beams.begin (), next_beams.end (), [](beam& beam) { beam.p = 0 .0f ; });
3093
- #if 0 // DEBUG: print current beams for this iteration
3094
- std::cout << "\n\nCurrent beams:\n";
3095
- for (size_t j=0 ; j < beams.size() ; ++j) {
3096
- std::cout << "beams["<<j<<"]: " << ostream_beam{ctx,beams[j]} << std::endl;
3097
- }
3098
- #endif
3099
3124
}
3100
- beam& top_b = top_beam ();
3101
- llama_tokens_view const top_beam_tokens{top_b.tokens .data (), top_b.tokens .size ()};
3102
- callback (callback_state, top_beam_tokens);
3125
+ collapse_beams (top_beam_index ());
3126
+ callback (callback_state, get_beams_state (true ));
3103
3127
}
3104
3128
3105
3129
// As beams grow, the cumulative probabilities decrease.
3106
3130
// Renormalize them to avoid floating point underflow.
3107
3131
static void renormalize_beam_probabilities (std::vector<beam>& beams) {
3108
3132
auto const sum_p = [](float sum, beam& beam) { return sum + beam.p ; };
3109
3133
float const inv_sum = 1 .0f / std::accumulate (beams.begin (), beams.end (), 0 .0f , sum_p);
3110
- std::for_each (beams.begin (), beams.end (), [inv_sum ](beam& beam) { beam.p *= inv_sum; });
3134
+ std::for_each (beams.begin (), beams.end (), [= ](beam& beam) { beam.p *= inv_sum; });
3111
3135
}
3112
3136
3113
- // Return beam with highest probability.
3114
- beam& top_beam () {
3115
- auto const by_p = [](beam const & a, beam const & b) { return a.p < b.p ; };
3116
- return *std::max_element (beams.begin (), beams.end (), by_p);
3137
+ // Return index of highest ranking beam by (probability,eos()).
3138
+ // In other words choose most probable beam. In case of ties, choose beam at end-of-sentence.
3139
+ // Assumes beams is non-empty.
3140
+ size_t top_beam_index () {
3141
+ auto const by_p_and_eos = [](beam const & a, beam const & b) {
3142
+ return a.p < b.p || (a.p == b.p && a.eos () < b.eos ()); };
3143
+ return std::max_element (beams.begin (), beams.end (), by_p_and_eos) - beams.begin ();
3117
3144
}
3118
3145
};
3119
3146
@@ -3125,8 +3152,6 @@ void llama_beam_search(llama_context * ctx,
3125
3152
3126
3153
beam_search beam_search (ctx, beam_width, n_past, n_predict, n_threads);
3127
3154
3128
- // callback(callback_state, common_beam_prefix) is called on each iteration, and when
3129
- // stop condition is met with remaining tokens from beam with the highest probability.
3130
3155
beam_search.loop (callback, callback_state);
3131
3156
3132
3157
ctx->t_sample_us += ggml_time_us () - t_start_sample_us;
0 commit comments