@@ -2977,8 +2977,8 @@ struct beam_search {
2977
2977
int common_prefix_length;
2978
2978
// true iff llama_eval() has been called with common prefix in current loop iteration.
2979
2979
bool common_prefix_evaluated;
2980
- // Save token prefix common to all beams here
2981
- std::vector<llama_token> response ;
2980
+ // Save token prefix common to all beams. Cleared after each loop iteration.
2981
+ std::vector<llama_token> common_prefix ;
2982
2982
2983
2983
beam_search (llama_context * ctx, int beam_width, int n_past, int n_predict, int n_threads)
2984
2984
: ctx(ctx)
@@ -3005,11 +3005,11 @@ struct beam_search {
3005
3005
return common_prefix_length;
3006
3006
}
3007
3007
3008
- // Min-heaps are used to efficiently gather the top-k elements (k=beam_width).
3008
+ // Min-heaps are used to efficiently collect the top-k elements (k=beam_width).
3009
3009
// The repetative patterns below reflect the 2 stages of heaps:
3010
3010
// * Gather elements until the vector is full, then call std::make_heap() on it.
3011
- // * If the heap is full and a new element is found that should be included,
3012
- // pop off the least element , replace it with the new, then push it into the heap.
3011
+ // * If the heap is full and a new element is found that should be included, pop the
3012
+ // least element to the back() , replace it with the new, then push it into the heap.
3013
3013
void fill_next_beams_by_top_probabilities (beam& b) {
3014
3014
// Min-heaps use a greater-than comparator.
3015
3015
auto const comp = [](beam const & a, beam const & b) { return a.p > b.p ; };
@@ -3021,7 +3021,7 @@ struct beam_search {
3021
3021
if (b.eos ()) {
3022
3022
// beam is at end-of-sentence, so just copy it to next_beams if its probability is high enough.
3023
3023
if (next_beams.size () < static_cast <size_t >(beam_width)) {
3024
- next_beams.push_back (b );
3024
+ next_beams.push_back (std::move (b) );
3025
3025
if (next_beams.size () == static_cast <size_t >(beam_width)) {
3026
3026
std::make_heap (next_beams.begin (), next_beams.end (), comp);
3027
3027
}
@@ -3035,7 +3035,7 @@ struct beam_search {
3035
3035
if (!b.tokens .empty ()) {
3036
3036
llama_eval (ctx, b.tokens .data (), b.tokens .size (), n_past, n_threads);
3037
3037
if (!common_prefix_evaluated && common_prefix_length) {
3038
- b.shift_tokens (response , common_prefix_length);
3038
+ b.shift_tokens (common_prefix , common_prefix_length);
3039
3039
n_past += common_prefix_length;
3040
3040
common_prefix_evaluated = true ;
3041
3041
}
@@ -3074,12 +3074,12 @@ struct beam_search {
3074
3074
}
3075
3075
3076
3076
// Loop:
3077
- // * while i < n_predict
3078
- // * until all of the beams have nreached end-of-sentence
3077
+ // * while i < n_predict, OR
3078
+ // * until all of the beams have nreached end-of-sentence, OR
3079
3079
// * until the highest probability beam is at end-of-sentence
3080
3080
// (since all other beam probabilities can only decrease)
3081
- void loop (std::function< void (std::vector<beam>&)> const callback) {
3082
- beams.push_back ({{}, 1 .0f });
3081
+ void loop (llama_beam_search_callback_fn_t const callback, void * const callback_state ) {
3082
+ beams.push_back ({{}, 1 .0f }); // Start with one empty beam w/ probability = 1.0.
3083
3083
auto const eos = [](beam const & beam) { return beam.eos (); };
3084
3084
for (int i=0 ; i<n_predict && !std::all_of (beams.begin (),beams.end (),eos) && !eos (top_beam ()) ; ++i) {
3085
3085
common_prefix_evaluated = false ;
@@ -3090,10 +3090,14 @@ struct beam_search {
3090
3090
beams.swap (next_beams);
3091
3091
renormalize_beam_probabilities (beams);
3092
3092
std::for_each (next_beams.begin (), next_beams.end (), [](beam& beam) { beam.p = 0 .0f ; });
3093
- callback (beams);
3093
+ llama_tokens_view const common_beam_prefix{common_prefix.data (), common_prefix.size ()};
3094
+ callback (callback_state, common_beam_prefix);
3095
+ common_prefix.clear ();
3094
3096
}
3095
3097
beam& top_b = top_beam ();
3096
- top_b.shift_tokens (response, top_b.tokens .size ());
3098
+ top_b.shift_tokens (common_prefix, top_b.tokens .size ());
3099
+ llama_tokens_view const common_beam_prefix{common_prefix.data (), common_prefix.size ()};
3100
+ callback (callback_state, common_beam_prefix);
3097
3101
}
3098
3102
3099
3103
// As beams grow, the cumulative probabilities decrease.
@@ -3111,38 +3115,20 @@ struct beam_search {
3111
3115
}
3112
3116
};
3113
3117
3114
- // Not thread-safe.
3115
- const char * llama_beam_search (llama_context * ctx, int beam_width,
3116
- int n_past, int const n_predict, int const n_threads) {
3117
- static std::string beam_search_response;
3118
+ void llama_beam_search (llama_context * ctx,
3119
+ llama_beam_search_callback_fn_t callback, void * callback_state,
3120
+ int beam_width, int n_past, int const n_predict, int const n_threads) {
3118
3121
assert (ctx);
3119
3122
const int64_t t_start_sample_us = ggml_time_us ();
3120
3123
3121
3124
beam_search beam_search (ctx, beam_width, n_past, n_predict, n_threads);
3122
3125
3123
- beam_search.loop ([&](std::vector<beam>& beams) {
3124
- #if 1 // DEBUG: print current beams for this iteration
3125
- std::cout << " \n\n Current beams:\n " ;
3126
- for (size_t j=0 ; j < beams.size () ; ++j) {
3127
- std::cout << " beams[" <<j<<" ]: " ;
3128
- out_beam (std::cout, ctx, beams[j]);
3129
- std::cout << std::endl;
3130
- }
3131
- #else
3132
- std::cout << '.' << std::flush; // Show progress
3133
- #endif
3134
- });
3135
-
3136
- // Save beam sentence to beam_search_response. Is there a better way?
3137
- std::ostringstream oss;
3138
- for (llama_token const token : beam_search.response ) {
3139
- oss << llama_token_to_str (ctx, token);
3140
- }
3141
- beam_search_response = oss.str ();
3126
+ // callback(callback_state, common_beam_prefix) is called on each iteration, and when
3127
+ // stop condition is met with remaining tokens from beam with the highest probability.
3128
+ beam_search.loop (callback, callback_state);
3142
3129
3143
3130
ctx->t_sample_us += ggml_time_us () - t_start_sample_us;
3144
3131
ctx->n_sample ++;
3145
- return beam_search_response.c_str ();
3146
3132
}
3147
3133
3148
3134
//
0 commit comments