47
47
#include < algorithm>
48
48
#include < initializer_list>
49
49
#include < thread>
50
+ #include < tuple>
50
51
#include < atomic>
51
52
#include < mutex>
52
53
#include < sstream>
@@ -2893,13 +2894,17 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
2893
2894
struct llama_beam {
2894
2895
std::vector<llama_token> tokens;
2895
2896
float p; // Cumulative beam probability (renormalized relative to all beams)
2896
- // end-of-sentence
2897
- bool eos () const { return !tokens.empty () && tokens.back () == llama_token_eos (); }
2897
+ bool eos; // Initialize end-of-sentence to false. Callback sets this to true.
2898
+ // Sort beams by probability. In case of ties, prefer beams at eos.
2899
+ bool operator <(llama_beam const & rhs) const {
2900
+ return std::make_tuple (p, eos) < std::make_tuple (rhs.p , rhs.eos );
2901
+ }
2898
2902
// Shift off first n tokens and discard them.
2899
2903
void shift_tokens (size_t const n) {
2900
2904
std::copy (tokens.begin () + n, tokens.end (), tokens.begin ());
2901
2905
tokens.resize (tokens.size () - n);
2902
2906
}
2907
+ llama_beam_view view () const { return {tokens.data (), tokens.size (), p, eos}; }
2903
2908
};
2904
2909
2905
2910
// A struct for calculating logit-related info.
@@ -2961,7 +2966,7 @@ struct beam_search {
2961
2966
// true iff llama_eval() has been called with non-empty common prefix in current loop iteration.
2962
2967
bool common_prefix_evaluated;
2963
2968
2964
- // Temporary memory used by llama_beams_state to pass back via callback .
2969
+ // Used to communicate to/from callback on beams state .
2965
2970
std::vector<llama_beam_view> beam_views;
2966
2971
2967
2972
beam_search (llama_context * ctx, size_t beam_width, int n_past, int n_predict, int n_threads)
@@ -2996,7 +3001,7 @@ struct beam_search {
2996
3001
// with the common token prefix, so shift it off this beam.
2997
3002
beam.shift_tokens (common_prefix_length);
2998
3003
}
2999
- if (beam.eos () ) {
3004
+ if (beam.eos ) {
3000
3005
// beam is at end-of-sentence, so just copy it to next_beams if its probability is high enough.
3001
3006
if (next_beams.size () < beam_width) {
3002
3007
next_beams.push_back (std::move (beam));
@@ -3071,7 +3076,7 @@ struct beam_search {
3071
3076
// Side effect: set common_prefix_length = find_common_prefix_length();
3072
3077
llama_beams_state get_beams_state (bool const last_call) {
3073
3078
for (size_t i=0 ; i<beams.size () ; ++i) {
3074
- beam_views[i] = llama_beam_view{ beams[i].tokens . data (), beams[i]. tokens . size (), beams[i]. p } ;
3079
+ beam_views[i] = beams[i].view () ;
3075
3080
}
3076
3081
common_prefix_length = find_common_prefix_length ();
3077
3082
return {beam_views.data (), beams.size (), common_prefix_length, last_call};
@@ -3080,28 +3085,24 @@ struct beam_search {
3080
3085
// Loop:
3081
3086
// * while i < n_predict, AND
3082
3087
// * any of the beams have not yet reached end-of-sentence, AND
3083
- // * the highest probability beams (plural in case of ties) are not at end-of-sentence
3088
+ // * the highest probability beam(s) (plural in case of ties) are not at end-of-sentence
3084
3089
// (since all other beam probabilities can only decrease)
3085
3090
void loop (llama_beam_search_callback_fn_t const callback, void * const callback_state) {
3086
- beams.push_back ({{}, 1 .0f }); // Start with one empty beam w/ probability = 1.0.
3087
- auto const not_eos = [](llama_beam const & beam) { return !beam.eos () ; };
3091
+ beams.push_back ({{}, 1 .0f , false }); // Start with one empty beam w/ probability = 1.0 and !eos .
3092
+ auto const not_eos = [](llama_beam const & beam) { return !beam.eos ; };
3088
3093
for (int i=0 ; i<n_predict && std::any_of (beams.begin (),beams.end (),not_eos) &&
3089
- !beams[top_beam_index ()].eos () ; ++i) {
3090
- llama_beam_search_control const control = callback (callback_state, get_beams_state (false ));
3091
- if (control.collapse_to < beams.size ()) {
3092
- // Caller has manually selected a specific beam. Collapse beams into it.
3093
- collapse_beams (control.collapse_to );
3094
- }
3095
- if (control.stop ) {
3096
- break ;
3097
- }
3098
- common_prefix_evaluated = false ;
3094
+ !beams[top_beam_index ()].eos ; ++i) {
3095
+ callback (callback_state, get_beams_state (false ));
3096
+ update_beams_from_beam_views (); // Update values (p,eos) that callback may have changed.
3097
+ common_prefix_evaluated = false ; // Any common prefix has not yet been llama_eval()ed.
3098
+ // Zero-out next_beam probabilities to place them last in following min-heap.
3099
+ std::for_each (next_beams.begin (), next_beams.end (), [](llama_beam& beam) { beam.p = 0 .0f ; });
3099
3100
for (llama_beam& beam : beams) {
3100
3101
fill_next_beams_by_top_probabilities (beam);
3101
3102
}
3103
+ // next_beams become the beams of next/final iteration. Swap them to re-use memory.
3102
3104
beams.swap (next_beams);
3103
3105
renormalize_beam_probabilities (beams);
3104
- std::for_each (next_beams.begin (), next_beams.end (), [](llama_beam& beam) { beam.p = 0 .0f ; });
3105
3106
}
3106
3107
collapse_beams (top_beam_index ());
3107
3108
callback (callback_state, get_beams_state (true ));
@@ -3115,13 +3116,17 @@ struct beam_search {
3115
3116
std::for_each (beams.begin (), beams.end (), [=](llama_beam& beam) { beam.p *= inv_sum; });
3116
3117
}
3117
3118
3118
- // Return index of highest ranking beam by (probability,eos()).
3119
- // In other words choose most probable beam. In case of ties, choose beam at end-of-sentence.
3120
- // Assumes beams is non-empty.
3119
+ // Assumes beams is non-empty. Uses llama_beam::operator<() for ordering.
3121
3120
size_t top_beam_index () {
3122
- auto const by_p_and_eos = [](llama_beam const & a, llama_beam const & b) {
3123
- return a.p < b.p || (a.p == b.p && a.eos () < b.eos ()); };
3124
- return std::max_element (beams.begin (), beams.end (), by_p_and_eos) - beams.begin ();
3121
+ return std::max_element (beams.begin (), beams.end ()) - beams.begin ();
3122
+ }
3123
+
3124
+ // Copy (p,eos) for each beam which may have been changed by the callback.
3125
+ void update_beams_from_beam_views () {
3126
+ for (size_t i=0 ; i<beams.size () ; ++i) {
3127
+ beams[i].p = beam_views[i].p ;
3128
+ beams[i].eos = beam_views[i].eos ;
3129
+ }
3125
3130
}
3126
3131
};
3127
3132
0 commit comments