Skip to content

Commit 4e702e6

Browse files
committed
Cleanup: Change beam_width from type int to size_t.
1 parent e156b30 commit 4e702e6

File tree

3 files changed

+17
-16
lines changed

3 files changed

+17
-16
lines changed

examples/beam_search/beam_search.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,9 @@ int main(int argc, char ** argv)
140140
n_past += tokens_list.size();
141141

142142
std::vector<llama_token> response;
143+
size_t const beam_width = static_cast<size_t>(params.n_beams);
143144
int const n_predict = 256;
144-
llama_beam_search(ctx, beam_search_callback, &response, params.n_beams, n_past, n_predict, params.n_threads);
145+
llama_beam_search(ctx, beam_search_callback, &response, beam_width, n_past, n_predict, params.n_threads);
145146

146147
printf("\n\n");
147148
for (llama_token const token_id : response) {

llama.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2919,21 +2919,21 @@ struct logit_info {
29192919
, max_l(*std::max_element(logits, logits + n_vocab))
29202920
, normalizer(1.0f / std::accumulate(logits, logits + n_vocab, 0.0f, sum_exp{max_l}))
29212921
{ }
2922-
llama_token_data get_token_data(int const token_id) const {
2922+
llama_token_data get_token_data(llama_token const token_id) const {
29232923
constexpr auto p = std::numeric_limits<float>::quiet_NaN(); // never used
29242924
return {token_id, logits[token_id], p};
29252925
}
29262926
// Return top k token_data by logit.
2927-
std::vector<llama_token_data> top_k(int k) {
2927+
std::vector<llama_token_data> top_k(size_t k) {
29282928
std::vector<llama_token_data> min_heap; // min-heap by logit
2929-
k = std::min(k, n_vocab);
2930-
min_heap.reserve(k);
2931-
for (int token_id=0 ; token_id<k ; ++token_id) {
2929+
llama_token const k_min = std::min(static_cast<llama_token>(k), n_vocab);
2930+
min_heap.reserve(k_min);
2931+
for (llama_token token_id=0 ; token_id<k_min ; ++token_id) {
29322932
min_heap.push_back(get_token_data(token_id));
29332933
}
29342934
auto comp = [](llama_token_data const& a, llama_token_data const& b) { return a.logit > b.logit; };
29352935
std::make_heap(min_heap.begin(), min_heap.end(), comp);
2936-
for (int token_id=k ; token_id<n_vocab ; ++token_id) {
2936+
for (llama_token token_id=k_min ; token_id<n_vocab ; ++token_id) {
29372937
if (min_heap.front().logit < logits[token_id]) {
29382938
std::pop_heap(min_heap.begin(), min_heap.end(), comp);
29392939
min_heap.back().id = token_id;
@@ -2950,7 +2950,7 @@ struct logit_info {
29502950

29512951
struct beam_search {
29522952
llama_context * ctx;
2953-
int beam_width;
2953+
size_t beam_width;
29542954
int n_past;
29552955
int n_predict;
29562956
int n_threads;
@@ -2966,7 +2966,7 @@ struct beam_search {
29662966
std::vector<size_t> beam_lengths;
29672967
std::vector<llama_token const*> beam_ptrs;
29682968

2969-
beam_search(llama_context * ctx, int beam_width, int n_past, int n_predict, int n_threads)
2969+
beam_search(llama_context * ctx, size_t beam_width, int n_past, int n_predict, int n_threads)
29702970
: ctx(ctx)
29712971
, beam_width(beam_width)
29722972
, n_past(n_past)
@@ -3001,9 +3001,9 @@ struct beam_search {
30013001
}
30023002
if (b.eos()) {
30033003
// beam is at end-of-sentence, so just copy it to next_beams if its probability is high enough.
3004-
if (next_beams.size() < static_cast<size_t>(beam_width)) {
3004+
if (next_beams.size() < beam_width) {
30053005
next_beams.push_back(std::move(b));
3006-
if (next_beams.size() == static_cast<size_t>(beam_width)) {
3006+
if (next_beams.size() == beam_width) {
30073007
std::make_heap(next_beams.begin(), next_beams.end(), comp);
30083008
}
30093009
} else if (next_beams.front().p < b.p) {
@@ -3023,9 +3023,9 @@ struct beam_search {
30233023
}
30243024
logit_info logit_info(ctx);
30253025
std::vector<llama_token_data> next_tokens = logit_info.top_k(beam_width);
3026-
int i=0;
3027-
if (next_beams.size() < static_cast<size_t>(beam_width)) {
3028-
for (; next_beams.size() < static_cast<size_t>(beam_width) ; ++i) {
3026+
size_t i=0;
3027+
if (next_beams.size() < beam_width) {
3028+
for (; next_beams.size() < beam_width ; ++i) {
30293029
beam next_beam = b;
30303030
next_beam.tokens.push_back(next_tokens[i].id);
30313031
next_beam.p *= logit_info.probability_from_logit(next_tokens[i].logit);
@@ -3131,7 +3131,7 @@ struct beam_search {
31313131

31323132
void llama_beam_search(llama_context * ctx,
31333133
llama_beam_search_callback_fn_t callback, void* callback_state,
3134-
int beam_width, int n_past, int const n_predict, int const n_threads) {
3134+
size_t beam_width, int n_past, int n_predict, int n_threads) {
31353135
assert(ctx);
31363136
const int64_t t_start_sample_us = ggml_time_us();
31373137

llama.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ extern "C" {
473473
/// @param n_past The number of tokens already evaluated.
474474
/// @param n_predict The maximum number of tokens to predict.
475475
/// @param n_threads The maximum number of threads as passed to llama_eval().
476-
LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void* callback_state, int beam_width, int n_past, int n_predict, int n_threads);
476+
LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void* callback_state, size_t beam_width, int n_past, int n_predict, int n_threads);
477477

478478
// Performance information
479479
LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);

0 commit comments

Comments
 (0)