Skip to content

Commit 2a62535

Browse files
committed
tests : try to fix tail free sampling test
ggml-ci
1 parent d4d7d2f commit 2a62535

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

llama.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2024,9 +2024,18 @@ void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array *
20242024
}
20252025

20262026
// Normalize the second derivatives
2027-
float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f);
2028-
for (float & value : second_derivatives) {
2029-
value /= second_derivatives_sum;
2027+
{
2028+
const float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f);
2029+
2030+
if (second_derivatives_sum > 1e-6f) {
2031+
for (float & value : second_derivatives) {
2032+
value /= second_derivatives_sum;
2033+
}
2034+
} else {
2035+
for (float & value : second_derivatives) {
2036+
value = 1.0f / second_derivatives.size();
2037+
}
2038+
}
20302039
}
20312040

20322041
float cum_sum = 0.0f;

tests/test-sampling.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,4 +200,6 @@ int main(void) {
200200
test_frequency_presence_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 5.0f, 5.0f);
201201

202202
printf("OK\n");
203+
204+
return 0;
203205
}

0 commit comments

Comments
 (0)