Skip to content

Commit e702206

Browse files
authored
perplexity : fix integer overflow (#9783)
* perplexity : fix integer overflow ggml-ci * perplexity : keep n_vocab as int and make appropriate casts ggml-ci
1 parent 3dc48fe commit e702206

File tree

1 file changed

+49
-36
lines changed

1 file changed

+49
-36
lines changed

examples/perplexity/perplexity.cpp

Lines changed: 49 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ static void process_logits(
169169
break;
170170
}
171171
lock.unlock();
172-
const results_log_softmax results = log_softmax(n_vocab, logits + i*n_vocab, tokens[i+1]);
172+
const results_log_softmax results = log_softmax(n_vocab, logits + size_t(i)*n_vocab, tokens[i+1]);
173173
const double v = -results.log_softmax;
174174
local_nll += v;
175175
local_nll2 += v*v;
@@ -203,7 +203,7 @@ static void process_logits(std::ostream& out, int n_vocab, const float * logits,
203203
break;
204204
}
205205
lock.unlock();
206-
const double v = log_softmax(n_vocab, logits + i*n_vocab, log_probs.data() + i*nv, tokens[i+1]);
206+
const double v = log_softmax(n_vocab, logits + size_t(i)*n_vocab, log_probs.data() + i*nv, tokens[i+1]);
207207
local_nll += v;
208208
local_nll2 += v*v;
209209
}
@@ -281,7 +281,9 @@ static std::pair<double, float> log_softmax(int n_vocab, const float * logits, c
281281
kld.sum_kld += sum;
282282
kld.sum_kld2 += sum*sum;
283283
++kld.count;
284-
if (imax == imax_base) ++kld.n_same_top;
284+
if (imax == imax_base) {
285+
++kld.n_same_top;
286+
}
285287

286288
const float p_base = expf(-nll_base);
287289
const float p = expf(-nll);
@@ -323,7 +325,7 @@ static void process_logits(int n_vocab, const float * logits, const int * tokens
323325
break;
324326
}
325327
lock.unlock();
326-
std::pair<double, float> v = log_softmax(n_vocab, logits + i*n_vocab, base_log_probs.data() + i*nv, tokens[i+1], local_kld);
328+
std::pair<double, float> v = log_softmax(n_vocab, logits + size_t(i)*n_vocab, base_log_probs.data() + i*nv, tokens[i+1], local_kld);
327329
kld_values[i] = (float)v.first;
328330
p_diff_values[i] = v.second;
329331
}
@@ -383,9 +385,10 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
383385
const int n_chunk_max = (tokens.size() - calc_chunk + params.ppl_stride - 1) / params.ppl_stride;
384386

385387
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
386-
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
387388
const int n_batch = params.n_batch;
388389

390+
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
391+
389392
int count = 0;
390393
double nll = 0.0;
391394

@@ -424,8 +427,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
424427
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
425428
}
426429

427-
const auto batch_logits = llama_get_logits(ctx);
428-
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
430+
const auto * batch_logits = llama_get_logits(ctx);
431+
logits.insert(logits.end(), batch_logits, batch_logits + size_t(batch_size) * n_vocab);
429432

430433
if (j == 0) {
431434
tokens[batch_start] = token_org;
@@ -447,11 +450,10 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
447450

448451
//LOG_DBG("%s: using tokens %d...%d\n",__func__,params.n_ctx - params.ppl_stride + start, params.n_ctx + start);
449452
for (int j = n_ctx - params.ppl_stride - 1; j < n_ctx - 1; ++j) {
450-
451453
// Calculate probability of next token, given the previous ones.
452454
const std::vector<float> tok_logits(
453-
logits.begin() + (j + 0) * n_vocab,
454-
logits.begin() + (j + 1) * n_vocab);
455+
logits.begin() + size_t(j + 0) * n_vocab,
456+
logits.begin() + size_t(j + 1) * n_vocab);
455457

456458
const float prob = softmax(tok_logits)[tokens[start + j + 1]];
457459
logit_history[start + j + 1] = tok_logits[tokens[start + j + 1]];
@@ -521,9 +523,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
521523
const int n_chunk_max = tokens.size() / n_ctx;
522524

523525
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
524-
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
525526
const int n_batch = params.n_batch;
526527

528+
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
529+
527530
int count = 0;
528531
double nll = 0.0;
529532
double nll2 = 0.0;
@@ -538,7 +541,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
538541

539542
std::vector<float> logits;
540543
if (num_batches > 1) {
541-
logits.reserve((size_t)n_ctx * n_vocab);
544+
logits.reserve(size_t(n_ctx) * n_vocab);
542545
}
543546

544547
LOG_INF("%s: calculating perplexity over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq);
@@ -620,7 +623,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
620623

621624
if (num_batches > 1 && n_outputs > 0) {
622625
const auto * batch_logits = llama_get_logits(ctx);
623-
logits.insert(logits.end(), batch_logits, batch_logits + n_outputs * n_vocab);
626+
logits.insert(logits.end(), batch_logits, batch_logits + size_t(n_outputs) * n_vocab);
624627
}
625628
}
626629

@@ -661,7 +664,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
661664
} else {
662665
double av = nll/count;
663666
double av2 = nll2/count - av*av;
664-
if (av2 > 0) av2 = sqrt(av2/(count-1));
667+
if (av2 > 0) {
668+
av2 = sqrt(av2/(count-1));
669+
}
665670
LOG("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
666671
}
667672
}
@@ -686,10 +691,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
686691
return {tokens, ppl, logit_history, prob_history};
687692
}
688693

