Skip to content

Commit 3359308

Browse files
committed
Rename beam_width to n_beams for consistency with existing convention.
1 parent f75fab4 commit 3359308

File tree

3 files changed

+24
-22
lines changed

3 files changed

+24
-22
lines changed

examples/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ struct gpt_params {
2929
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
3030
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
3131
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
32-
int32_t n_beams = 0; // Used in mem allocation if > 0 and by llama_beam_search().
32+
int32_t n_beams = 0; // if non-zero then use beam search of given width.
3333
float rms_norm_eps = LLAMA_DEFAULT_RMS_EPS; // rms norm epsilon
3434
float rope_freq_base = 10000.0f; // RoPE base frequency
3535
float rope_freq_scale = 1.0f; // RoPE frequency scaling factor

llama.cpp

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2954,7 +2954,7 @@ struct logit_info {
29542954

29552955
struct beam_search {
29562956
llama_context * ctx;
2957-
size_t beam_width;
2957+
size_t n_beams;
29582958
int n_past;
29592959
int n_predict;
29602960
int n_threads;
@@ -2969,15 +2969,15 @@ struct beam_search {
29692969
// Used to communicate to/from callback on beams state.
29702970
std::vector<llama_beam_view> beam_views;
29712971

2972-
beam_search(llama_context * ctx, size_t beam_width, int n_past, int n_predict, int n_threads)
2972+
beam_search(llama_context * ctx, size_t n_beams, int n_past, int n_predict, int n_threads)
29732973
: ctx(ctx)
2974-
, beam_width(beam_width)
2974+
, n_beams(n_beams)
29752975
, n_past(n_past)
29762976
, n_predict(n_predict)
29772977
, n_threads(n_threads)
2978-
, beam_views(beam_width) {
2979-
beams.reserve(beam_width);
2980-
next_beams.reserve(beam_width);
2978+
, beam_views(n_beams) {
2979+
beams.reserve(n_beams);
2980+
next_beams.reserve(n_beams);
29812981
}
29822982

29832983
// Collapse beams to a single beam given by index.
@@ -2988,7 +2988,7 @@ struct beam_search {
29882988
beams.resize(1);
29892989
}
29902990

2991-
// Min-heaps are used to efficiently collect the top-k elements (k=beam_width).
2991+
// Min-heaps are used to efficiently collect the top-k elements (k=n_beams).
29922992
// The repetative patterns below reflect the 2 stages of heaps:
29932993
// * Gather elements until the vector is full, then call std::make_heap() on it.
29942994
// * If the heap is full and a new element is found that should be included, pop the
@@ -3003,9 +3003,9 @@ struct beam_search {
30033003
}
30043004
if (beam.eos) {
30053005
// beam is at end-of-sentence, so just copy it to next_beams if its probability is high enough.
3006-
if (next_beams.size() < beam_width) {
3006+
if (next_beams.size() < n_beams) {
30073007
next_beams.push_back(std::move(beam));
3008-
if (next_beams.size() == beam_width) {
3008+
if (next_beams.size() == n_beams) {
30093009
std::make_heap(next_beams.begin(), next_beams.end(), comp);
30103010
}
30113011
} else if (next_beams.front().p < beam.p) {
@@ -3024,10 +3024,10 @@ struct beam_search {
30243024
}
30253025
}
30263026
logit_info logit_info(ctx);
3027-
std::vector<llama_token_data> next_tokens = logit_info.top_k(beam_width);
3027+
std::vector<llama_token_data> next_tokens = logit_info.top_k(n_beams);
30283028
size_t i=0;
3029-
if (next_beams.size() < beam_width) {
3030-
for (; next_beams.size() < beam_width ; ++i) {
3029+
if (next_beams.size() < n_beams) {
3030+
for (; next_beams.size() < n_beams ; ++i) {
30313031
llama_beam next_beam = beam;
30323032
next_beam.tokens.push_back(next_tokens[i].id);
30333033
next_beam.p *= logit_info.probability_from_logit(next_tokens[i].logit);
@@ -3043,7 +3043,7 @@ struct beam_search {
30433043
std::push_heap(next_beams.begin(), next_beams.end(), comp);
30443044
}
30453045
}
3046-
for (; i < beam_width ; ++i) {
3046+
for (; i < n_beams ; ++i) {
30473047
float const next_p = beam.p * logit_info.probability_from_logit(next_tokens[i].logit);
30483048
if (next_beams.front().p < next_p) {
30493049
std::pop_heap(next_beams.begin(), next_beams.end(), comp);
@@ -3076,7 +3076,9 @@ struct beam_search {
30763076
// Side effect: set common_prefix_length = find_common_prefix_length();
30773077
llama_beams_state get_beams_state(bool const last_call) {
30783078
for (size_t i=0 ; i<beams.size() ; ++i) {
3079-
beam_views[i] = beams[i].view();
3079+
//beam_views[i] = beams[i].view();
3080+
auto view = beams.at(i).view();
3081+
beam_views.at(i) = view; // capacity 0
30803082
}
30813083
common_prefix_length = find_common_prefix_length();
30823084
return {beam_views.data(), beams.size(), common_prefix_length, last_call};
@@ -3132,11 +3134,11 @@ struct beam_search {
31323134

31333135
void llama_beam_search(llama_context * ctx,
31343136
llama_beam_search_callback_fn_t callback, void* callback_state,
3135-
size_t beam_width, int n_past, int n_predict, int n_threads) {
3137+
size_t n_beams, int n_past, int n_predict, int n_threads) {
31363138
assert(ctx);
31373139
const int64_t t_start_sample_us = ggml_time_us();
31383140

3139-
beam_search beam_search(ctx, beam_width, n_past, n_predict, n_threads);
3141+
beam_search beam_search(ctx, n_beams, n_past, n_predict, n_threads);
31403142

31413143
beam_search.loop(callback, callback_state);
31423144

llama.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -487,11 +487,11 @@ extern "C" {
487487
/// @param callback Invoked for each iteration of the beam_search loop, passing in beams_state.
488488
/// The return beam_search_control can be used to control the beam_search execution.
489489
/// @param callback_state A pointer that is simply passed back to callback.
490-
/// @param beam_width The number of parallel beams to use.
491-
/// @param n_past The number of tokens already evaluated.
492-
/// @param n_predict The maximum number of tokens to predict.
493-
/// @param n_threads The maximum number of threads as passed to llama_eval().
494-
LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void* callback_state, size_t beam_width, int n_past, int n_predict, int n_threads);
490+
/// @param n_beams Number of beams to use.
491+
/// @param n_past Number of tokens already evaluated.
492+
/// @param n_predict Maximum number of tokens to predict. EOS may occur earlier.
493+
/// @param n_threads Number of threads as passed to llama_eval().
494+
LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void* callback_state, size_t n_beams, int n_past, int n_predict, int n_threads);
495495

496496
// Performance information
497497
LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);

0 commit comments

Comments
 (0)