@@ -2962,30 +2962,28 @@ struct beam_search {
2962
2962
// true iff llama_eval() has been called with non-empty common prefix in current loop iteration.
2963
2963
bool common_prefix_evaluated;
2964
2964
2965
+ // Memory used by beam_state
2966
+ std::vector<size_t > beam_lengths;
2967
+ std::vector<llama_token const *> beam_ptrs;
2968
+
2965
2969
beam_search (llama_context * ctx, int beam_width, int n_past, int n_predict, int n_threads)
2966
2970
: ctx(ctx)
2967
2971
, beam_width(beam_width)
2968
2972
, n_past(n_past)
2969
2973
, n_predict(n_predict)
2970
- , n_threads(n_threads) {
2974
+ , n_threads(n_threads)
2975
+ , beam_lengths(beam_width)
2976
+ , beam_ptrs(beam_width) {
2971
2977
beams.reserve (beam_width);
2972
2978
next_beams.reserve (beam_width);
2973
2979
}
2974
2980
2975
- // Find common_prefix_length based on beams.
2976
- // Requires beams is not empty.
2977
- size_t find_common_prefix_length () {
2978
- size_t common_prefix_length = beams[0 ].tokens .size ();
2979
- for (size_t i=1 ; i<beams.size () ; ++i) {
2980
- common_prefix_length = std::min (common_prefix_length, beams[i].tokens .size ());
2981
- for (size_t j=0 ; j<common_prefix_length ; ++j) {
2982
- if (beams[0 ].tokens [j] != beams[i].tokens [j]) {
2983
- common_prefix_length = j;
2984
- break ;
2985
- }
2986
- }
2981
+ // Collapse beams to a single beam given by index.
2982
+ void collapse_beams (size_t const beam_idx) {
2983
+ if (0u < beam_idx) {
2984
+ std::swap (beams[0 ], beams[beam_idx]);
2987
2985
}
2988
- return common_prefix_length ;
2986
+ beams. resize ( 1 ) ;
2989
2987
}
2990
2988
2991
2989
// Min-heaps are used to efficiently collect the top-k elements (k=beam_width).
@@ -3056,49 +3054,78 @@ struct beam_search {
3056
3054
}
3057
3055
}
3058
3056
3057
+ // Find common_prefix_length based on beams.
3058
+ // Requires beams is not empty.
3059
+ size_t find_common_prefix_length () {
3060
+ size_t common_prefix_length = beams[0 ].tokens .size ();
3061
+ for (size_t i=1 ; i<beams.size () ; ++i) {
3062
+ common_prefix_length = std::min (common_prefix_length, beams[i].tokens .size ());
3063
+ for (size_t j=0 ; j<common_prefix_length ; ++j) {
3064
+ if (beams[0 ].tokens [j] != beams[i].tokens [j]) {
3065
+ common_prefix_length = j;
3066
+ break ;
3067
+ }
3068
+ }
3069
+ }
3070
+ return common_prefix_length;
3071
+ }
3072
+
3073
+ // Construct beams_state to send back to caller via the callback function.
3074
+ // Side effect: set common_prefix_length = find_common_prefix_length();
3075
+ beams_state get_beams_state (bool const last_call) {
3076
+ for (size_t i=0 ; i<beams.size () ; ++i) {
3077
+ beam_lengths[i] = beams[i].tokens .size ();
3078
+ beam_ptrs[i] = beams[i].tokens .data ();
3079
+ }
3080
+ common_prefix_length = find_common_prefix_length ();
3081
+ return {beams.size (), beam_lengths.data (), beam_ptrs.data (), common_prefix_length, last_call};
3082
+ }
3083
+
3059
3084
// Loop:
3060
- // * while i < n_predict, OR
3061
- // * until all of the beams have nreached end-of-sentence, OR
3062
- // * until the highest probability beam is at end-of-sentence
3085
+ // * while i < n_predict, AND
3086
+ // * any of the beams have not yet reached end-of-sentence, AND
3087
+ // * the highest probability beams (plural in case of ties) are not at end-of-sentence
3063
3088
// (since all other beam probabilities can only decrease)
3064
3089
void loop (llama_beam_search_callback_fn_t const callback, void * const callback_state) {
3065
3090
beams.push_back ({{}, 1 .0f }); // Start with one empty beam w/ probability = 1.0.
3066
- auto const eos = [](beam const & beam) { return beam.eos (); };
3067
- for (int i=0 ; i<n_predict && !std::all_of (beams.begin (),beams.end (),eos) && !eos (top_beam ()) ; ++i) {
3068
- common_prefix_length = find_common_prefix_length ();
3069
- llama_tokens_view const common_prefix{beams[0 ].tokens .data (), common_prefix_length};
3070
- callback (callback_state, common_prefix);
3091
+ auto const not_eos = [](beam const & beam) { return !beam.eos (); };
3092
+ for (int i=0 ; i<n_predict && std::any_of (beams.begin (),beams.end (),not_eos) &&
3093
+ !beams[top_beam_index ()].eos () ; ++i) {
3094
+ beam_search_control const control = callback (callback_state, get_beams_state (false ));
3095
+ if (control.collapse_to < beams.size ()) {
3096
+ // Caller has manually selected a specific beam. Collapse beams into it.
3097
+ collapse_beams (control.collapse_to );
3098
+ }
3099
+ if (control.stop ) {
3100
+ break ;
3101
+ }
3071
3102
common_prefix_evaluated = false ;
3072
3103
for (beam& beam : beams) {
3073
3104
fill_next_beams_by_top_probabilities (beam);
3074
3105
}
3075
3106
beams.swap (next_beams);
3076
3107
renormalize_beam_probabilities (beams);
3077
3108
std::for_each (next_beams.begin (), next_beams.end (), [](beam& beam) { beam.p = 0 .0f ; });
3078
- #if 0 // DEBUG: print current beams for this iteration
3079
- std::cout << "\n\nCurrent beams:\n";
3080
- for (size_t j=0 ; j < beams.size() ; ++j) {
3081
- std::cout << "beams["<<j<<"]: " << ostream_beam{ctx,beams[j]} << std::endl;
3082
- }
3083
- #endif
3084
3109
}
3085
- beam& top_b = top_beam ();
3086
- llama_tokens_view const top_beam_tokens{top_b.tokens .data (), top_b.tokens .size ()};
3087
- callback (callback_state, top_beam_tokens);
3110
+ collapse_beams (top_beam_index ());
3111
+ callback (callback_state, get_beams_state (true ));
3088
3112
}
3089
3113
3090
3114
// As beams grow, the cumulative probabilities decrease.
3091
3115
// Renormalize them to avoid floating point underflow.
3092
3116
static void renormalize_beam_probabilities (std::vector<beam>& beams) {
3093
3117
auto const sum_p = [](float sum, beam& beam) { return sum + beam.p ; };
3094
3118
float const inv_sum = 1 .0f / std::accumulate (beams.begin (), beams.end (), 0 .0f , sum_p);
3095
- std::for_each (beams.begin (), beams.end (), [inv_sum ](beam& beam) { beam.p *= inv_sum; });
3119
+ std::for_each (beams.begin (), beams.end (), [= ](beam& beam) { beam.p *= inv_sum; });
3096
3120
}
3097
3121
3098
- // Return beam with highest probability.
3099
- beam& top_beam () {
3100
- auto const by_p = [](beam const & a, beam const & b) { return a.p < b.p ; };
3101
- return *std::max_element (beams.begin (), beams.end (), by_p);
3122
+ // Return index of highest ranking beam by (probability,eos()).
3123
+ // In other words choose most probable beam. In case of ties, choose beam at end-of-sentence.
3124
+ // Assumes beams is non-empty.
3125
+ size_t top_beam_index () {
3126
+ auto const by_p_and_eos = [](beam const & a, beam const & b) {
3127
+ return a.p < b.p || (a.p == b.p && a.eos () < b.eos ()); };
3128
+ return std::max_element (beams.begin (), beams.end (), by_p_and_eos) - beams.begin ();
3102
3129
}
3103
3130
};
3104
3131
@@ -3110,8 +3137,6 @@ void llama_beam_search(llama_context * ctx,
3110
3137
3111
3138
beam_search beam_search (ctx, beam_width, n_past, n_predict, n_threads);
3112
3139
3113
- // callback(callback_state, common_beam_prefix) is called on each iteration, and when
3114
- // stop condition is met with remaining tokens from beam with the highest probability.
3115
3140
beam_search.loop (callback, callback_state);
3116
3141
3117
3142
ctx->t_sample_us += ggml_time_us () - t_start_sample_us;
0 commit comments