@@ -2898,25 +2898,24 @@ struct beam {
2898
2898
float p; // Cumulative beam probability (renormalized with each token)
2899
2899
// end-of-sentence
2900
2900
bool eos () const { return !tokens.empty () && tokens.back () == llama_token_eos (); }
2901
- // Shift off first n tokens to the end of dest.
2902
- void shift_tokens (std::vector<llama_token>& dest, int const n) {
2903
- dest.resize (dest.size () + n);
2904
- std::copy (tokens.begin (), tokens.begin () + n, dest.end () - n);
2905
- shift_tokens (n);
2906
- }
2907
2901
// Shift off first n tokens and discard them.
2908
- void shift_tokens (int const n) {
2902
+ void shift_tokens (size_t const n) {
2909
2903
std::copy (tokens.begin () + n, tokens.end (), tokens.begin ());
2910
2904
tokens.resize (tokens.size () - n);
2911
2905
}
2912
2906
};
2913
2907
2914
- void out_beam (std::ostream& os, llama_context* ctx, beam const & beam) {
2915
- os << " p(" << beam.p << " ) eos(" << std::boolalpha << beam.eos () << " ) tokens(" ;
2916
- for (llama_token const token_id : beam.tokens ) {
2917
- os << llama_token_to_str (ctx, token_id);
2908
+ // Used for debugging to print out beam tokens.
2909
+ struct ostream_beam {
2910
+ llama_context* ctx;
2911
+ beam& b;
2912
+ };
2913
+ std::ostream& operator <<(std::ostream& os, ostream_beam const & osb) {
2914
+ os << " p(" << osb.b .p << " ) eos(" << std::boolalpha << osb.b .eos () << " ) tokens(" ;
2915
+ for (llama_token const token_id : osb.b .tokens ) {
2916
+ os << llama_token_to_str (osb.ctx , token_id);
2918
2917
}
2919
- os << ' )' ;
2918
+ return os << ' )' ;
2920
2919
}
2921
2920
2922
2921
// A struct for calculating logit-related info.
@@ -2974,11 +2973,9 @@ struct beam_search {
2974
2973
std::vector<beam> next_beams;
2975
2974
2976
2975
// Re-calculated on each loop iteration
2977
- int common_prefix_length;
2978
- // true iff llama_eval() has been called with common prefix in current loop iteration.
2976
+ size_t common_prefix_length;
2977
+ // true iff llama_eval() has been called with non-empty common prefix in current loop iteration.
2979
2978
bool common_prefix_evaluated;
2980
- // Save token prefix common to all beams. Cleared after each loop iteration.
2981
- std::vector<llama_token> common_prefix;
2982
2979
2983
2980
beam_search (llama_context * ctx, int beam_width, int n_past, int n_predict, int n_threads)
2984
2981
: ctx(ctx)
@@ -2992,13 +2989,14 @@ struct beam_search {
2992
2989
2993
2990
// Find common_prefix_length based on beams.
2994
2991
// Requires beams is not empty.
2995
- int find_common_prefix_length () {
2996
- int common_prefix_length = int ( beams[0 ].tokens .size () );
2997
- for (int i=1 ; i<int ( beams.size () ) ; ++i) {
2998
- int const j_max = std::min (common_prefix_length, int ( beams[i].tokens .size () ));
2999
- for (int j=0 ; j<j_max ; ++j) {
2992
+ size_t find_common_prefix_length () {
2993
+ size_t common_prefix_length = beams[0 ].tokens .size ();
2994
+ for (size_t i=1 ; i<beams.size () ; ++i) {
2995
+ common_prefix_length = std::min (common_prefix_length, beams[i].tokens .size ());
2996
+ for (size_t j=0 ; j<common_prefix_length ; ++j) {
3000
2997
if (beams[0 ].tokens [j] != beams[i].tokens [j]) {
3001
- return j;
2998
+ common_prefix_length = j;
2999
+ break ;
3002
3000
}
3003
3001
}
3004
3002
}
@@ -3035,7 +3033,7 @@ struct beam_search {
3035
3033
if (!b.tokens .empty ()) {
3036
3034
llama_eval (ctx, b.tokens .data (), b.tokens .size (), n_past, n_threads);
3037
3035
if (!common_prefix_evaluated && common_prefix_length) {
3038
- b.shift_tokens (common_prefix, common_prefix_length);
3036
+ b.shift_tokens (common_prefix_length);
3039
3037
n_past += common_prefix_length;
3040
3038
common_prefix_evaluated = true ;
3041
3039
}
@@ -3082,22 +3080,26 @@ struct beam_search {
3082
3080
beams.push_back ({{}, 1 .0f }); // Start with one empty beam w/ probability = 1.0.
3083
3081
auto const eos = [](beam const & beam) { return beam.eos (); };
3084
3082
for (int i=0 ; i<n_predict && !std::all_of (beams.begin (),beams.end (),eos) && !eos (top_beam ()) ; ++i) {
3085
- common_prefix_evaluated = false ;
3086
3083
common_prefix_length = find_common_prefix_length ();
3084
+ llama_tokens_view const common_prefix{beams[0 ].tokens .data (), common_prefix_length};
3085
+ callback (callback_state, common_prefix);
3086
+ common_prefix_evaluated = false ;
3087
3087
for (beam& beam : beams) {
3088
3088
fill_next_beams_by_top_probabilities (beam);
3089
3089
}
3090
3090
beams.swap (next_beams);
3091
3091
renormalize_beam_probabilities (beams);
3092
3092
std::for_each (next_beams.begin (), next_beams.end (), [](beam& beam) { beam.p = 0 .0f ; });
3093
- llama_tokens_view const common_beam_prefix{common_prefix.data (), common_prefix.size ()};
3094
- callback (callback_state, common_beam_prefix);
3095
- common_prefix.clear ();
3093
+ #if 0 // DEBUG: print current beams for this iteration
3094
+ std::cout << "\n\nCurrent beams:\n";
3095
+ for (size_t j=0 ; j < beams.size() ; ++j) {
3096
+ std::cout << "beams["<<j<<"]: " << ostream_beam{ctx,beams[j]} << std::endl;
3097
+ }
3098
+ #endif
3096
3099
}
3097
3100
beam& top_b = top_beam ();
3098
- top_b.shift_tokens (common_prefix, top_b.tokens .size ());
3099
- llama_tokens_view const common_beam_prefix{common_prefix.data (), common_prefix.size ()};
3100
- callback (callback_state, common_beam_prefix);
3101
+ llama_tokens_view const top_beam_tokens{top_b.tokens .data (), top_b.tokens .size ()};
3102
+ callback (callback_state, top_beam_tokens);
3101
3103
}
3102
3104
3103
3105
// As beams grow, the cumulative probabilities decrease.
0 commit comments