Skip to content

Commit 117088a

Browse files
committed
Drop unnecessary use of std::vector<llama_token> common_prefix.
1 parent 29ec1a0 commit 117088a

File tree

1 file changed

+32
-30
lines changed

1 file changed

+32
-30
lines changed

llama.cpp

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2883,25 +2883,24 @@ struct beam {
28832883
float p; // Cumulative beam probability (renormalized with each token)
28842884
// end-of-sentence
28852885
bool eos() const { return !tokens.empty() && tokens.back() == llama_token_eos(); }
2886-
// Shift off first n tokens to the end of dest.
2887-
void shift_tokens(std::vector<llama_token>& dest, int const n) {
2888-
dest.resize(dest.size() + n);
2889-
std::copy(tokens.begin(), tokens.begin() + n, dest.end() - n);
2890-
shift_tokens(n);
2891-
}
28922886
// Shift off first n tokens and discard them.
2893-
void shift_tokens(int const n) {
2887+
void shift_tokens(size_t const n) {
28942888
std::copy(tokens.begin() + n, tokens.end(), tokens.begin());
28952889
tokens.resize(tokens.size() - n);
28962890
}
28972891
};
28982892

2899-
void out_beam(std::ostream& os, llama_context* ctx, beam const& beam) {
2900-
os << "p(" << beam.p << ") eos(" << std::boolalpha << beam.eos() << ") tokens(";
2901-
for (llama_token const token_id : beam.tokens) {
2902-
os << llama_token_to_str(ctx, token_id);
2893+
// Used for debugging to print out beam tokens.
2894+
struct ostream_beam {
2895+
llama_context* ctx;
2896+
beam& b;
2897+
};
2898+
std::ostream& operator<<(std::ostream& os, ostream_beam const& osb) {
2899+
os << "p(" << osb.b.p << ") eos(" << std::boolalpha << osb.b.eos() << ") tokens(";
2900+
for (llama_token const token_id : osb.b.tokens) {
2901+
os << llama_token_to_str(osb.ctx, token_id);
29032902
}
2904-
os << ')';
2903+
return os << ')';
29052904
}
29062905

29072906
// A struct for calculating logit-related info.
@@ -2959,11 +2958,9 @@ struct beam_search {
29592958
std::vector<beam> next_beams;
29602959

29612960
// Re-calculated on each loop iteration
2962-
int common_prefix_length;
2963-
// true iff llama_eval() has been called with common prefix in current loop iteration.
2961+
size_t common_prefix_length;
2962+
// true iff llama_eval() has been called with non-empty common prefix in current loop iteration.
29642963
bool common_prefix_evaluated;
2965-
// Save token prefix common to all beams. Cleared after each loop iteration.
2966-
std::vector<llama_token> common_prefix;
29672964

29682965
beam_search(llama_context * ctx, int beam_width, int n_past, int n_predict, int n_threads)
29692966
: ctx(ctx)
@@ -2977,13 +2974,14 @@ struct beam_search {
29772974

29782975
// Find common_prefix_length based on beams.
29792976
// Requires beams is not empty.
2980-
int find_common_prefix_length() {
2981-
int common_prefix_length = int(beams[0].tokens.size());
2982-
for (int i=1 ; i<int(beams.size()) ; ++i) {
2983-
int const j_max = std::min(common_prefix_length, int(beams[i].tokens.size()));
2984-
for (int j=0 ; j<j_max ; ++j) {
2977+
size_t find_common_prefix_length() {
2978+
size_t common_prefix_length = beams[0].tokens.size();
2979+
for (size_t i=1 ; i<beams.size() ; ++i) {
2980+
common_prefix_length = std::min(common_prefix_length, beams[i].tokens.size());
2981+
for (size_t j=0 ; j<common_prefix_length ; ++j) {
29852982
if (beams[0].tokens[j] != beams[i].tokens[j]) {
2986-
return j;
2983+
common_prefix_length = j;
2984+
break;
29872985
}
29882986
}
29892987
}
@@ -3020,7 +3018,7 @@ struct beam_search {
30203018
if (!b.tokens.empty()) {
30213019
llama_eval(ctx, b.tokens.data(), b.tokens.size(), n_past, n_threads);
30223020
if (!common_prefix_evaluated && common_prefix_length) {
3023-
b.shift_tokens(common_prefix, common_prefix_length);
3021+
b.shift_tokens(common_prefix_length);
30243022
n_past += common_prefix_length;
30253023
common_prefix_evaluated = true;
30263024
}
@@ -3067,22 +3065,26 @@ struct beam_search {
30673065
beams.push_back({{}, 1.0f}); // Start with one empty beam w/ probability = 1.0.
30683066
auto const eos = [](beam const& beam) { return beam.eos(); };
30693067
for (int i=0 ; i<n_predict && !std::all_of(beams.begin(),beams.end(),eos) && !eos(top_beam()) ; ++i) {
3070-
common_prefix_evaluated = false;
30713068
common_prefix_length = find_common_prefix_length();
3069+
llama_tokens_view const common_prefix{beams[0].tokens.data(), common_prefix_length};
3070+
callback(callback_state, common_prefix);
3071+
common_prefix_evaluated = false;
30723072
for (beam& beam : beams) {
30733073
fill_next_beams_by_top_probabilities(beam);
30743074
}
30753075
beams.swap(next_beams);
30763076
renormalize_beam_probabilities(beams);
30773077
std::for_each(next_beams.begin(), next_beams.end(), [](beam& beam) { beam.p = 0.0f; });
3078-
llama_tokens_view const common_beam_prefix{common_prefix.data(), common_prefix.size()};
3079-
callback(callback_state, common_beam_prefix);
3080-
common_prefix.clear();
3078+
#if 0 // DEBUG: print current beams for this iteration
3079+
std::cout << "\n\nCurrent beams:\n";
3080+
for (size_t j=0 ; j < beams.size() ; ++j) {
3081+
std::cout << "beams["<<j<<"]: " << ostream_beam{ctx,beams[j]} << std::endl;
3082+
}
3083+
#endif
30813084
}
30823085
beam& top_b = top_beam();
3083-
top_b.shift_tokens(common_prefix, top_b.tokens.size());
3084-
llama_tokens_view const common_beam_prefix{common_prefix.data(), common_prefix.size()};
3085-
callback(callback_state, common_beam_prefix);
3086+
llama_tokens_view const top_beam_tokens{top_b.tokens.data(), top_b.tokens.size()};
3087+
callback(callback_state, top_beam_tokens);
30863088
}
30873089

30883090
// As beams grow, the cumulative probabilities decrease.

0 commit comments

Comments
 (0)