Skip to content

Commit 10c8fd2

Browse files
committed
Cleanup: Change beam_width from type int to size_t.
1 parent f20e584 commit 10c8fd2

File tree

3 files changed

+17
-16
lines changed

3 files changed

+17
-16
lines changed

examples/beam_search/beam_search.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,9 @@ int main(int argc, char ** argv)
140140
n_past += tokens_list.size();
141141

142142
std::vector<llama_token> response;
143+
size_t const beam_width = static_cast<size_t>(params.n_beams);
143144
int const n_predict = 256;
144-
llama_beam_search(ctx, beam_search_callback, &response, params.n_beams, n_past, n_predict, params.n_threads);
145+
llama_beam_search(ctx, beam_search_callback, &response, beam_width, n_past, n_predict, params.n_threads);
145146

146147
printf("\n\n");
147148
for (llama_token const token_id : response) {

llama.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2934,21 +2934,21 @@ struct logit_info {
29342934
, max_l(*std::max_element(logits, logits + n_vocab))
29352935
, normalizer(1.0f / std::accumulate(logits, logits + n_vocab, 0.0f, sum_exp{max_l}))
29362936
{ }
2937-
llama_token_data get_token_data(int const token_id) const {
2937+
llama_token_data get_token_data(llama_token const token_id) const {
29382938
constexpr auto p = std::numeric_limits<float>::quiet_NaN(); // never used
29392939
return {token_id, logits[token_id], p};
29402940
}
29412941
// Return top k token_data by logit.
2942-
std::vector<llama_token_data> top_k(int k) {
2942+
std::vector<llama_token_data> top_k(size_t k) {
29432943
std::vector<llama_token_data> min_heap; // min-heap by logit
2944-
k = std::min(k, n_vocab);
2945-
min_heap.reserve(k);
2946-
for (int token_id=0 ; token_id<k ; ++token_id) {
2944+
llama_token const k_min = std::min(static_cast<llama_token>(k), n_vocab);
2945+
min_heap.reserve(k_min);
2946+
for (llama_token token_id=0 ; token_id<k_min ; ++token_id) {
29472947
min_heap.push_back(get_token_data(token_id));
29482948
}
29492949
auto comp = [](llama_token_data const& a, llama_token_data const& b) { return a.logit > b.logit; };
29502950
std::make_heap(min_heap.begin(), min_heap.end(), comp);
2951-
for (int token_id=k ; token_id<n_vocab ; ++token_id) {
2951+
for (llama_token token_id=k_min ; token_id<n_vocab ; ++token_id) {
29522952
if (min_heap.front().logit < logits[token_id]) {
29532953
std::pop_heap(min_heap.begin(), min_heap.end(), comp);
29542954
min_heap.back().id = token_id;
@@ -2965,7 +2965,7 @@ struct logit_info {
29652965

29662966
struct beam_search {
29672967
llama_context * ctx;
2968-
int beam_width;
2968+
size_t beam_width;
29692969
int n_past;
29702970
int n_predict;
29712971
int n_threads;
@@ -2981,7 +2981,7 @@ struct beam_search {
29812981
std::vector<size_t> beam_lengths;
29822982
std::vector<llama_token const*> beam_ptrs;
29832983

2984-
beam_search(llama_context * ctx, int beam_width, int n_past, int n_predict, int n_threads)
2984+
beam_search(llama_context * ctx, size_t beam_width, int n_past, int n_predict, int n_threads)
29852985
: ctx(ctx)
29862986
, beam_width(beam_width)
29872987
, n_past(n_past)
@@ -3016,9 +3016,9 @@ struct beam_search {
30163016
}
30173017
if (b.eos()) {
30183018
// beam is at end-of-sentence, so just copy it to next_beams if its probability is high enough.
3019-
if (next_beams.size() < static_cast<size_t>(beam_width)) {
3019+
if (next_beams.size() < beam_width) {
30203020
next_beams.push_back(std::move(b));
3021-
if (next_beams.size() == static_cast<size_t>(beam_width)) {
3021+
if (next_beams.size() == beam_width) {
30223022
std::make_heap(next_beams.begin(), next_beams.end(), comp);
30233023
}
30243024
} else if (next_beams.front().p < b.p) {
@@ -3038,9 +3038,9 @@ struct beam_search {
30383038
}
30393039
logit_info logit_info(ctx);
30403040
std::vector<llama_token_data> next_tokens = logit_info.top_k(beam_width);
3041-
int i=0;
3042-
if (next_beams.size() < static_cast<size_t>(beam_width)) {
3043-
for (; next_beams.size() < static_cast<size_t>(beam_width) ; ++i) {
3041+
size_t i=0;
3042+
if (next_beams.size() < beam_width) {
3043+
for (; next_beams.size() < beam_width ; ++i) {
30443044
beam next_beam = b;
30453045
next_beam.tokens.push_back(next_tokens[i].id);
30463046
next_beam.p *= logit_info.probability_from_logit(next_tokens[i].logit);
@@ -3146,7 +3146,7 @@ struct beam_search {
31463146

31473147
void llama_beam_search(llama_context * ctx,
31483148
llama_beam_search_callback_fn_t callback, void* callback_state,
3149-
int beam_width, int n_past, int const n_predict, int const n_threads) {
3149+
size_t beam_width, int n_past, int n_predict, int n_threads) {
31503150
assert(ctx);
31513151
const int64_t t_start_sample_us = ggml_time_us();
31523152

llama.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ extern "C" {
490490
/// @param n_past The number of tokens already evaluated.
491491
/// @param n_predict The maximum number of tokens to predict.
492492
/// @param n_threads The maximum number of threads as passed to llama_eval().
493-
LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void* callback_state, int beam_width, int n_past, int n_predict, int n_threads);
493+
LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void* callback_state, size_t beam_width, int n_past, int n_predict, int n_threads);
494494

495495
// Performance information
496496
LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);

0 commit comments

Comments
 (0)