Skip to content

Commit f9cd683

Browse files
authored
sampling : make sure samplers return at least 1 token (#13822)
* sampling : min-p should always return at least one token ggml-ci * sampling : same for typical sampling * tests : sampling tests use min_keep == 0 ggml-ci
1 parent 4f81b33 commit f9cd683

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

src/llama-sampling.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -798,7 +798,7 @@ static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_d
798798
}
799799

800800
// if we have enough values the operation was a success
801-
if (filtered_tokens.size() >= ctx->min_keep) {
801+
if (!filtered_tokens.empty() && filtered_tokens.size() >= ctx->min_keep) {
802802
memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
803803
cur_p->size = filtered_tokens.size();
804804
min_p_applied = true;
@@ -909,7 +909,7 @@ static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token
909909
cum_sum += cur_p->data[idx].p;
910910

911911
// Check if the running sum is greater than typical or if we have kept at least min_keep tokens
912-
if (cum_sum > ctx->p && i >= ctx->min_keep - 1) {
912+
if (cum_sum > ctx->p && (ctx->min_keep == 0 || i >= ctx->min_keep - 1)) {
913913
last_idx = i + 1;
914914
break;
915915
}

tests/test-sampling.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ static void test_top_p(const std::vector<float> & probs, const std::vector<float
9898
sampler_tester tester(probs, probs_expected);
9999

100100
DUMP(&tester.cur_p);
101-
tester.apply(llama_sampler_init_top_p(p, 1));
101+
tester.apply(llama_sampler_init_top_p(p, 0));
102102
tester.apply(llama_sampler_init_dist (0));
103103
DUMP(&tester.cur_p);
104104

@@ -109,7 +109,7 @@ static void test_min_p(const std::vector<float> & probs, const std::vector<float
109109
sampler_tester tester(probs, probs_expected);
110110

111111
DUMP(&tester.cur_p);
112-
tester.apply(llama_sampler_init_min_p(p, 1));
112+
tester.apply(llama_sampler_init_min_p(p, 0));
113113
tester.apply(llama_sampler_init_dist (0));
114114
DUMP(&tester.cur_p);
115115

@@ -130,7 +130,7 @@ static void test_typical(const std::vector<float> & probs, const std::vector<flo
130130
sampler_tester tester(probs, probs_expected);
131131

132132
DUMP(&tester.cur_p);
133-
tester.apply(llama_sampler_init_typical(p, 1));
133+
tester.apply(llama_sampler_init_typical(p, 0));
134134
DUMP(&tester.cur_p);
135135

136136
tester.check();
@@ -332,6 +332,7 @@ int main(void) {
332332
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.7f, 0.3f/0.7f}, 0.74f);
333333
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 0.76f);
334334
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 1.00f);
335+
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 1.05f);
335336

336337
printf("XTC should:\n");
337338
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.1f}, 0.99f, 0.09f);
@@ -341,8 +342,8 @@ int main(void) {
341342
printf("XTC should not:\n");
342343
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0.99f, 0.39f);
343344

344-
test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
345-
test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);
345+
test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
346+
test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);
346347

347348
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f);
348349
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);

0 commit comments

Comments
 (0)