@@ -2883,25 +2883,24 @@ struct beam {
2883
2883
float p; // Cumulative beam probability (renormalized with each token)
2884
2884
// end-of-sentence
2885
2885
bool eos () const { return !tokens.empty () && tokens.back () == llama_token_eos (); }
2886
- // Shift off first n tokens to the end of dest.
2887
- void shift_tokens (std::vector<llama_token>& dest, int const n) {
2888
- dest.resize (dest.size () + n);
2889
- std::copy (tokens.begin (), tokens.begin () + n, dest.end () - n);
2890
- shift_tokens (n);
2891
- }
2892
2886
// Shift off first n tokens and discard them.
2893
- void shift_tokens (int const n) {
2887
+ void shift_tokens (size_t const n) {
2894
2888
std::copy (tokens.begin () + n, tokens.end (), tokens.begin ());
2895
2889
tokens.resize (tokens.size () - n);
2896
2890
}
2897
2891
};
2898
2892
2899
- void out_beam (std::ostream& os, llama_context* ctx, beam const & beam) {
2900
- os << " p(" << beam.p << " ) eos(" << std::boolalpha << beam.eos () << " ) tokens(" ;
2901
- for (llama_token const token_id : beam.tokens ) {
2902
- os << llama_token_to_str (ctx, token_id);
2893
+ // Used for debugging to print out beam tokens.
2894
+ struct ostream_beam {
2895
+ llama_context* ctx;
2896
+ beam& b;
2897
+ };
2898
+ std::ostream& operator <<(std::ostream& os, ostream_beam const & osb) {
2899
+ os << " p(" << osb.b .p << " ) eos(" << std::boolalpha << osb.b .eos () << " ) tokens(" ;
2900
+ for (llama_token const token_id : osb.b .tokens ) {
2901
+ os << llama_token_to_str (osb.ctx , token_id);
2903
2902
}
2904
- os << ' )' ;
2903
+ return os << ' )' ;
2905
2904
}
2906
2905
2907
2906
// A struct for calculating logit-related info.
@@ -2959,11 +2958,9 @@ struct beam_search {
2959
2958
std::vector<beam> next_beams;
2960
2959
2961
2960
// Re-calculated on each loop iteration
2962
- int common_prefix_length;
2963
- // true iff llama_eval() has been called with common prefix in current loop iteration.
2961
+ size_t common_prefix_length;
2962
+ // true iff llama_eval() has been called with non-empty common prefix in current loop iteration.
2964
2963
bool common_prefix_evaluated;
2965
- // Save token prefix common to all beams. Cleared after each loop iteration.
2966
- std::vector<llama_token> common_prefix;
2967
2964
2968
2965
beam_search (llama_context * ctx, int beam_width, int n_past, int n_predict, int n_threads)
2969
2966
: ctx(ctx)
@@ -2977,13 +2974,14 @@ struct beam_search {
2977
2974
2978
2975
// Find common_prefix_length based on beams.
2979
2976
// Requires beams is not empty.
2980
- int find_common_prefix_length () {
2981
- int common_prefix_length = int ( beams[0 ].tokens .size () );
2982
- for (int i=1 ; i<int ( beams.size () ) ; ++i) {
2983
- int const j_max = std::min (common_prefix_length, int ( beams[i].tokens .size () ));
2984
- for (int j=0 ; j<j_max ; ++j) {
2977
+ size_t find_common_prefix_length () {
2978
+ size_t common_prefix_length = beams[0 ].tokens .size ();
2979
+ for (size_t i=1 ; i<beams.size () ; ++i) {
2980
+ common_prefix_length = std::min (common_prefix_length, beams[i].tokens .size ());
2981
+ for (size_t j=0 ; j<common_prefix_length ; ++j) {
2985
2982
if (beams[0 ].tokens [j] != beams[i].tokens [j]) {
2986
- return j;
2983
+ common_prefix_length = j;
2984
+ break ;
2987
2985
}
2988
2986
}
2989
2987
}
@@ -3020,7 +3018,7 @@ struct beam_search {
3020
3018
if (!b.tokens .empty ()) {
3021
3019
llama_eval (ctx, b.tokens .data (), b.tokens .size (), n_past, n_threads);
3022
3020
if (!common_prefix_evaluated && common_prefix_length) {
3023
- b.shift_tokens (common_prefix, common_prefix_length);
3021
+ b.shift_tokens (common_prefix_length);
3024
3022
n_past += common_prefix_length;
3025
3023
common_prefix_evaluated = true ;
3026
3024
}
@@ -3067,22 +3065,26 @@ struct beam_search {
3067
3065
beams.push_back ({{}, 1 .0f }); // Start with one empty beam w/ probability = 1.0.
3068
3066
auto const eos = [](beam const & beam) { return beam.eos (); };
3069
3067
for (int i=0 ; i<n_predict && !std::all_of (beams.begin (),beams.end (),eos) && !eos (top_beam ()) ; ++i) {
3070
- common_prefix_evaluated = false ;
3071
3068
common_prefix_length = find_common_prefix_length ();
3069
+ llama_tokens_view const common_prefix{beams[0 ].tokens .data (), common_prefix_length};
3070
+ callback (callback_state, common_prefix);
3071
+ common_prefix_evaluated = false ;
3072
3072
for (beam& beam : beams) {
3073
3073
fill_next_beams_by_top_probabilities (beam);
3074
3074
}
3075
3075
beams.swap (next_beams);
3076
3076
renormalize_beam_probabilities (beams);
3077
3077
std::for_each (next_beams.begin (), next_beams.end (), [](beam& beam) { beam.p = 0 .0f ; });
3078
- llama_tokens_view const common_beam_prefix{common_prefix.data (), common_prefix.size ()};
3079
- callback (callback_state, common_beam_prefix);
3080
- common_prefix.clear ();
3078
+ #if 0 // DEBUG: print current beams for this iteration
3079
+ std::cout << "\n\nCurrent beams:\n";
3080
+ for (size_t j=0 ; j < beams.size() ; ++j) {
3081
+ std::cout << "beams["<<j<<"]: " << ostream_beam{ctx,beams[j]} << std::endl;
3082
+ }
3083
+ #endif
3081
3084
}
3082
3085
beam& top_b = top_beam ();
3083
- top_b.shift_tokens (common_prefix, top_b.tokens .size ());
3084
- llama_tokens_view const common_beam_prefix{common_prefix.data (), common_prefix.size ()};
3085
- callback (callback_state, common_beam_prefix);
3086
+ llama_tokens_view const top_beam_tokens{top_b.tokens .data (), top_b.tokens .size ()};
3087
+ callback (callback_state, top_beam_tokens);
3086
3088
}
3087
3089
3088
3090
// As beams grow, the cumulative probabilities decrease.
0 commit comments