1
1
#include " utils.h"
2
- #include " ggml.h"
3
2
#include " llama.h"
4
3
5
4
#include < cassert>
@@ -65,79 +64,6 @@ void set_console_state(console_state new_st)
65
64
}
66
65
}
67
66
68
- std::vector<double > softmax (const std::vector<float >& logits) {
69
- std::vector<double > probs (logits.size ());
70
- float max_logit = logits[0 ];
71
- for (float v : logits) max_logit = std::max (max_logit, v);
72
- double sum_exp = 0.0 ;
73
- for (size_t i = 0 ; i < logits.size (); i++) {
74
- // Subtract the maximum logit value from the current logit value for numerical stability
75
- float logit = logits[i] - max_logit;
76
- double exp_logit = std::exp (logit);
77
- sum_exp += exp_logit;
78
- probs[i] = exp_logit;
79
- }
80
- for (size_t i = 0 ; i < probs.size (); i++) probs[i] /= sum_exp;
81
- return probs;
82
- }
83
-
84
- void perplexity (llama_context * ctx, const gpt_params & params) {
85
- // Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
86
- // Run `./main --perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
87
- // Output: `perplexity: 13.5106 [114/114]`
88
- auto tokens = ::llama_tokenize (ctx, params.prompt , true );
89
-
90
- int count = 0 ;
91
- double nll = 0.0 ;
92
- int seq_count = tokens.size () / params.n_ctx ;
93
-
94
- fprintf (stderr, " %s : calculating perplexity over %d chunks\n " , __func__, seq_count);
95
-
96
- for (int i = 0 ; i < seq_count; ++i) {
97
- int start = i * params.n_ctx ;
98
- int end = start + params.n_ctx - 1 ;
99
- std::vector<llama_token> embd (tokens.begin () + start, tokens.begin () + end);
100
- auto start_t = std::chrono::high_resolution_clock::now ();
101
- if (llama_eval (ctx, embd.data (), embd.size (), 0 , params.n_threads )) {
102
- fprintf (stderr, " %s : failed to eval\n " , __func__);
103
- return ;
104
- }
105
- auto end_t = std::chrono::high_resolution_clock::now ();
106
- if (i == 0 ) {
107
- double seconds = std::chrono::duration<double >(end_t - start_t ).count ();
108
- printf (" %.2f seconds per pass - ETA %.2f hours\n " , seconds, (seconds * seq_count) / (60.0 *60.0 ));
109
- }
110
- // We get the logits for all the tokens in the context window (params.n_ctx)
111
- // from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity,
112
- // calculate the perplexity over the last half the window (so the model always has
113
- // some context to predict the token).
114
- //
115
- // We rely on the fact that attention in the forward pass only looks at previous
116
- // tokens here, so the logits returned for each token are an accurate representation
117
- // of what the model would have predicted at that point.
118
- //
119
- // Example, we have a context window of 512, we will compute perplexity for each of the
120
- // last 256 tokens. Then, we split the input up into context window size chunks to
121
- // process the entire prompt.
122
-
123
- auto logits = llama_get_logits (ctx);
124
- for (int j = params.n_ctx / 2 ; j < params.n_ctx - 1 ; ++j) {
125
- // Calculate probability of next token, given the previous ones.
126
- int n_vocab = llama_n_vocab (ctx);
127
- std::vector<float > tok_logits (
128
- logits + j * n_vocab,
129
- logits + (j + 1 ) * n_vocab);
130
- double prob = softmax (tok_logits)[tokens[start + j + 1 ]];
131
- nll += -std::log (prob);
132
- ++count;
133
- }
134
- // perplexity is e^(average negative log-likelihood)
135
- printf (" [%d]%.4lf," , i + 1 , std::exp (nll / count));
136
- fflush (stdout);
137
- }
138
- printf (" \n " );
139
- }
140
-
141
67
static bool is_interacting = false ;
142
68
143
69
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
@@ -154,21 +80,7 @@ void sigint_handler(int signo) {
154
80
}
155
81
#endif
156
82
157
- int run (int argc, char ** argv) {
158
- // has to be called once at the start of the program to init ggml stuff
159
- ggml_time_init ();
160
-
161
- gpt_params params;
162
- params.model = " models/llama-7B/ggml-model.bin" ;
163
-
164
- if (gpt_params_parse (argc, argv, params) == false ) {
165
- return 1 ;
166
- }
167
-
168
- if (params.n_ctx > 2048 ) {
169
- fprintf (stderr, " %s: warning: model does not support context sizes greater than 2048 tokens (%d specified);"
170
- " expect poor results\n " , __func__, params.n_ctx );
171
- }
83
+ int run (llama_context * ctx, gpt_params params) {
172
84
173
85
if (params.seed <= 0 ) {
174
86
params.seed = time (NULL );
@@ -188,45 +100,13 @@ int run(int argc, char ** argv) {
188
100
// params.prompt = R"(// this function checks if the number n is prime
189
101
// bool is_prime(int n) {)";
190
102
191
- llama_context * ctx;
192
-
193
- // load the model
194
- {
195
- auto lparams = llama_context_default_params ();
196
-
197
- lparams.n_ctx = params.n_ctx ;
198
- lparams.n_parts = params.n_parts ;
199
- lparams.seed = params.seed ;
200
- lparams.f16_kv = params.memory_f16 ;
201
- lparams.logits_all = params.perplexity ;
202
-
203
- ctx = llama_init_from_file (params.model .c_str (), lparams);
204
-
205
- if (ctx == NULL ) {
206
- fprintf (stderr, " %s: error: failed to load model '%s'\n " , __func__, params.model .c_str ());
207
- return 1 ;
208
- }
209
- }
210
-
211
- // print system information
212
- {
213
- fprintf (stderr, " \n " );
214
- fprintf (stderr, " system_info: n_threads = %d / %d | %s\n " ,
215
- params.n_threads , std::thread::hardware_concurrency (), llama_print_system_info ());
216
- }
217
-
218
103
// determine the required inference memory per token:
219
104
// TODO: better way to do that
220
105
{
221
106
const std::vector<llama_token> tmp = { 0 , 1 , 2 , 3 };
222
107
llama_eval (ctx, tmp.data (), tmp.size (), 0 , params.n_threads );
223
108
}
224
109
225
- if (params.perplexity ) {
226
- perplexity (ctx, params);
227
- exit (0 );
228
- }
229
-
230
110
int n_past = 0 ;
231
111
232
112
// Add a space in front of the first character to match OG llama tokenizer behavior
0 commit comments