Skip to content

Commit 91a86a6

Browse files
authored
sampling : don't consider -infinity values in top_n_sigma (#13344)
1 parent f4ed10b commit 91a86a6

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

src/llama-sampling.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1757,20 +1757,28 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t
17571757
// find max logit and calculate mean
17581758
float max = cur_p->data[0].logit;
17591759
float logits_sum = 0;
1760+
size_t valid_count = 0;
17601761
for (size_t i = 0; i < cur_p->size; ++i) {
1761-
if (cur_p->data[i].logit > max) {
1762-
max = cur_p->data[i].logit;
1762+
// Only count non-negative infinity values
1763+
if (cur_p->data[i].logit != -INFINITY) {
1764+
if (cur_p->data[i].logit > max) {
1765+
max = cur_p->data[i].logit;
1766+
}
1767+
logits_sum += cur_p->data[i].logit;
1768+
valid_count++;
17631769
}
1764-
logits_sum += cur_p->data[i].logit;
17651770
}
1766-
float mean = logits_sum/cur_p->size;
1771+
float mean = valid_count > 0 ? logits_sum/valid_count : 0;
17671772

17681773
// calculate standard deviation
17691774
float acc = 0;
17701775
for (size_t i = 0; i < cur_p->size; ++i) {
1771-
acc += pow(cur_p->data[i].logit - mean, 2);
1776+
// Skip -infinity in std calculation
1777+
if (cur_p->data[i].logit != -INFINITY) {
1778+
acc += pow(cur_p->data[i].logit - mean, 2);
1779+
}
17721780
}
1773-
float std = sqrt(acc/cur_p->size);
1781+
float std = valid_count > 0 ? sqrt(acc/valid_count) : 0;
17741782

17751783
//apply mask
17761784
for (size_t i = 0; i < cur_p->size; ++i) {

0 commit comments

Comments
 (0)