@@ -1757,20 +1757,28 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t
1757
1757
// find max logit and calculate mean
1758
1758
float max = cur_p->data [0 ].logit ;
1759
1759
float logits_sum = 0 ;
1760
+ size_t valid_count = 0 ;
1760
1761
for (size_t i = 0 ; i < cur_p->size ; ++i) {
1761
- if (cur_p->data [i].logit > max) {
1762
- max = cur_p->data [i].logit ;
1762
+ // Only count non-negative infinity values
1763
+ if (cur_p->data [i].logit != -INFINITY) {
1764
+ if (cur_p->data [i].logit > max) {
1765
+ max = cur_p->data [i].logit ;
1766
+ }
1767
+ logits_sum += cur_p->data [i].logit ;
1768
+ valid_count++;
1763
1769
}
1764
- logits_sum += cur_p->data [i].logit ;
1765
1770
}
1766
- float mean = logits_sum/cur_p-> size ;
1771
+ float mean = valid_count > 0 ? logits_sum/valid_count : 0 ;
1767
1772
1768
1773
// calculate standard deviation
1769
1774
float acc = 0 ;
1770
1775
for (size_t i = 0 ; i < cur_p->size ; ++i) {
1771
- acc += pow (cur_p->data [i].logit - mean, 2 );
1776
+ // Skip -infinity in std calculation
1777
+ if (cur_p->data [i].logit != -INFINITY) {
1778
+ acc += pow (cur_p->data [i].logit - mean, 2 );
1779
+ }
1772
1780
}
1773
- float std = sqrt (acc/cur_p-> size ) ;
1781
+ float std = valid_count > 0 ? sqrt (acc/valid_count) : 0 ;
1774
1782
1775
1783
// apply mask
1776
1784
for (size_t i = 0 ; i < cur_p->size ; ++i) {
0 commit comments