@@ -222,13 +222,18 @@ struct kl_divergence_result {
222
222
double sum_kld2 = 0 ;
223
223
double sum_nll_diff = 0 ;
224
224
double sum_nll_diff2 = 0 ;
225
+ size_t n_same_top = 0 ;
225
226
size_t count = 0 ;
226
227
};
227
228
228
- static void log_softmax (int n_vocab, const float * logits, const uint16_t * base_log_prob, int tok, kl_divergence_result & kld) {
229
+ static double log_softmax (int n_vocab, const float * logits, const uint16_t * base_log_prob, int tok, kl_divergence_result & kld) {
229
230
float max_logit = logits[0 ];
231
+ int imax = 0 ;
230
232
for (int i = 1 ; i < n_vocab; ++i) {
231
- max_logit = std::max (max_logit, logits[i]);
233
+ if (logits[i] > max_logit) {
234
+ max_logit = logits[i];
235
+ imax = i;
236
+ }
232
237
}
233
238
double sum_exp = 0.0 ;
234
239
for (int i = 0 ; i < n_vocab; ++i) {
@@ -247,8 +252,14 @@ static void log_softmax(int n_vocab, const float * logits, const uint16_t * base
247
252
kld.sum_nll_diff2 += nll*nll;
248
253
max_logit += log_sum_exp;
249
254
double sum = 0 ;
255
+ int imax_base = -1 ;
256
+ float p_log_base_max = 0 ;
250
257
for (int i = 0 ; i < n_vocab; ++i) {
251
258
const float p_log_base = scale*base_log_prob[i] + min_log_prob;
259
+ if (i == 0 || p_log_base > p_log_base_max) {
260
+ p_log_base_max = p_log_base;
261
+ imax_base = i;
262
+ }
252
263
if (p_log_base > -16 .f ) {
253
264
const float p_base = expf (p_log_base);
254
265
sum += p_base * (p_log_base - logits[i] + max_logit);
@@ -257,14 +268,17 @@ static void log_softmax(int n_vocab, const float * logits, const uint16_t * base
257
268
kld.sum_kld += sum;
258
269
kld.sum_kld2 += sum*sum;
259
270
++kld.count ;
271
+ if (imax == imax_base) ++kld.n_same_top ;
272
+ return sum;
260
273
}
261
274
262
275
static void process_logits (int n_vocab, const float * logits, const int * tokens, int n_token,
263
- std::vector<std::thread> & workers, const std::vector<uint16_t > & base_log_probs, kl_divergence_result & kld) {
276
+ std::vector<std::thread> & workers, const std::vector<uint16_t > & base_log_probs, kl_divergence_result & kld,
277
+ float * kld_values) {
264
278
std::mutex mutex;
265
279
const int nv = 2 *((n_vocab + 1 )/2 ) + 4 ;
266
280
int counter = 0 ;
267
- auto compute = [&mutex, &counter, &base_log_probs, &kld, n_vocab, logits, tokens, n_token, nv] () {
281
+ auto compute = [&mutex, &counter, &base_log_probs, &kld, n_vocab, logits, tokens, n_token, nv, kld_values ] () {
268
282
kl_divergence_result local_kld;
269
283
while (true ) {
270
284
std::unique_lock<std::mutex> lock (mutex);
@@ -276,11 +290,13 @@ static void process_logits(int n_vocab, const float * logits, const int * tokens
276
290
kld.sum_kld2 += local_kld.sum_kld2 ;
277
291
kld.sum_nll_diff += local_kld.sum_nll_diff ;
278
292
kld.sum_nll_diff2 += local_kld.sum_nll_diff2 ;
293
+ kld.n_same_top += local_kld.n_same_top ;
279
294
kld.count += local_kld.count ;
280
295
break ;
281
296
}
282
297
lock.unlock ();
283
- log_softmax (n_vocab, logits + i*n_vocab, base_log_probs.data () + i*nv, tokens[i+1 ], local_kld);
298
+ double v = log_softmax (n_vocab, logits + i*n_vocab, base_log_probs.data () + i*nv, tokens[i+1 ], local_kld);
299
+ kld_values[i] = (float )v;
284
300
}
285
301
};
286
302
for (auto & w : workers) {
@@ -1615,7 +1631,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
1615
1631
in.read ((char *)&n_vocab, sizeof (n_vocab));
1616
1632
in.read ((char *)&n_chunk, sizeof (n_chunk));
1617
1633
if (in.fail ()) {
1618
- fprintf (stderr, " %s: failed rwading n_vocab, n_chunk from %s\n " , __func__, params.logits_file .c_str ());
1634
+ fprintf (stderr, " %s: failed reading n_vocab, n_chunk from %s\n " , __func__, params.logits_file .c_str ());
1619
1635
return ;
1620
1636
}
1621
1637
if (n_vocab != llama_n_vocab (llama_get_model (ctx))) {
@@ -1634,6 +1650,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
1634
1650
const bool add_bos = llama_should_add_bos_token (llama_get_model (ctx));
1635
1651
1636
1652
std::vector<uint16_t > log_probs_uint16 (size_t (n_ctx - 1 - n_ctx/2 ) * nv);
1653
+ std::vector<float > kld_values (size_t (n_ctx - 1 - n_ctx/2 )*n_chunk);
1637
1654
std::vector<float > logits;
1638
1655
if (num_batches > 1 ) {
1639
1656
logits.reserve (n_ctx * n_vocab);
@@ -1652,6 +1669,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
1652
1669
};
1653
1670
1654
1671
kl_divergence_result kld;
1672
+ auto kld_ptr = kld_values.data ();
1655
1673
1656
1674
for (int i = 0 ; i < n_chunk; ++i) {
1657
1675
const int start = i * n_ctx;
@@ -1705,27 +1723,60 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
1705
1723
}
1706
1724
fprintf (stderr, " %.2f minutes\n " , total_seconds / 60.0 );
1707
1725
1708
- printf (" \n chunk PPL ln(PPL(Q)/PPL(base)) KL-Divergence\n " );
1726
+ printf (" \n chunk PPL ln(PPL(Q)/PPL(base)) KL-Divergence Same top \n " );
1709
1727
}
1710
1728
1711
1729
const int first = n_ctx/2 ;
1712
1730
const float * all_logits = num_batches > 1 ? logits.data () : llama_get_logits (ctx);
1713
1731
process_logits (n_vocab, all_logits + first*n_vocab, tokens.data () + start + first, n_ctx - 1 - first,
1714
- workers, log_probs_uint16, kld);
1732
+ workers, log_probs_uint16, kld, kld_ptr);
1733
+ kld_ptr += n_ctx - 1 - first;
1715
1734
1716
1735
auto ppl = mean_and_uncertainty (kld.sum_nll , kld.sum_nll2 , kld.count );
1717
1736
auto log_ppl_ratio = mean_and_uncertainty (kld.sum_nll_diff , kld.sum_nll_diff2 , kld.count );
1718
1737
auto kl_div = mean_and_uncertainty (kld.sum_kld , kld.sum_kld2 , kld.count );
1738
+ auto p_top = 1 .*kld.n_same_top /kld.count ;
1739
+ auto d_p_top = sqrt (p_top*(1 - p_top)/(kld.count - 1 ));
1719
1740
1720
- printf (" %4d %10.4lf %10.5lf ± %10.5f %10.5f ± %10.5lf\n " , i+1 , exp (ppl.first ),
1721
- log_ppl_ratio.first , log_ppl_ratio.second , kl_div.first , kl_div.second );
1741
+ printf (" %4d %10.4lf %10.5lf ± %10.5f %10.5f ± %10.5lf %.5f ± %.5f\n " , i+1 , exp (ppl.first ),
1742
+ log_ppl_ratio.first , log_ppl_ratio.second , kl_div.first , kl_div.second ,
1743
+ p_top, d_p_top);
1722
1744
1723
1745
fflush (stdout);
1724
1746
1725
1747
logits.clear ();
1726
1748
}
1727
1749
printf (" \n " );
1728
1750
1751
+ if (kld.count < 100 ) return ; // we do not wish to do statistics on so few values
1752
+
1753
+ std::sort (kld_values.begin (), kld_values.end ());
1754
+
1755
+ printf (" ===== KL-divergence statistics\n " );
1756
+ auto kl_div = mean_and_uncertainty (kld.sum_kld , kld.sum_kld2 , kld.count );
1757
+ printf (" Average: %10.6f ±%10.6lf\n " , kl_div.first , kl_div.second );
1758
+ auto kld_median = kld_values.size ()%2 == 0 ? 0 .5f *(kld_values[kld_values.size ()/2 ] + kld_values[kld_values.size ()/2 -1 ])
1759
+ : kld_values[kld_values.size ()/2 ];
1760
+ printf (" Median : %10.6f\n " , kld_median);
1761
+
1762
+ auto percentile = [&kld_values] (float fraction) {
1763
+ if (fraction <= 0 ) return kld_values.front ();
1764
+ if (fraction >= 1 ) return kld_values.back ();
1765
+ float p = fraction*(kld_values.size () - 1 );
1766
+ size_t ip = size_t (p); p -= ip;
1767
+ return (1 - p)*kld_values[ip] + p*kld_values[std::min (ip+1 , kld_values.size ()-1 )];
1768
+ };
1769
+
1770
+ printf (" Maximum: %10.6f\n " , kld_values.back ());
1771
+ printf (" KLD_99 : %10.6f\n " , percentile (0 .99f ));
1772
+ printf (" KLD_95 : %10.6f\n " , percentile (0 .95f ));
1773
+ printf (" KLD_90 : %10.6f\n " , percentile (0 .90f ));
1774
+
1775
+ printf (" Minimum: %10.6f\n " , kld_values.front ());
1776
+ printf (" KLD_01 : %10.6f\n " , percentile (0 .01f ));
1777
+ printf (" KLD_05 : %10.6f\n " , percentile (0 .05f ));
1778
+ printf (" KLD_10 : %10.6f\n " , percentile (0 .10f ));
1779
+
1729
1780
}
1730
1781
1731
1782
int main (int argc, char ** argv) {
0 commit comments