Skip to content

Commit ef7850d

Browse files
committed
Drop unnecessary use of std::vector<llama_token> common_prefix.
1 parent ce55910 commit ef7850d

File tree

1 file changed

+32
-30
lines changed

1 file changed

+32
-30
lines changed

llama.cpp

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2898,25 +2898,24 @@ struct beam {
28982898
float p; // Cumulative beam probability (renormalized with each token)
28992899
// end-of-sentence
29002900
bool eos() const { return !tokens.empty() && tokens.back() == llama_token_eos(); }
2901-
// Shift off first n tokens to the end of dest.
2902-
void shift_tokens(std::vector<llama_token>& dest, int const n) {
2903-
dest.resize(dest.size() + n);
2904-
std::copy(tokens.begin(), tokens.begin() + n, dest.end() - n);
2905-
shift_tokens(n);
2906-
}
29072901
// Shift off first n tokens and discard them.
2908-
void shift_tokens(int const n) {
2902+
void shift_tokens(size_t const n) {
29092903
std::copy(tokens.begin() + n, tokens.end(), tokens.begin());
29102904
tokens.resize(tokens.size() - n);
29112905
}
29122906
};
29132907

2914-
void out_beam(std::ostream& os, llama_context* ctx, beam const& beam) {
2915-
os << "p(" << beam.p << ") eos(" << std::boolalpha << beam.eos() << ") tokens(";
2916-
for (llama_token const token_id : beam.tokens) {
2917-
os << llama_token_to_str(ctx, token_id);
2908+
// Used for debugging to print out beam tokens.
2909+
struct ostream_beam {
2910+
llama_context* ctx;
2911+
beam& b;
2912+
};
2913+
std::ostream& operator<<(std::ostream& os, ostream_beam const& osb) {
2914+
os << "p(" << osb.b.p << ") eos(" << std::boolalpha << osb.b.eos() << ") tokens(";
2915+
for (llama_token const token_id : osb.b.tokens) {
2916+
os << llama_token_to_str(osb.ctx, token_id);
29182917
}
2919-
os << ')';
2918+
return os << ')';
29202919
}
29212920

29222921
// A struct for calculating logit-related info.
@@ -2974,11 +2973,9 @@ struct beam_search {
29742973
std::vector<beam> next_beams;
29752974

29762975
// Re-calculated on each loop iteration
2977-
int common_prefix_length;
2978-
// true iff llama_eval() has been called with common prefix in current loop iteration.
2976+
size_t common_prefix_length;
2977+
// true iff llama_eval() has been called with non-empty common prefix in current loop iteration.
29792978
bool common_prefix_evaluated;
2980-
// Save token prefix common to all beams. Cleared after each loop iteration.
2981-
std::vector<llama_token> common_prefix;
29822979

29832980
beam_search(llama_context * ctx, int beam_width, int n_past, int n_predict, int n_threads)
29842981
: ctx(ctx)
@@ -2992,13 +2989,14 @@ struct beam_search {
29922989

29932990
// Find common_prefix_length based on beams.
29942991
// Requires beams is not empty.
2995-
int find_common_prefix_length() {
2996-
int common_prefix_length = int(beams[0].tokens.size());
2997-
for (int i=1 ; i<int(beams.size()) ; ++i) {
2998-
int const j_max = std::min(common_prefix_length, int(beams[i].tokens.size()));
2999-
for (int j=0 ; j<j_max ; ++j) {
2992+
size_t find_common_prefix_length() {
2993+
size_t common_prefix_length = beams[0].tokens.size();
2994+
for (size_t i=1 ; i<beams.size() ; ++i) {
2995+
common_prefix_length = std::min(common_prefix_length, beams[i].tokens.size());
2996+
for (size_t j=0 ; j<common_prefix_length ; ++j) {
30002997
if (beams[0].tokens[j] != beams[i].tokens[j]) {
3001-
return j;
2998+
common_prefix_length = j;
2999+
break;
30023000
}
30033001
}
30043002
}
@@ -3035,7 +3033,7 @@ struct beam_search {
30353033
if (!b.tokens.empty()) {
30363034
llama_eval(ctx, b.tokens.data(), b.tokens.size(), n_past, n_threads);
30373035
if (!common_prefix_evaluated && common_prefix_length) {
3038-
b.shift_tokens(common_prefix, common_prefix_length);
3036+
b.shift_tokens(common_prefix_length);
30393037
n_past += common_prefix_length;
30403038
common_prefix_evaluated = true;
30413039
}
@@ -3082,22 +3080,26 @@ struct beam_search {
30823080
beams.push_back({{}, 1.0f}); // Start with one empty beam w/ probability = 1.0.
30833081
auto const eos = [](beam const& beam) { return beam.eos(); };
30843082
for (int i=0 ; i<n_predict && !std::all_of(beams.begin(),beams.end(),eos) && !eos(top_beam()) ; ++i) {
3085-
common_prefix_evaluated = false;
30863083
common_prefix_length = find_common_prefix_length();
3084+
llama_tokens_view const common_prefix{beams[0].tokens.data(), common_prefix_length};
3085+
callback(callback_state, common_prefix);
3086+
common_prefix_evaluated = false;
30873087
for (beam& beam : beams) {
30883088
fill_next_beams_by_top_probabilities(beam);
30893089
}
30903090
beams.swap(next_beams);
30913091
renormalize_beam_probabilities(beams);
30923092
std::for_each(next_beams.begin(), next_beams.end(), [](beam& beam) { beam.p = 0.0f; });
3093-
llama_tokens_view const common_beam_prefix{common_prefix.data(), common_prefix.size()};
3094-
callback(callback_state, common_beam_prefix);
3095-
common_prefix.clear();
3093+
#if 0 // DEBUG: print current beams for this iteration
3094+
std::cout << "\n\nCurrent beams:\n";
3095+
for (size_t j=0 ; j < beams.size() ; ++j) {
3096+
std::cout << "beams["<<j<<"]: " << ostream_beam{ctx,beams[j]} << std::endl;
3097+
}
3098+
#endif
30963099
}
30973100
beam& top_b = top_beam();
3098-
top_b.shift_tokens(common_prefix, top_b.tokens.size());
3099-
llama_tokens_view const common_beam_prefix{common_prefix.data(), common_prefix.size()};
3100-
callback(callback_state, common_beam_prefix);
3101+
llama_tokens_view const top_beam_tokens{top_b.tokens.data(), top_b.tokens.size()};
3102+
callback(callback_state, top_beam_tokens);
31013103
}
31023104

31033105
// As beams grow, the cumulative probabilities decrease.

0 commit comments

Comments
 (0)