Skip to content

Commit b7f1fa6

Browse files
committed
Move llama_context setup + perplexity back to main.cpp
Signed-off-by: Thiago Padilha <[email protected]>
1 parent d7d53b8 commit b7f1fa6

File tree

3 files changed

+128
-123
lines changed

3 files changed

+128
-123
lines changed

main.cpp

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,127 @@
11
#include "run.h"
2+
#include "ggml.h"
3+
4+
5+
std::vector<double> softmax(const std::vector<float>& logits) {
6+
std::vector<double> probs(logits.size());
7+
float max_logit = logits[0];
8+
for (float v : logits) max_logit = std::max(max_logit, v);
9+
double sum_exp = 0.0;
10+
for (size_t i = 0; i < logits.size(); i++) {
11+
// Subtract the maximum logit value from the current logit value for numerical stability
12+
float logit = logits[i] - max_logit;
13+
double exp_logit = std::exp(logit);
14+
sum_exp += exp_logit;
15+
probs[i] = exp_logit;
16+
}
17+
for (size_t i = 0; i < probs.size(); i++) probs[i] /= sum_exp;
18+
return probs;
19+
}
20+
21+
void perplexity(llama_context * ctx, const gpt_params & params) {
22+
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
23+
// Run `./main --perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
24+
// Output: `perplexity: 13.5106 [114/114]`
25+
auto tokens = ::llama_tokenize(ctx, params.prompt, true);
26+
27+
int count = 0;
28+
double nll = 0.0;
29+
int seq_count = tokens.size() / params.n_ctx;
30+
31+
fprintf(stderr, "%s : calculating perplexity over %d chunks\n", __func__, seq_count);
32+
33+
for (int i = 0; i < seq_count; ++i) {
34+
int start = i * params.n_ctx;
35+
int end = start + params.n_ctx - 1;
36+
std::vector<llama_token> embd(tokens.begin() + start, tokens.begin() + end);
37+
auto start_t = std::chrono::high_resolution_clock::now();
38+
if (llama_eval(ctx, embd.data(), embd.size(), 0, params.n_threads)) {
39+
fprintf(stderr, "%s : failed to eval\n", __func__);
40+
return;
41+
}
42+
auto end_t = std::chrono::high_resolution_clock::now();
43+
if (i == 0) {
44+
double seconds = std::chrono::duration<double>(end_t - start_t).count();
45+
printf("%.2f seconds per pass - ETA %.2f hours\n", seconds, (seconds * seq_count) / (60.0*60.0));
46+
}
47+
// We get the logits for all the tokens in the context window (params.n_ctx)
48+
// from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity,
49+
// calculate the perplexity over the last half the window (so the model always has
50+
// some context to predict the token).
51+
//
52+
// We rely on the fact that attention in the forward pass only looks at previous
53+
// tokens here, so the logits returned for each token are an accurate representation
54+
// of what the model would have predicted at that point.
55+
//
56+
// Example, we have a context window of 512, we will compute perplexity for each of the
57+
// last 256 tokens. Then, we split the input up into context window size chunks to
58+
// process the entire prompt.
59+
60+
auto logits = llama_get_logits(ctx);
61+
for (int j = params.n_ctx / 2; j < params.n_ctx - 1; ++j) {
62+
// Calculate probability of next token, given the previous ones.
63+
int n_vocab = llama_n_vocab(ctx);
64+
std::vector<float> tok_logits(
65+
logits + j * n_vocab,
66+
logits + (j + 1) * n_vocab);
67+
double prob = softmax(tok_logits)[tokens[start + j + 1]];
68+
nll += -std::log(prob);
69+
++count;
70+
}
71+
// perplexity is e^(average negative log-likelihood)
72+
printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
73+
fflush(stdout);
74+
}
75+
printf("\n");
76+
}
277

378
int main(int argc, char ** argv) {
4-
return run(argc, argv);
79+
// has to be called once at the start of the program to init ggml stuff
80+
ggml_time_init();
81+
82+
gpt_params params;
83+
params.model = "models/llama-7B/ggml-model.bin";
84+
85+
if (gpt_params_parse(argc, argv, params) == false) {
86+
return 1;
87+
}
88+
89+
if (params.n_ctx > 2048) {
90+
fprintf(stderr, "%s: warning: model does not support context sizes greater than 2048 tokens (%d specified);"
91+
"expect poor results\n", __func__, params.n_ctx);
92+
}
93+
94+
llama_context * ctx;
95+
96+
// load the model
97+
{
98+
auto lparams = llama_context_default_params();
99+
100+
lparams.n_ctx = params.n_ctx;
101+
lparams.n_parts = params.n_parts;
102+
lparams.seed = params.seed;
103+
lparams.f16_kv = params.memory_f16;
104+
lparams.logits_all = params.perplexity;
105+
106+
ctx = llama_init_from_file(params.model.c_str(), lparams);
107+
108+
if (ctx == NULL) {
109+
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
110+
return 1;
111+
}
112+
}
113+
114+
// print system information
115+
{
116+
fprintf(stderr, "\n");
117+
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
118+
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
119+
}
120+
121+
if (params.perplexity) {
122+
perplexity(ctx, params);
123+
exit(0);
124+
}
125+
126+
return run(ctx, params);
5127
}

run.cpp

Lines changed: 1 addition & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#include "utils.h"
2-
#include "ggml.h"
32
#include "llama.h"
43

54
#include <cassert>
@@ -65,79 +64,6 @@ void set_console_state(console_state new_st)
6564
}
6665
}
6766

