Skip to content

Commit 7ede556

Browse files
sampling: fix top_k <= 0
1 parent 213d143 commit 7ede556

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

llama.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8371,6 +8371,10 @@ void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * can
83718371
// return;
83728372
// }
83738373

8374+
if (k <= 0) {
8375+
k = candidates->size;
8376+
}
8377+
83748378
const int64_t t_start_sample_us = ggml_time_us();
83758379

83768380
k = std::max(k, (int) min_keep);

tests/test-sampling.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,8 @@ int main(void) {
235235

236236
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 1);
237237
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 3);
238+
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);
239+
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0);
238240

239241
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 0);
240242
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f}, 0.7f);

0 commit comments

Comments
 (0)