Skip to content

Commit 5fa1ea2

Browse files
committed
Change eos to eob in llama_beam and llama_beam_view structs.
1 parent b619cfc commit 5fa1ea2

File tree

4 files changed

+24
-22
lines changed

4 files changed

+24
-22
lines changed

examples/beam_search/beam_search.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ struct ostream_beam_view {
3333
llama_beam_view beam_view;
3434
};
3535
std::ostream& operator<<(std::ostream& os, const ostream_beam_view & obv) {
36-
os << "p(" << obv.beam_view.p << ") eos(" << std::boolalpha << obv.beam_view.eos << ") tokens(";
36+
os << "p(" << obv.beam_view.p << ") eob(" << std::boolalpha << obv.beam_view.eob << ") tokens(";
3737
for (size_t i = 0 ; i < obv.beam_view.n_tokens ; ++i) {
3838
os << llama_token_to_str(obv.ctx, obv.beam_view.tokens[i]);
3939
}
@@ -46,7 +46,9 @@ struct beam_search_callback_data {
4646
std::vector<llama_token> response;
4747
};
4848

49-
bool is_at_eos(const beam_search_callback_data & callback_data, const llama_token * tokens, const size_t n_tokens) {
49+
// In this case, end-of-beam (eob) is equivalent to end-of-sentence (eos) but this need not always be the same.
50+
// For example, eob can be flagged due to maximum token length, stop words, etc.
51+
bool is_at_eob(const beam_search_callback_data & callback_data, const llama_token * tokens, const size_t n_tokens) {
5052
return n_tokens && tokens[n_tokens-1] == llama_token_eos(callback_data.ctx);
5153
}
5254

@@ -61,8 +63,8 @@ void beam_search_callback(void * callback_data_ptr, llama_beams_state beams_stat
6163
// Mark beams as EOS as needed.
6264
for (size_t i = 0 ; i < beams_state.n_beams ; ++i) {
6365
llama_beam_view& beam_view = beams_state.beam_views[i];
64-
if (!beam_view.eos && is_at_eos(callback_data, beam_view.tokens, beam_view.n_tokens)) {
65-
beam_view.eos = true;
66+
if (!beam_view.eob && is_at_eob(callback_data, beam_view.tokens, beam_view.n_tokens)) {
67+
beam_view.eob = true;
6668
}
6769
}
6870
printf(","); // Show progress

examples/server/server.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,7 +1209,7 @@ static void log_server_request(const Request &req, const Response &res)
12091209
});
12101210
}
12111211

1212-
bool is_at_eos(llama_server_context & server_context, const llama_token * tokens, const size_t n_tokens) {
1212+
bool is_at_eob(llama_server_context & server_context, const llama_token * tokens, const size_t n_tokens) {
12131213
return n_tokens && tokens[n_tokens-1] == llama_token_eos(server_context.ctx);
12141214
}
12151215

@@ -1223,9 +1223,9 @@ void beam_search_callback(void * callback_data, llama_beams_state beams_state) {
12231223
auto & llama = *static_cast<llama_server_context*>(callback_data);
12241224
// Mark beams as EOS as needed.
12251225
for (size_t i = 0 ; i < beams_state.n_beams ; ++i) {
1226-
llama_beam_view & beam_view = beams_state.beam_views[i];
1227-
if (!beam_view.eos && is_at_eos(llama, beam_view.tokens, beam_view.n_tokens)) {
1228-
beam_view.eos = true;
1226+
llama_beam_view& beam_view = beams_state.beam_views[i];
1227+
if (!beam_view.eob && is_at_eob(llama, beam_view.tokens, beam_view.n_tokens)) {
1228+
beam_view.eob = true;
12291229
}
12301230
}
12311231
printf(","); // Show progress

llama.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4333,10 +4333,10 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
43334333
struct llama_beam {
43344334
std::vector<llama_token> tokens;
43354335
float p; // Cumulative beam probability (renormalized relative to all beams)
4336-
bool eos; // Initialize end-of-sentence to false. Callback sets this to true.
4337-
// Sort beams by probability. In case of ties, prefer beams at eos.
4336+
bool eob; // Initialize end-of-beam to false. Callback sets this to true.
4337+
// Sort beams by probability. In case of ties, prefer beams at eob.
43384338
bool operator<(const llama_beam & rhs) const {
4339-
return std::make_tuple(p, eos) < std::make_tuple(rhs.p, rhs.eos);
4339+
return std::make_pair(p, eob) < std::make_pair(rhs.p, rhs.eob);
43404340
}
43414341
// Shift off first n tokens and discard them.
43424342
void shift_tokens(const size_t n) {
@@ -4345,7 +4345,7 @@ struct llama_beam {
43454345
tokens.resize(tokens.size() - n);
43464346
}
43474347
}
4348-
llama_beam_view view() const { return {tokens.data(), tokens.size(), p, eos}; }
4348+
llama_beam_view view() const { return {tokens.data(), tokens.size(), p, eob}; }
43494349
};
43504350

43514351
// A struct for calculating logit-related info.
@@ -4435,7 +4435,7 @@ struct llama_beam_search_data {
44354435
void fill_next_beams_by_top_probabilities(llama_beam & beam) {
44364436
// Min-heaps use a greater-than comparator.
44374437
const auto comp = [](const llama_beam & a, const llama_beam & b) { return a.p > b.p; };
4438-
if (beam.eos) {
4438+
if (beam.eob) {
44394439
// beam is at end-of-sentence, so just copy it to next_beams if its probability is high enough.
44404440
if (next_beams.size() < n_beams) {
44414441
next_beams.push_back(std::move(beam));
@@ -4513,16 +4513,16 @@ struct llama_beam_search_data {
45134513

45144514
// Loop:
45154515
// * while i < n_predict, AND
4516-
// * any of the beams have not yet reached end-of-sentence, AND
4516+
// * any of the beams have not yet reached end-of-beam (eob), AND
45174517
// * the highest probability beam(s) (plural in case of ties) are not at end-of-sentence
45184518
// (since all other beam probabilities can only decrease)
45194519
void loop(const llama_beam_search_callback_fn_t callback, void * const callback_data) {
4520-
beams.push_back({{}, 1.0f, false}); // Start with one empty beam w/ probability = 1.0 and !eos.
4521-
const auto not_eos = [](const llama_beam & beam) { return !beam.eos; };
4522-
for (int i = 0 ; i < n_predict && std::any_of(beams.begin(),beams.end(),not_eos) &&
4523-
!beams[top_beam_index()].eos ; ++i) {
4520+
beams.push_back({{}, 1.0f, false}); // Start with one empty beam w/ probability = 1.0 and !eob.
4521+
const auto not_eob = [](const llama_beam & beam) { return !beam.eob; };
4522+
for (int i = 0 ; i < n_predict && std::any_of(beams.begin(),beams.end(),not_eob) &&
4523+
!beams[top_beam_index()].eob ; ++i) {
45244524
callback(callback_data, get_beams_state(false)); // Sets common_prefix_length
4525-
update_beams_from_beam_views(); // Update values (p,eos) that callback may have changed.
4525+
update_beams_from_beam_views(); // Update values (p,eob) that callback may have changed.
45264526
if (common_prefix_length) {
45274527
llama_eval(ctx, beams[0].tokens.data(), common_prefix_length, n_past, n_threads);
45284528
n_past += common_prefix_length;
@@ -4554,11 +4554,11 @@ struct llama_beam_search_data {
45544554
return std::max_element(beams.begin(), beams.end()) - beams.begin();
45554555
}
45564556

4557-
// Copy (p,eos) for each beam which may have been changed by the callback.
4557+
// Copy (p,eob) for each beam which may have been changed by the callback.
45584558
void update_beams_from_beam_views() {
45594559
for (size_t i = 0 ; i < beams.size() ; ++i) {
45604560
beams[i].p = beam_views[i].p;
4561-
beams[i].eos = beam_views[i].eos;
4561+
beams[i].eob = beam_views[i].eob;
45624562
}
45634563
}
45644564
};

llama.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ extern "C" {
473473
const llama_token * tokens;
474474
size_t n_tokens;
475475
float p; // Cumulative beam probability (renormalized relative to all beams)
476-
bool eos; // Callback should set this to true when a beam is at end-of-sentence.
476+
bool eob; // Callback should set this to true when a beam is at end-of-beam.
477477
};
478478

479479
// Passed to beam_search_callback function.

0 commit comments

Comments
 (0)