Skip to content

Commit ac77583

Browse files
committed
Remove llama_context* ctx member from beam struct.
1 parent a99cc90 commit ac77583

File tree

1 file changed

+16
-26
lines changed

1 file changed

+16
-26
lines changed

llama.cpp

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2878,20 +2878,18 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
28782878
}
28792879

28802880
struct beam {
2881-
llama_context* ctx;
28822881
std::vector<llama_token> tokens;
28832882
float p; // Cumulative beam probability (renormalized with each token)
28842883
// end-of-sentence
28852884
bool eos() const { return !tokens.empty() && tokens.back() == llama_token_eos(); }
28862885
};
28872886

2888-
std::ostream& operator<<(std::ostream& os, beam const& b) {
2889-
os << "ctx(" << static_cast<void*>(b.ctx) << ") p(" << b.p
2890-
<< ") eos(" << std::boolalpha << b.eos() << ") tokens(";
2891-
for (auto const token_id : b.tokens) {
2892-
os << llama_token_to_str(b.ctx, token_id);
2887+
void out_beam(std::ostream& os, llama_context* ctx, beam const& b) {
2888+
os << "p(" << b.p << ") eos(" << std::boolalpha << b.eos() << ") tokens(";
2889+
for (llama_token const token_id : b.tokens) {
2890+
os << llama_token_to_str(ctx, token_id);
28932891
}
2894-
return os << ')';
2892+
os << ')';
28952893
}
28962894

28972895
// A struct for calculating logit-related info.
@@ -2938,8 +2936,8 @@ struct logit_info {
29382936
}
29392937
};
29402938

2941-
void fill_next_beams_by_top_probabilities(std::vector<beam>& next_beams, beam const& b,
2942-
int const beam_width, int const n_past, int const n_threads) {
2939+
void fill_next_beams_by_top_probabilities(llama_context* ctx, std::vector<beam>& next_beams,
2940+
beam const& b, int const beam_width, int const n_past, int const n_threads) {
29432941
auto const comp = [](beam const& a, beam const& b) { return a.p > b.p; };
29442942
if (b.eos()) {
29452943
// b is at end-of-sentence, so just copy it to next_beams if its probability is high enough.
@@ -2956,9 +2954,9 @@ void fill_next_beams_by_top_probabilities(std::vector<beam>& next_beams, beam co
29562954
} else {
29572955
// b is not at end-of-sentence, so branch with next top_k tokens.
29582956
if (!b.tokens.empty()) {
2959-
llama_eval(b.ctx, b.tokens.data(), b.tokens.size(), n_past, n_threads);
2957+
llama_eval(ctx, b.tokens.data(), b.tokens.size(), n_past, n_threads);
29602958
}
2961-
logit_info li(b.ctx);
2959+
logit_info li(ctx);
29622960
std::vector<llama_token_data> next_tokens = li.top_k(beam_width);
29632961
int i=0;
29642962
if (next_beams.size() < static_cast<size_t>(beam_width)) {
@@ -3001,15 +2999,14 @@ beam const& top_beam(std::vector<beam> const& beams) {
30012999
// fill_next_beams_by_top_probabilities() by randomly selecting from all next_beams.
30023000
// Not thread-safe.
30033001
const char* llama_beam_search(llama_context * ctx, int const beam_width,
3004-
int const n_past, int const n_predict, int const n_threads) {
3002+
int n_past, int const n_predict, int const n_threads) {
30053003
static std::string beam_search_response;
30063004
assert(ctx);
30073005
const int64_t t_start_sample_us = ggml_time_us();
30083006

30093007
std::vector<beam> beams;
30103008
beams.reserve(beam_width);
3011-
beams.push_back({ctx, {}, 1.0});
3012-
// Init next_beams with unique next token_id each.
3009+
beams.push_back({{}, 1.0});
30133010
std::vector<beam> next_beams;
30143011
next_beams.reserve(beam_width);
30153012
// Loop while there are any beams that have not yet reached end-of-sentence.
@@ -3019,35 +3016,28 @@ const char* llama_beam_search(llama_context * ctx, int const beam_width,
30193016
for (int i=0 ; i<n_predict && !eos(top_beam(beams)) &&
30203017
!std::all_of(beams.begin(), beams.end(), eos); ++i) {
30213018
for (beam& b : beams) {
3022-
fill_next_beams_by_top_probabilities(next_beams, b, beam_width, n_past, n_threads);
3019+
fill_next_beams_by_top_probabilities(ctx, next_beams, b, beam_width, n_past, n_threads);
30233020
}
30243021
beams.swap(next_beams);
30253022
next_beams.clear();
30263023
renormalize_beam_probabilities(beams);
30273024
#if 1 // DEBUG: print current beams for this iteration
30283025
std::cout << "\n\nCurrent beams:\n";
30293026
for (size_t j=0 ; j < beams.size() ; ++j) {
3030-
std::cout << "beams["<<j<<"]: " << beams[j] << std::endl;;
3027+
std::cout << "beams["<<j<<"]: ";
3028+
out_beam(std::cout, ctx, beams[j]);
3029+
std::cout << std::endl;
30313030
}
30323031
#else
30333032
std::cout << '.' << std::flush; // Show progress
30343033
#endif
30353034
}
3036-
#if 1 // DEBUG: print final beam results
3037-
for (size_t i=0 ; i<beams.size() ; ++i) {
3038-
std::cout << "\nbeams["<<i<<"] with p(" << beams[i].p << "): ";
3039-
for (llama_token const token : beams[i].tokens) {
3040-
std::cout << llama_token_to_str(beams[i].ctx, token);
3041-
}
3042-
std::cout << std::endl;
3043-
}
3044-
#endif
30453035

30463036
beam const& top_b = top_beam(beams);
30473037
// Save beam sentence to beam_search_response. Is there a better way?
30483038
std::ostringstream oss;
30493039
for (llama_token const token : top_b.tokens) {
3050-
oss << llama_token_to_str(top_b.ctx, token);
3040+
oss << llama_token_to_str(ctx, token);
30513041
}
30523042
beam_search_response = oss.str();
30533043

0 commit comments

Comments
 (0)