Skip to content

Commit 26d4efd

Browse files
sampling: fix top_k <= 0 (#5388)
* sampling: fix top_k <= 0 * Update llama.cpp Co-authored-by: Georgi Gerganov <[email protected]> --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 8504d2d commit 26d4efd

File tree

3 files changed

+7
-1
lines changed

3 files changed

+7
-1
lines changed

common/sampling.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ static void sampler_queue(
132132
const float temp = params.temp;
133133
const float dynatemp_range = params.dynatemp_range;
134134
const float dynatemp_exponent = params.dynatemp_exponent;
135-
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
135+
const int32_t top_k = params.top_k;
136136
const float top_p = params.top_p;
137137
const float min_p = params.min_p;
138138
const float tfs_z = params.tfs_z;

llama.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8585,6 +8585,10 @@ void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * can
85858585
// }
85868586

85878587
const int64_t t_start_sample_us = ggml_time_us();
8588+
8589+
if (k <= 0) {
8590+
k = candidates->size;
8591+
}
85888592

85898593
k = std::max(k, (int) min_keep);
85908594
k = std::min(k, (int) candidates->size);

tests/test-sampling.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,8 @@ int main(void) {
235235

236236
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 1);
237237
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 3);
238+
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);
239+
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0);
238240

239241
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 0);
240242
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f}, 0.7f);

0 commit comments

Comments
 (0)