@@ -2962,8 +2962,8 @@ struct beam_search {
2962
2962
int common_prefix_length;
2963
2963
// true iff llama_eval() has been called with common prefix in current loop iteration.
2964
2964
bool common_prefix_evaluated;
2965
- // Save token prefix common to all beams here
2966
- std::vector<llama_token> response ;
2965
+ // Save token prefix common to all beams. Cleared after each loop iteration.
2966
+ std::vector<llama_token> common_prefix ;
2967
2967
2968
2968
beam_search (llama_context * ctx, int beam_width, int n_past, int n_predict, int n_threads)
2969
2969
: ctx(ctx)
@@ -2990,11 +2990,11 @@ struct beam_search {
2990
2990
return common_prefix_length;
2991
2991
}
2992
2992
2993
- // Min-heaps are used to efficiently gather the top-k elements (k=beam_width).
2993
+ // Min-heaps are used to efficiently collect the top-k elements (k=beam_width).
2994
2994
// The repetative patterns below reflect the 2 stages of heaps:
2995
2995
// * Gather elements until the vector is full, then call std::make_heap() on it.
2996
- // * If the heap is full and a new element is found that should be included,
2997
- // pop off the least element , replace it with the new, then push it into the heap.
2996
+ // * If the heap is full and a new element is found that should be included, pop the
2997
+ // least element to the back() , replace it with the new, then push it into the heap.
2998
2998
void fill_next_beams_by_top_probabilities (beam& b) {
2999
2999
// Min-heaps use a greater-than comparator.
3000
3000
auto const comp = [](beam const & a, beam const & b) { return a.p > b.p ; };
@@ -3006,7 +3006,7 @@ struct beam_search {
3006
3006
if (b.eos ()) {
3007
3007
// beam is at end-of-sentence, so just copy it to next_beams if its probability is high enough.
3008
3008
if (next_beams.size () < static_cast <size_t >(beam_width)) {
3009
- next_beams.push_back (b );
3009
+ next_beams.push_back (std::move (b) );
3010
3010
if (next_beams.size () == static_cast <size_t >(beam_width)) {
3011
3011
std::make_heap (next_beams.begin (), next_beams.end (), comp);
3012
3012
}
@@ -3020,7 +3020,7 @@ struct beam_search {
3020
3020
if (!b.tokens .empty ()) {
3021
3021
llama_eval (ctx, b.tokens .data (), b.tokens .size (), n_past, n_threads);
3022
3022
if (!common_prefix_evaluated && common_prefix_length) {
3023
- b.shift_tokens (response , common_prefix_length);
3023
+ b.shift_tokens (common_prefix , common_prefix_length);
3024
3024
n_past += common_prefix_length;
3025
3025
common_prefix_evaluated = true ;
3026
3026
}
@@ -3059,12 +3059,12 @@ struct beam_search {
3059
3059
}
3060
3060
3061
3061
// Loop:
3062
- // * while i < n_predict
3063
- // * until all of the beams have nreached end-of-sentence
3062
+ // * while i < n_predict, OR
3063
+ // * until all of the beams have nreached end-of-sentence, OR
3064
3064
// * until the highest probability beam is at end-of-sentence
3065
3065
// (since all other beam probabilities can only decrease)
3066
- void loop (std::function< void (std::vector<beam>&)> const callback) {
3067
- beams.push_back ({{}, 1 .0f });
3066
+ void loop (llama_beam_search_callback_fn_t const callback, void * const callback_state ) {
3067
+ beams.push_back ({{}, 1 .0f }); // Start with one empty beam w/ probability = 1.0.
3068
3068
auto const eos = [](beam const & beam) { return beam.eos (); };
3069
3069
for (int i=0 ; i<n_predict && !std::all_of (beams.begin (),beams.end (),eos) && !eos (top_beam ()) ; ++i) {
3070
3070
common_prefix_evaluated = false ;
@@ -3075,10 +3075,14 @@ struct beam_search {
3075
3075
beams.swap (next_beams);
3076
3076
renormalize_beam_probabilities (beams);
3077
3077
std::for_each (next_beams.begin (), next_beams.end (), [](beam& beam) { beam.p = 0 .0f ; });
3078
- callback (beams);
3078
+ llama_tokens_view const common_beam_prefix{common_prefix.data (), common_prefix.size ()};
3079
+ callback (callback_state, common_beam_prefix);
3080
+ common_prefix.clear ();
3079
3081
}
3080
3082
beam& top_b = top_beam ();
3081
- top_b.shift_tokens (response, top_b.tokens .size ());
3083
+ top_b.shift_tokens (common_prefix, top_b.tokens .size ());
3084
+ llama_tokens_view const common_beam_prefix{common_prefix.data (), common_prefix.size ()};
3085
+ callback (callback_state, common_beam_prefix);
3082
3086
}
3083
3087
3084
3088
// As beams grow, the cumulative probabilities decrease.
@@ -3096,38 +3100,20 @@ struct beam_search {
3096
3100
}
3097
3101
};
3098
3102
3099
- // Not thread-safe.
3100
- const char * llama_beam_search (llama_context * ctx, int beam_width,
3101
- int n_past, int const n_predict, int const n_threads) {
3102
- static std::string beam_search_response;
3103
+ void llama_beam_search (llama_context * ctx,
3104
+ llama_beam_search_callback_fn_t callback, void * callback_state,
3105
+ int beam_width, int n_past, int const n_predict, int const n_threads) {
3103
3106
assert (ctx);
3104
3107
const int64_t t_start_sample_us = ggml_time_us ();
3105
3108
3106
3109
beam_search beam_search (ctx, beam_width, n_past, n_predict, n_threads);
3107
3110
3108
- beam_search.loop ([&](std::vector<beam>& beams) {
3109
- #if 1 // DEBUG: print current beams for this iteration
3110
- std::cout << " \n\n Current beams:\n " ;
3111
- for (size_t j=0 ; j < beams.size () ; ++j) {
3112
- std::cout << " beams[" <<j<<" ]: " ;
3113
- out_beam (std::cout, ctx, beams[j]);
3114
- std::cout << std::endl;
3115
- }
3116
- #else
3117
- std::cout << '.' << std::flush; // Show progress
3118
- #endif
3119
- });
3120
-
3121
- // Save beam sentence to beam_search_response. Is there a better way?
3122
- std::ostringstream oss;
3123
- for (llama_token const token : beam_search.response ) {
3124
- oss << llama_token_to_str (ctx, token);
3125
- }
3126
- beam_search_response = oss.str ();
3111
+ // callback(callback_state, common_beam_prefix) is called on each iteration, and when
3112
+ // stop condition is met with remaining tokens from beam with the highest probability.
3113
+ beam_search.loop (callback, callback_state);
3127
3114
3128
3115
ctx->t_sample_us += ggml_time_us () - t_start_sample_us;
3129
3116
ctx->n_sample ++;
3130
- return beam_search_response.c_str ();
3131
3117
}
3132
3118
3133
3119
//
0 commit comments