@@ -4333,10 +4333,10 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
4333
4333
struct llama_beam {
4334
4334
std::vector<llama_token> tokens;
4335
4335
float p; // Cumulative beam probability (renormalized relative to all beams)
4336
- bool eos ; // Initialize end-of-sentence to false. Callback sets this to true.
4337
- // Sort beams by probability. In case of ties, prefer beams at eos .
4336
+ bool eob ; // Initialize end-of-beam to false. Callback sets this to true.
4337
+ // Sort beams by probability. In case of ties, prefer beams at eob .
4338
4338
bool operator <(const llama_beam & rhs) const {
4339
- return std::make_tuple (p, eos ) < std::make_tuple (rhs.p , rhs.eos );
4339
+ return std::make_pair (p, eob ) < std::make_pair (rhs.p , rhs.eob );
4340
4340
}
4341
4341
// Shift off first n tokens and discard them.
4342
4342
void shift_tokens (const size_t n) {
@@ -4345,7 +4345,7 @@ struct llama_beam {
4345
4345
tokens.resize (tokens.size () - n);
4346
4346
}
4347
4347
}
4348
- llama_beam_view view () const { return {tokens.data (), tokens.size (), p, eos }; }
4348
+ llama_beam_view view () const { return {tokens.data (), tokens.size (), p, eob }; }
4349
4349
};
4350
4350
4351
4351
// A struct for calculating logit-related info.
@@ -4435,7 +4435,7 @@ struct llama_beam_search_data {
4435
4435
void fill_next_beams_by_top_probabilities (llama_beam & beam) {
4436
4436
// Min-heaps use a greater-than comparator.
4437
4437
const auto comp = [](const llama_beam & a, const llama_beam & b) { return a.p > b.p ; };
4438
- if (beam.eos ) {
4438
+ if (beam.eob ) {
4439
4439
// beam is at end-of-sentence, so just copy it to next_beams if its probability is high enough.
4440
4440
if (next_beams.size () < n_beams) {
4441
4441
next_beams.push_back (std::move (beam));
@@ -4513,16 +4513,16 @@ struct llama_beam_search_data {
4513
4513
4514
4514
// Loop:
4515
4515
// * while i < n_predict, AND
4516
- // * any of the beams have not yet reached end-of-sentence , AND
4516
+ // * any of the beams have not yet reached end-of-beam (eob) , AND
4517
4517
// * the highest probability beam(s) (plural in case of ties) are not at end-of-sentence
4518
4518
// (since all other beam probabilities can only decrease)
4519
4519
void loop (const llama_beam_search_callback_fn_t callback, void * const callback_data) {
4520
- beams.push_back ({{}, 1 .0f , false }); // Start with one empty beam w/ probability = 1.0 and !eos .
4521
- const auto not_eos = [](const llama_beam & beam) { return !beam.eos ; };
4522
- for (int i = 0 ; i < n_predict && std::any_of (beams.begin (),beams.end (),not_eos ) &&
4523
- !beams[top_beam_index ()].eos ; ++i) {
4520
+ beams.push_back ({{}, 1 .0f , false }); // Start with one empty beam w/ probability = 1.0 and !eob .
4521
+ const auto not_eob = [](const llama_beam & beam) { return !beam.eob ; };
4522
+ for (int i = 0 ; i < n_predict && std::any_of (beams.begin (),beams.end (),not_eob ) &&
4523
+ !beams[top_beam_index ()].eob ; ++i) {
4524
4524
callback (callback_data, get_beams_state (false )); // Sets common_prefix_length
4525
- update_beams_from_beam_views (); // Update values (p,eos ) that callback may have changed.
4525
+ update_beams_from_beam_views (); // Update values (p,eob ) that callback may have changed.
4526
4526
if (common_prefix_length) {
4527
4527
llama_eval (ctx, beams[0 ].tokens .data (), common_prefix_length, n_past, n_threads);
4528
4528
n_past += common_prefix_length;
@@ -4554,11 +4554,11 @@ struct llama_beam_search_data {
4554
4554
return std::max_element (beams.begin (), beams.end ()) - beams.begin ();
4555
4555
}
4556
4556
4557
- // Copy (p,eos ) for each beam which may have been changed by the callback.
4557
+ // Copy (p,eob ) for each beam which may have been changed by the callback.
4558
4558
void update_beams_from_beam_views () {
4559
4559
for (size_t i = 0 ; i < beams.size () ; ++i) {
4560
4560
beams[i].p = beam_views[i].p ;
4561
- beams[i].eos = beam_views[i].eos ;
4561
+ beams[i].eob = beam_views[i].eob ;
4562
4562
}
4563
4563
}
4564
4564
};
0 commit comments