Skip to content

Commit 3229586

Browse files
committed
perplexity : fix integer overflow
ggml-ci
1 parent 6374743 commit 3229586

File tree

1 file changed

+23
-15
lines changed

1 file changed

+23
-15
lines changed

examples/perplexity/perplexity.cpp

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ static std::vector<float> softmax(const std::vector<float>& logits) {
103103
return probs;
104104
}
105105

106-
static results_log_softmax log_softmax(int n_vocab, const float * logits, int tok) {
106+
static results_log_softmax log_softmax(int64_t n_vocab, const float * logits, int tok) {
107107
float max_logit = logits[0];
108108
for (int i = 1; i < n_vocab; ++i) {
109109
max_logit = std::max(max_logit, logits[i]);
@@ -122,7 +122,7 @@ static inline int nearest_int(float fval) {
122122
return (i & 0x007fffff) - 0x00400000;
123123
}
124124

125-
static double log_softmax(int n_vocab, const float * logits, uint16_t * log_prob, int tok) {
125+
static double log_softmax(int64_t n_vocab, const float * logits, uint16_t * log_prob, int tok) {
126126
float max_logit = logits[0];
127127
float min_logit = logits[0];
128128
for (int i = 1; i < n_vocab; ++i) {
@@ -153,7 +153,7 @@ static double log_softmax(int n_vocab, const float * logits, uint16_t * log_prob
153153
}
154154

155155
static void process_logits(
156-
int n_vocab, const float * logits, const int * tokens, int n_token, std::vector<std::thread> & workers,
156+
int64_t n_vocab, const float * logits, const int * tokens, int n_token, std::vector<std::thread> & workers,
157157
double & nll, double & nll2, float * logit_history, float * prob_history
158158
) {
159159
std::mutex mutex;
@@ -187,7 +187,7 @@ static void process_logits(
187187
}
188188
}
189189

190-
static void process_logits(std::ostream& out, int n_vocab, const float * logits, const int * tokens, int n_token,
190+
static void process_logits(std::ostream& out, int64_t n_vocab, const float * logits, const int * tokens, int n_token,
191191
std::vector<std::thread> & workers, std::vector<uint16_t> & log_probs, double & nll, double & nll2) {
192192
std::mutex mutex;
193193
const int nv = 2*((n_vocab + 1)/2) + 4;
@@ -234,7 +234,7 @@ struct kl_divergence_result {
234234
size_t count = 0.0;
235235
};
236236

237-
static std::pair<double, float> log_softmax(int n_vocab, const float * logits, const uint16_t * base_log_prob, int tok, kl_divergence_result & kld) {
237+
static std::pair<double, float> log_softmax(int64_t n_vocab, const float * logits, const uint16_t * base_log_prob, int tok, kl_divergence_result & kld) {
238238
float max_logit = logits[0];
239239
int imax = 0;
240240
for (int i = 1; i < n_vocab; ++i) {
@@ -281,7 +281,9 @@ static std::pair<double, float> log_softmax(int n_vocab, const float * logits, c
281281
kld.sum_kld += sum;
282282
kld.sum_kld2 += sum*sum;
283283
++kld.count;
284-
if (imax == imax_base) ++kld.n_same_top;
284+
if (imax == imax_base) {
285+
++kld.n_same_top;
286+
}
285287

286288
const float p_base = expf(-nll_base);
287289
const float p = expf(-nll);
@@ -295,7 +297,7 @@ static std::pair<double, float> log_softmax(int n_vocab, const float * logits, c
295297
return std::make_pair(sum, p_diff);
296298
}
297299

298-
static void process_logits(int n_vocab, const float * logits, const int * tokens, int n_token,
300+
static void process_logits(int64_t n_vocab, const float * logits, const int * tokens, int n_token,
299301
std::vector<std::thread> & workers, const std::vector<uint16_t> & base_log_probs, kl_divergence_result & kld,
300302
float * kld_values, float * p_diff_values) {
301303
std::mutex mutex;
@@ -383,9 +385,10 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
383385
const int n_chunk_max = (tokens.size() - calc_chunk + params.ppl_stride - 1) / params.ppl_stride;
384386

385387
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
386-
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
387388
const int n_batch = params.n_batch;
388389

390+
const int64_t n_vocab = llama_n_vocab(llama_get_model(ctx));
391+
389392
int count = 0;
390393
double nll = 0.0;
391394

@@ -521,9 +524,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
521524
const int n_chunk_max = tokens.size() / n_ctx;
522525

523526
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
524-
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
525527
const int n_batch = params.n_batch;
526528

529+
const int64_t n_vocab = llama_n_vocab(llama_get_model(ctx));
530+
527531
int count = 0;
528532
double nll = 0.0;
529533
double nll2 = 0.0;
@@ -723,7 +727,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
723727

724728
#define K_TOKEN_CHUNK 4
725729

726-
static void compute_logprobs(const float * batch_logits, int n_vocab, std::vector<std::thread>& workers,
730+
static void compute_logprobs(const float * batch_logits, int64_t n_vocab, std::vector<std::thread>& workers,
727731
const std::vector<std::pair<size_t, llama_token>>& eval_pairs, std::vector<float>& eval_results) {
728732
if (eval_results.size() != eval_pairs.size()) {
729733
eval_results.resize(eval_pairs.size());
@@ -877,10 +881,11 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
877881

878882
double acc = 0.0f;
879883

880-
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
881884
const int n_ctx = llama_n_ctx(ctx);
882885
const int n_batch = params.n_batch;
883886

887+
const int64_t n_vocab = llama_n_vocab(llama_get_model(ctx));
888+
884889
const int max_tasks_per_batch = 32;
885890
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
886891

@@ -1158,10 +1163,11 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
11581163

11591164
LOG_INF("%s : calculating winogrande score over selected tasks.\n", __func__);
11601165

1161-
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
11621166
const int n_ctx = llama_n_ctx(ctx);
11631167
const int n_batch = params.n_batch;
11641168

1169+
const int64_t n_vocab = llama_n_vocab(llama_get_model(ctx));
1170+
11651171
const int max_tasks_per_batch = 128;
11661172
const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
11671173

@@ -1509,10 +1515,11 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
15091515

15101516
LOG("\ntask\tacc_norm\n");
15111517

1512-
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
15131518
const int n_ctx = llama_n_ctx(ctx);
15141519
const int n_batch = params.n_batch;
15151520

1521+
const int64_t n_vocab = llama_n_vocab(llama_get_model(ctx));
1522+
15161523
const int max_tasks_per_batch = 32;
15171524
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
15181525

@@ -1709,15 +1716,16 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
17091716
__func__, params.logits_file.c_str(), n_ctx, params.n_ctx);
17101717
}
17111718

1712-
int n_vocab, n_chunk;
1719+
int64_t n_vocab;
1720+
int64_t n_chunk;
17131721
in.read((char *)&n_vocab, sizeof(n_vocab));
17141722
in.read((char *)&n_chunk, sizeof(n_chunk));
17151723
if (in.fail()) {
17161724
LOG_ERR("%s: failed reading n_vocab, n_chunk from %s\n", __func__, params.logits_file.c_str());
17171725
return;
17181726
}
17191727
if (n_vocab != llama_n_vocab(llama_get_model(ctx))) {
1720-
LOG_ERR("%s: inconsistent vocabulary (%d vs %d)\n", __func__, n_vocab, llama_n_vocab(llama_get_model(ctx)));
1728+
LOG_ERR("%s: inconsistent vocabulary (%lld vs %d)\n", __func__, n_vocab, llama_n_vocab(llama_get_model(ctx)));
17211729
}
17221730

17231731
std::vector<llama_token> tokens(n_ctx * n_chunk);

0 commit comments

Comments
 (0)