Skip to content

Commit d138fcf

Browse files
ikawrakowKawrakow
authored andcommitted
Additional KL-divergence statistics (ggml-org#5081)
* perplexity: add top-token probability * perplexity: add additional KL-divergence statistics * perplexity: a better organized KL-divergence statistics output --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent ebf24a1 commit d138fcf

File tree

1 file changed

+61
-10
lines changed

1 file changed

+61
-10
lines changed

examples/perplexity/perplexity.cpp

Lines changed: 61 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -222,13 +222,18 @@ struct kl_divergence_result {
222222
double sum_kld2 = 0;
223223
double sum_nll_diff = 0;
224224
double sum_nll_diff2 = 0;
225+
size_t n_same_top = 0;
225226
size_t count = 0;
226227
};
227228

228-
static void log_softmax(int n_vocab, const float * logits, const uint16_t * base_log_prob, int tok, kl_divergence_result & kld) {
229+
static double log_softmax(int n_vocab, const float * logits, const uint16_t * base_log_prob, int tok, kl_divergence_result & kld) {
229230
float max_logit = logits[0];
231+
int imax = 0;
230232
for (int i = 1; i < n_vocab; ++i) {
231-
max_logit = std::max(max_logit, logits[i]);
233+
if (logits[i] > max_logit) {
234+
max_logit = logits[i];
235+
imax = i;
236+
}
232237
}
233238
double sum_exp = 0.0;
234239
for (int i = 0; i < n_vocab; ++i) {
@@ -247,8 +252,14 @@ static void log_softmax(int n_vocab, const float * logits, const uint16_t * base
247252
kld.sum_nll_diff2 += nll*nll;
248253
max_logit += log_sum_exp;
249254
double sum = 0;
255+
int imax_base = -1;
256+
float p_log_base_max = 0;
250257
for (int i = 0; i < n_vocab; ++i) {
251258
const float p_log_base = scale*base_log_prob[i] + min_log_prob;
259+
if (i == 0 || p_log_base > p_log_base_max) {
260+
p_log_base_max = p_log_base;
261+
imax_base = i;
262+
}
252263
if (p_log_base > -16.f) {
253264
const float p_base = expf(p_log_base);
254265
sum += p_base * (p_log_base - logits[i] + max_logit);
@@ -257,14 +268,17 @@ static void log_softmax(int n_vocab, const float * logits, const uint16_t * base
257268
kld.sum_kld += sum;
258269
kld.sum_kld2 += sum*sum;
259270
++kld.count;
271+
if (imax == imax_base) ++kld.n_same_top;
272+
return sum;
260273
}
261274

262275
static void process_logits(int n_vocab, const float * logits, const int * tokens, int n_token,
263-
std::vector<std::thread> & workers, const std::vector<uint16_t> & base_log_probs, kl_divergence_result & kld) {
276+
std::vector<std::thread> & workers, const std::vector<uint16_t> & base_log_probs, kl_divergence_result & kld,
277+
float * kld_values) {
264278
std::mutex mutex;
265279
const int nv = 2*((n_vocab + 1)/2) + 4;
266280
int counter = 0;
267-
auto compute = [&mutex, &counter, &base_log_probs, &kld, n_vocab, logits, tokens, n_token, nv] () {
281+
auto compute = [&mutex, &counter, &base_log_probs, &kld, n_vocab, logits, tokens, n_token, nv, kld_values] () {
268282
kl_divergence_result local_kld;
269283
while (true) {
270284
std::unique_lock<std::mutex> lock(mutex);
@@ -276,11 +290,13 @@ static void process_logits(int n_vocab, const float * logits, const int * tokens
276290
kld.sum_kld2 += local_kld.sum_kld2;
277291
kld.sum_nll_diff += local_kld.sum_nll_diff;
278292
kld.sum_nll_diff2 += local_kld.sum_nll_diff2;
293+
kld.n_same_top += local_kld.n_same_top;
279294
kld.count += local_kld.count;
280295
break;
281296
}
282297
lock.unlock();
283-
log_softmax(n_vocab, logits + i*n_vocab, base_log_probs.data() + i*nv, tokens[i+1], local_kld);
298+
double v = log_softmax(n_vocab, logits + i*n_vocab, base_log_probs.data() + i*nv, tokens[i+1], local_kld);
299+
kld_values[i] = (float)v;
284300
}
285301
};
286302
for (auto & w : workers) {
@@ -1615,7 +1631,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
16151631
in.read((char *)&n_vocab, sizeof(n_vocab));
16161632
in.read((char *)&n_chunk, sizeof(n_chunk));
16171633
if (in.fail()) {
1618-
fprintf(stderr, "%s: failed rwading n_vocab, n_chunk from %s\n", __func__, params.logits_file.c_str());
1634+
fprintf(stderr, "%s: failed reading n_vocab, n_chunk from %s\n", __func__, params.logits_file.c_str());
16191635
return;
16201636
}
16211637
if (n_vocab != llama_n_vocab(llama_get_model(ctx))) {
@@ -1634,6 +1650,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
16341650
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
16351651

16361652
std::vector<uint16_t> log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv);
1653+
std::vector<float> kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk);
16371654
std::vector<float> logits;
16381655
if (num_batches > 1) {
16391656
logits.reserve(n_ctx * n_vocab);
@@ -1652,6 +1669,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
16521669
};
16531670

16541671
kl_divergence_result kld;
1672+
auto kld_ptr = kld_values.data();
16551673

16561674
for (int i = 0; i < n_chunk; ++i) {
16571675
const int start = i * n_ctx;
@@ -1705,27 +1723,60 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
17051723
}
17061724
fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
17071725

1708-
printf("\nchunk PPL ln(PPL(Q)/PPL(base)) KL-Divergence\n");
1726+
printf("\nchunk PPL ln(PPL(Q)/PPL(base)) KL-Divergence Same top\n");
17091727
}
17101728

17111729
const int first = n_ctx/2;
17121730
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
17131731
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
1714-
workers, log_probs_uint16, kld);
1732+
workers, log_probs_uint16, kld, kld_ptr);
1733+
kld_ptr += n_ctx - 1 - first;
17151734

17161735
auto ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count);
17171736
auto log_ppl_ratio = mean_and_uncertainty(kld.sum_nll_diff, kld.sum_nll_diff2, kld.count);
17181737
auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count);
1738+
auto p_top = 1.*kld.n_same_top/kld.count;
1739+
auto d_p_top = sqrt(p_top*(1 - p_top)/(kld.count - 1));
17191740

1720-
printf("%4d %10.4lf %10.5lf ± %10.5f %10.5f ± %10.5lf\n", i+1, exp(ppl.first),
1721-
log_ppl_ratio.first, log_ppl_ratio.second, kl_div.first, kl_div.second);
1741+
printf("%4d %10.4lf %10.5lf ± %10.5f %10.5f ± %10.5lf %.5f ± %.5f\n", i+1, exp(ppl.first),
1742+
log_ppl_ratio.first, log_ppl_ratio.second, kl_div.first, kl_div.second,
1743+
p_top, d_p_top);
17221744

17231745
fflush(stdout);
17241746

17251747
logits.clear();
17261748
}
17271749
printf("\n");
17281750

1751+
if (kld.count < 100) return; // we do not wish to do statistics on so few values
1752+
1753+
std::sort(kld_values.begin(), kld_values.end());
1754+
1755+
printf("===== KL-divergence statistics\n");
1756+
auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count);
1757+
printf("Average: %10.6f ±%10.6lf\n", kl_div.first, kl_div.second);
1758+
auto kld_median = kld_values.size()%2 == 0 ? 0.5f*(kld_values[kld_values.size()/2] + kld_values[kld_values.size()/2-1])
1759+
: kld_values[kld_values.size()/2];
1760+
printf("Median : %10.6f\n", kld_median);
1761+
1762+
auto percentile = [&kld_values] (float fraction) {
1763+
if (fraction <= 0) return kld_values.front();
1764+
if (fraction >= 1) return kld_values.back();
1765+
float p = fraction*(kld_values.size() - 1);
1766+
size_t ip = size_t(p); p -= ip;
1767+
return (1 - p)*kld_values[ip] + p*kld_values[std::min(ip+1, kld_values.size()-1)];
1768+
};
1769+
1770+
printf("Maximum: %10.6f\n", kld_values.back());
1771+
printf("KLD_99 : %10.6f\n", percentile(0.99f));
1772+
printf("KLD_95 : %10.6f\n", percentile(0.95f));
1773+
printf("KLD_90 : %10.6f\n", percentile(0.90f));
1774+
1775+
printf("Minimum: %10.6f\n", kld_values.front());
1776+
printf("KLD_01 : %10.6f\n", percentile(0.01f));
1777+
printf("KLD_05 : %10.6f\n", percentile(0.05f));
1778+
printf("KLD_10 : %10.6f\n", percentile(0.10f));
1779+
17291780
}
17301781

17311782
int main(int argc, char ** argv) {

0 commit comments

Comments
 (0)