@@ -2878,20 +2878,18 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
2878
2878
}
2879
2879
2880
2880
struct beam {
2881
- llama_context* ctx;
2882
2881
std::vector<llama_token> tokens;
2883
2882
float p; // Cumulative beam probability (renormalized with each token)
2884
2883
// end-of-sentence
2885
2884
bool eos () const { return !tokens.empty () && tokens.back () == llama_token_eos (); }
2886
2885
};
2887
2886
2888
- std::ostream& operator <<(std::ostream& os, beam const & b) {
2889
- os << " ctx(" << static_cast <void *>(b.ctx ) << " ) p(" << b.p
2890
- << " ) eos(" << std::boolalpha << b.eos () << " ) tokens(" ;
2891
- for (auto const token_id : b.tokens ) {
2892
- os << llama_token_to_str (b.ctx , token_id);
2887
+ void out_beam (std::ostream& os, llama_context* ctx, beam const & b) {
2888
+ os << " p(" << b.p << " ) eos(" << std::boolalpha << b.eos () << " ) tokens(" ;
2889
+ for (llama_token const token_id : b.tokens ) {
2890
+ os << llama_token_to_str (ctx, token_id);
2893
2891
}
2894
- return os << ' )' ;
2892
+ os << ' )' ;
2895
2893
}
2896
2894
2897
2895
// A struct for calculating logit-related info.
@@ -2938,8 +2936,8 @@ struct logit_info {
2938
2936
}
2939
2937
};
2940
2938
2941
- void fill_next_beams_by_top_probabilities (std::vector<beam>& next_beams, beam const & b ,
2942
- int const beam_width, int const n_past, int const n_threads) {
2939
+ void fill_next_beams_by_top_probabilities (llama_context* ctx, std::vector<beam>& next_beams,
2940
+ beam const & b, int const beam_width, int const n_past, int const n_threads) {
2943
2941
auto const comp = [](beam const & a, beam const & b) { return a.p > b.p ; };
2944
2942
if (b.eos ()) {
2945
2943
// b is at end-of-sentence, so just copy it to next_beams if its probability is high enough.
@@ -2956,9 +2954,9 @@ void fill_next_beams_by_top_probabilities(std::vector<beam>& next_beams, beam co
2956
2954
} else {
2957
2955
// b is not at end-of-sentence, so branch with next top_k tokens.
2958
2956
if (!b.tokens .empty ()) {
2959
- llama_eval (b. ctx , b.tokens .data (), b.tokens .size (), n_past, n_threads);
2957
+ llama_eval (ctx, b.tokens .data (), b.tokens .size (), n_past, n_threads);
2960
2958
}
2961
- logit_info li (b. ctx );
2959
+ logit_info li (ctx);
2962
2960
std::vector<llama_token_data> next_tokens = li.top_k (beam_width);
2963
2961
int i=0 ;
2964
2962
if (next_beams.size () < static_cast <size_t >(beam_width)) {
@@ -3001,15 +2999,14 @@ beam const& top_beam(std::vector<beam> const& beams) {
3001
2999
// fill_next_beams_by_top_probabilities() by randomly selecting from all next_beams.
3002
3000
// Not thread-safe.
3003
3001
const char * llama_beam_search (llama_context * ctx, int const beam_width,
3004
- int const n_past, int const n_predict, int const n_threads) {
3002
+ int n_past, int const n_predict, int const n_threads) {
3005
3003
static std::string beam_search_response;
3006
3004
assert (ctx);
3007
3005
const int64_t t_start_sample_us = ggml_time_us ();
3008
3006
3009
3007
std::vector<beam> beams;
3010
3008
beams.reserve (beam_width);
3011
- beams.push_back ({ctx, {}, 1.0 });
3012
- // Init next_beams with unique next token_id each.
3009
+ beams.push_back ({{}, 1.0 });
3013
3010
std::vector<beam> next_beams;
3014
3011
next_beams.reserve (beam_width);
3015
3012
// Loop while there are any beams that have not yet reached end-of-sentence.
@@ -3019,35 +3016,28 @@ const char* llama_beam_search(llama_context * ctx, int const beam_width,
3019
3016
for (int i=0 ; i<n_predict && !eos (top_beam (beams)) &&
3020
3017
!std::all_of (beams.begin (), beams.end (), eos); ++i) {
3021
3018
for (beam& b : beams) {
3022
- fill_next_beams_by_top_probabilities (next_beams, b, beam_width, n_past, n_threads);
3019
+ fill_next_beams_by_top_probabilities (ctx, next_beams, b, beam_width, n_past, n_threads);
3023
3020
}
3024
3021
beams.swap (next_beams);
3025
3022
next_beams.clear ();
3026
3023
renormalize_beam_probabilities (beams);
3027
3024
#if 1 // DEBUG: print current beams for this iteration
3028
3025
std::cout << " \n\n Current beams:\n " ;
3029
3026
for (size_t j=0 ; j < beams.size () ; ++j) {
3030
- std::cout << " beams[" <<j<<" ]: " << beams[j] << std::endl;;
3027
+ std::cout << " beams[" <<j<<" ]: " ;
3028
+ out_beam (std::cout, ctx, beams[j]);
3029
+ std::cout << std::endl;
3031
3030
}
3032
3031
#else
3033
3032
std::cout << '.' << std::flush; // Show progress
3034
3033
#endif
3035
3034
}
3036
- #if 1 // DEBUG: print final beam results
3037
- for (size_t i=0 ; i<beams.size () ; ++i) {
3038
- std::cout << " \n beams[" <<i<<" ] with p(" << beams[i].p << " ): " ;
3039
- for (llama_token const token : beams[i].tokens ) {
3040
- std::cout << llama_token_to_str (beams[i].ctx , token);
3041
- }
3042
- std::cout << std::endl;
3043
- }
3044
- #endif
3045
3035
3046
3036
beam const & top_b = top_beam (beams);
3047
3037
// Save beam sentence to beam_search_response. Is there a better way?
3048
3038
std::ostringstream oss;
3049
3039
for (llama_token const token : top_b.tokens ) {
3050
- oss << llama_token_to_str (top_b. ctx , token);
3040
+ oss << llama_token_to_str (ctx, token);
3051
3041
}
3052
3042
beam_search_response = oss.str ();
3053
3043
0 commit comments