689-
static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int32_t n_batch, int32_t n_vocab) {
694+
static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int n_batch, int n_vocab) {
690695
int prev_outputs = 0;
691-
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
692-
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
696+
for (int i = 0; i < (int) batch.n_tokens; i += n_batch) {
697+
const int n_tokens = std::min<int>(n_batch, batch.n_tokens - i);
693698

694699
llama_batch batch_view = {
695700
n_tokens,
@@ -713,7 +718,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
713718
n_outputs += batch_view.logits[i] != 0;
714719
}
715720

716-
memcpy(batch_logits.data() + prev_outputs*n_vocab, llama_get_logits(ctx), n_outputs*n_vocab*sizeof(float));
721+
memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float));
717722

718723
prev_outputs += n_outputs;
719724
}
@@ -728,19 +733,23 @@ static void compute_logprobs(const float * batch_logits, int n_vocab, std::vecto
728733
if (eval_results.size() != eval_pairs.size()) {
729734
eval_results.resize(eval_pairs.size());
730735
}
731-
if (eval_pairs.empty()) return;
736+
if (eval_pairs.empty()) {
737+
return;
738+
}
732739

733740
size_t max_threads = std::min((eval_pairs.size() + K_TOKEN_CHUNK - 1)/K_TOKEN_CHUNK, workers.size());
734741

735742
std::atomic<int> counter(0);
736743
auto compute = [&counter, &eval_pairs, &eval_results, batch_logits, n_vocab] () {
737744
float local_logprobs[K_TOKEN_CHUNK];
738745
while (true) {
739-
size_t first = counter.fetch_add(K_TOKEN_CHUNK, std::memory_order_relaxed);
740-
if (first >= eval_results.size()) break;
741-
size_t last = std::min(first + K_TOKEN_CHUNK, eval_results.size());
746+
const size_t first = counter.fetch_add(K_TOKEN_CHUNK, std::memory_order_relaxed);
747+
if (first >= eval_results.size()) {
748+
break;
749+
}
750+
const size_t last = std::min(first + K_TOKEN_CHUNK, eval_results.size());
742751
for (size_t i = first; i < last; ++i) {
743-
auto logits = batch_logits + eval_pairs[i].first * n_vocab;
752+
const auto * logits = batch_logits + eval_pairs[i].first * n_vocab;
744753
float max_logit = logits[0];
745754
for (int j = 1; j < n_vocab; ++j) {
746755
max_logit = std::max(max_logit, logits[j]);
@@ -877,18 +886,19 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
877886

878887
double acc = 0.0f;
879888

880-
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
881889
const int n_ctx = llama_n_ctx(ctx);
882890
const int n_batch = params.n_batch;
883891

892+
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
893+
884894
const int max_tasks_per_batch = 32;
885895
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
886896

887897
llama_batch batch = llama_batch_init(n_ctx, 0, 4);
888898

889899
std::vector<float> tok_logits(n_vocab);
890900
// TODO: this could be made smaller; it's currently the worst-case size
891-
std::vector<float> batch_logits(n_vocab*n_ctx);
901+
std::vector<float> batch_logits(size_t(n_ctx)*n_vocab);
892902

893903
std::vector<std::pair<size_t, llama_token>> eval_pairs;
894904
std::vector<float> eval_results;
@@ -975,7 +985,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
975985
auto & hs_cur = hs_data[i];
976986

977987
// get the logits of the last token of the common prefix
978-
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*hs_cur.i_logits, n_vocab*sizeof(float));
988+
std::memcpy(tok_logits.data(), batch_logits.data() + hs_cur.i_logits*n_vocab, n_vocab*sizeof(float));
979989

980990
const auto first_probs = softmax(tok_logits);
981991

@@ -1158,18 +1168,19 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
11581168

11591169
LOG_INF("%s : calculating winogrande score over selected tasks.\n", __func__);
11601170

1161-
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
11621171
const int n_ctx = llama_n_ctx(ctx);
11631172
const int n_batch = params.n_batch;
11641173

1174+
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
1175+
11651176
const int max_tasks_per_batch = 128;
11661177
const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
11671178

11681179
llama_batch batch = llama_batch_init(n_ctx, 0, 2);
11691180

11701181
std::vector<float> tok_logits(n_vocab);
11711182
// TODO: this could be made smaller; it's currently the worst-case size
1172-
std::vector<float> batch_logits(n_vocab*n_ctx);
1183+
std::vector<float> batch_logits(size_t(n_ctx)*n_vocab);
11731184

11741185
std::vector<std::pair<size_t, llama_token>> eval_pairs;
11751186
std::vector<float> eval_results;
@@ -1509,17 +1520,18 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
15091520

15101521
LOG("\ntask\tacc_norm\n");
15111522

1512-
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
15131523
const int n_ctx = llama_n_ctx(ctx);
15141524
const int n_batch = params.n_batch;
15151525

1526+
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
1527+
15161528
const int max_tasks_per_batch = 32;
15171529
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
15181530

15191531
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
15201532

15211533
std::vector<float> tok_logits(n_vocab);
1522-
std::vector<float> batch_logits(n_vocab*n_ctx);
1534+
std::vector<float> batch_logits(size_t(n_ctx)*n_vocab);
15231535

15241536
std::vector<std::pair<size_t, llama_token>> eval_pairs;
15251537
std::vector<float> eval_results;
@@ -1627,7 +1639,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
16271639
//LOG("\n common_prefix: %zu\n", cur_task.common_prefix);
16281640

16291641
// get the logits of the last token of the common prefix
1630-
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*cur_task.i_logits, n_vocab*sizeof(float));
1642+
std::memcpy(tok_logits.data(), batch_logits.data() + cur_task.i_logits*n_vocab, n_vocab*sizeof(float));
16311643

16321644
const auto first_probs = softmax(tok_logits);
16331645

@@ -1709,7 +1721,8 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
17091721
__func__, params.logits_file.c_str(), n_ctx, params.n_ctx);
17101722
}
17111723

1712-
int n_vocab, n_chunk;
1724+
int n_vocab;
1725+
int n_chunk;
17131726
in.read((char *)&n_vocab, sizeof(n_vocab));
17141727
in.read((char *)&n_chunk, sizeof(n_chunk));
17151728
if (in.fail()) {
@@ -1720,7 +1733,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
17201733
LOG_ERR("%s: inconsistent vocabulary (%d vs %d)\n", __func__, n_vocab, llama_n_vocab(llama_get_model(ctx)));
17211734
}
17221735

1723-
std::vector<llama_token> tokens(n_ctx * n_chunk);
1736+
std::vector<llama_token> tokens(size_t(n_ctx) * n_chunk);
17241737
if (in.read((char *)tokens.data(), tokens.size()*sizeof(tokens[0])).fail()) {
17251738
LOG_ERR("%s: failed reading evaluation tokens from %s\n", __func__, params.logits_file.c_str());
17261739
return;
@@ -1737,7 +1750,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
17371750
std::vector<float> p_diff_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk);
17381751
std::vector<float> logits;
17391752
if (num_batches > 1) {
1740-
logits.reserve(n_ctx * n_vocab);
1753+
logits.reserve(size_t(n_ctx) * n_vocab);
17411754
}
17421755

17431756
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
@@ -1801,7 +1814,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
18011814

18021815
if (num_batches > 1) {
18031816
const auto * batch_logits = llama_get_logits(ctx);
1804-
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
1817+
logits.insert(logits.end(), batch_logits, batch_logits + size_t(batch_size) * n_vocab);
18051818
}
18061819
}
18071820

@@ -1822,7 +1835,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
18221835

18231836
const int first = n_ctx/2;
18241837
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
1825-
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
1838+
process_logits(n_vocab, all_logits + size_t(first)*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
18261839
workers, log_probs_uint16, kld, kld_ptr, p_diff_ptr);
18271840
p_diff_ptr += n_ctx - 1 - first;
18281841
kld_ptr += n_ctx - 1 - first;

0 commit comments

Comments
 (0)