68-
std::vector<double> softmax(const std::vector<float>& logits) {
69-
std::vector<double> probs(logits.size());
70-
float max_logit = logits[0];
71-
for (float v : logits) max_logit = std::max(max_logit, v);
72-
double sum_exp = 0.0;
73-
for (size_t i = 0; i < logits.size(); i++) {
74-
// Subtract the maximum logit value from the current logit value for numerical stability
75-
float logit = logits[i] - max_logit;
76-
double exp_logit = std::exp(logit);
77-
sum_exp += exp_logit;
78-
probs[i] = exp_logit;
79-
}
80-
for (size_t i = 0; i < probs.size(); i++) probs[i] /= sum_exp;
81-
return probs;
82-
}
83-
84-
void perplexity(llama_context * ctx, const gpt_params & params) {
85-
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
86-
// Run `./main --perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
87-
// Output: `perplexity: 13.5106 [114/114]`
88-
auto tokens = ::llama_tokenize(ctx, params.prompt, true);
89-
90-
int count = 0;
91-
double nll = 0.0;
92-
int seq_count = tokens.size() / params.n_ctx;
93-
94-
fprintf(stderr, "%s : calculating perplexity over %d chunks\n", __func__, seq_count);
95-
96-
for (int i = 0; i < seq_count; ++i) {
97-
int start = i * params.n_ctx;
98-
int end = start + params.n_ctx - 1;
99-
std::vector<llama_token> embd(tokens.begin() + start, tokens.begin() + end);
100-
auto start_t = std::chrono::high_resolution_clock::now();
101-
if (llama_eval(ctx, embd.data(), embd.size(), 0, params.n_threads)) {
102-
fprintf(stderr, "%s : failed to eval\n", __func__);
103-
return;
104-
}
105-
auto end_t = std::chrono::high_resolution_clock::now();
106-
if (i == 0) {
107-
double seconds = std::chrono::duration<double>(end_t - start_t).count();
108-
printf("%.2f seconds per pass - ETA %.2f hours\n", seconds, (seconds * seq_count) / (60.0*60.0));
109-
}
110-
// We get the logits for all the tokens in the context window (params.n_ctx)
111-
// from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity,
112-
// calculate the perplexity over the last half the window (so the model always has
113-
// some context to predict the token).
114-
//
115-
// We rely on the fact that attention in the forward pass only looks at previous
116-
// tokens here, so the logits returned for each token are an accurate representation
117-
// of what the model would have predicted at that point.
118-
//
119-
// Example, we have a context window of 512, we will compute perplexity for each of the
120-
// last 256 tokens. Then, we split the input up into context window size chunks to
121-
// process the entire prompt.
122-
123-
auto logits = llama_get_logits(ctx);
124-
for (int j = params.n_ctx / 2; j < params.n_ctx - 1; ++j) {
125-
// Calculate probability of next token, given the previous ones.
126-
int n_vocab = llama_n_vocab(ctx);
127-
std::vector<float> tok_logits(
128-
logits + j * n_vocab,
129-
logits + (j + 1) * n_vocab);
130-
double prob = softmax(tok_logits)[tokens[start + j + 1]];
131-
nll += -std::log(prob);
132-
++count;
133-
}
134-
// perplexity is e^(average negative log-likelihood)
135-
printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
136-
fflush(stdout);
137-
}
138-
printf("\n");
139-
}
140-
14167
static bool is_interacting = false;
14268

14369
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
@@ -154,21 +80,7 @@ void sigint_handler(int signo) {
15480
}
15581
#endif
15682

157-
int run(int argc, char ** argv) {
158-
// has to be called once at the start of the program to init ggml stuff
159-
ggml_time_init();
160-
161-
gpt_params params;
162-
params.model = "models/llama-7B/ggml-model.bin";
163-
164-
if (gpt_params_parse(argc, argv, params) == false) {
165-
return 1;
166-
}
167-
168-
if (params.n_ctx > 2048) {
169-
fprintf(stderr, "%s: warning: model does not support context sizes greater than 2048 tokens (%d specified);"
170-
"expect poor results\n", __func__, params.n_ctx);
171-
}
83+
int run(llama_context * ctx, gpt_params params) {
17284

17385
if (params.seed <= 0) {
17486
params.seed = time(NULL);
@@ -188,45 +100,13 @@ int run(int argc, char ** argv) {
188100
// params.prompt = R"(// this function checks if the number n is prime
189101
//bool is_prime(int n) {)";
190102

191-
llama_context * ctx;
192-
193-
// load the model
194-
{
195-
auto lparams = llama_context_default_params();
196-
197-
lparams.n_ctx = params.n_ctx;
198-
lparams.n_parts = params.n_parts;
199-
lparams.seed = params.seed;
200-
lparams.f16_kv = params.memory_f16;
201-
lparams.logits_all = params.perplexity;
202-
203-
ctx = llama_init_from_file(params.model.c_str(), lparams);
204-
205-
if (ctx == NULL) {
206-
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
207-
return 1;
208-
}
209-
}
210-
211-
// print system information
212-
{
213-
fprintf(stderr, "\n");
214-
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
215-
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
216-
}
217-
218103
// determine the required inference memory per token:
219104
// TODO: better way to do that
220105
{
221106
const std::vector<llama_token> tmp = { 0, 1, 2, 3 };
222107
llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads);
223108
}
224109

225-
if (params.perplexity) {
226-
perplexity(ctx, params);
227-
exit(0);
228-
}
229-
230110
int n_past = 0;
231111

232112
// Add a space in front of the first character to match OG llama tokenizer behavior

run.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
#pragma once
22

3-
int run(int argc, char ** argv);
3+
#include "llama.h"
4+
#include "utils.h"
5+
6+
int run(llama_context * ctx, gpt_params params);

0 commit comments

Comments
 (0)