5
5
#include < cmath>
6
6
#include < ctime>
7
7
#include < sstream>
8
+ #include < cstring>
8
9
9
10
#if defined(_MSC_VER)
10
11
#pragma warning(disable: 4244 4267) // possible loss of data
@@ -121,6 +122,27 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
121
122
printf (" \n " );
122
123
}
123
124
125
+ std::vector<float > hellaswag_evaluate_tokens (llama_context * ctx, const std::vector<int >& tokens, int n_past, int n_batch,
126
+ int n_vocab, int n_thread) {
127
+ std::vector<float > result;
128
+ result.reserve (tokens.size () * n_vocab);
129
+ size_t n_chunk = (tokens.size () + n_batch - 1 )/n_batch;
130
+ for (size_t i_chunk = 0 ; i_chunk < n_chunk; ++i_chunk) {
131
+ size_t n_tokens = tokens.size () - i_chunk * n_batch;
132
+ n_tokens = std::min (n_tokens, size_t (n_batch));
133
+ if (llama_eval (ctx, tokens.data () + i_chunk * n_batch, n_tokens, n_past, n_thread)) {
134
+ fprintf (stderr, " %s : failed to eval\n " , __func__);
135
+ return {};
136
+ }
137
+
138
+ const auto logits = llama_get_logits (ctx);
139
+ result.insert (result.end (), logits, logits + n_tokens * n_vocab);
140
+
141
+ n_past += n_tokens;
142
+ }
143
+ return result;
144
+ }
145
+
124
146
void hellaswag_score (llama_context * ctx, const gpt_params & params) {
125
147
// Calculates hellaswag score (acc_norm) from prompt
126
148
//
@@ -209,50 +231,93 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
209
231
double acc = 0 .0f ;
210
232
const int n_vocab = llama_n_vocab (ctx);
211
233
234
+ std::vector<float > tok_logits (n_vocab);
235
+
212
236
for (size_t task_idx = 0 ; task_idx < hs_task_count; task_idx++) {
213
237
214
238
// Tokenize the context to count tokens
215
239
std::vector<int > context_embd = ::llama_tokenize (ctx, hs_data[task_idx].context , prepend_bos);
216
240
size_t context_size = context_embd.size ();
217
241
218
- for (size_t ending_idx=0 ;ending_idx<4 ;ending_idx++) {
242
+ // Do the 1st ending
243
+ // In this case we include the context when evaluating
244
+ auto query_embd = ::llama_tokenize (ctx, hs_data[task_idx].context + hs_data[task_idx].ending [0 ], prepend_bos);
245
+ auto query_size = query_embd.size ();
246
+ // printf("First query: %d\n",(int)query_size);
247
+
248
+ // Stop if query wont fit the ctx window
249
+ if (query_size > (size_t )params.n_ctx ) {
250
+ fprintf (stderr, " %s : number of tokens in query %zu > n_ctxl\n " , __func__, query_size);
251
+ return ;
252
+ }
253
+
254
+ // Speedup small evaluations by evaluating atleast 32 tokens
255
+ if (query_size < 32 ) {
256
+ query_embd.resize (32 );
257
+ }
258
+
259
+ auto logits = hellaswag_evaluate_tokens (ctx, query_embd, 0 , params.n_batch , n_vocab, params.n_threads );
260
+ if (logits.empty ()) {
261
+ fprintf (stderr, " %s : failed to eval\n " , __func__);
262
+ return ;
263
+ }
264
+
265
+ std::memcpy (tok_logits.data (), logits.data () + (context_size-1 )*n_vocab, n_vocab*sizeof (float ));
266
+ const auto first_probs = softmax (tok_logits);
267
+
268
+ hs_data[task_idx].ending_logprob_count [0 ] = 1 ;
269
+ hs_data[task_idx].ending_logprob [0 ] = std::log (first_probs[query_embd[context_size]]);
270
+
271
+ // Calculate the logprobs over the ending
272
+ for (size_t j = context_size; j < query_size - 1 ; j++) {
273
+
274
+ std::memcpy (tok_logits.data (), logits.data () + j*n_vocab, n_vocab*sizeof (float ));
275
+
276
+ const float prob = softmax (tok_logits)[query_embd[j + 1 ]];
277
+
278
+ hs_data[task_idx].ending_logprob [0 ] += std::log (prob);
279
+ hs_data[task_idx].ending_logprob_count [0 ]++;
280
+ }
281
+
282
+ // Calculate the mean token logprob for acc_norm
283
+ hs_data[task_idx].ending_logprob [0 ] /= hs_data[task_idx].ending_logprob_count [0 ];
284
+
285
+ // Do the remaining endings
286
+ // For these, we use the bare ending with n_past = context_size
287
+ //
288
+ for (size_t ending_idx = 1 ; ending_idx < 4 ; ending_idx++) {
219
289
220
290
// Tokenize the query
221
- std::vector< int > query_embd = ::llama_tokenize (ctx, hs_data[task_idx].context + hs_data[task_idx]. ending [ending_idx], prepend_bos );
222
- size_t query_size = query_embd.size ();
291
+ query_embd = ::llama_tokenize (ctx, hs_data[task_idx].ending [ending_idx], false );
292
+ query_size = query_embd.size ();
223
293
224
294
// Stop if query wont fit the ctx window
225
- if (query_size > (size_t )params.n_ctx ) {
295
+ if (context_size + query_size > (size_t )params.n_ctx ) {
226
296
fprintf (stderr, " %s : number of tokens in query %zu > n_ctxl\n " , __func__, query_size);
227
297
return ;
228
298
}
229
299
230
300
// Speedup small evaluations by evaluating atleast 32 tokens
231
- if (query_size < 32 ) {
232
- query_embd.resize (32 );
233
- }
301
+ // No, resizing to 32 is actually slightly slower (at least on CUDA)
302
+ // if (query_size < 32) {
303
+ // query_embd.resize(32);
304
+ // }
234
305
235
306
// Evaluate the query
236
- if (llama_eval (ctx, query_embd.data (), query_embd.size (), 0 , params.n_threads )) {
307
+ logits = hellaswag_evaluate_tokens (ctx, query_embd, context_size, params.n_batch , n_vocab, params.n_threads );
308
+ if (logits.empty ()) {
237
309
fprintf (stderr, " %s : failed to eval\n " , __func__);
238
310
return ;
239
311
}
240
312
241
- const auto query_logits = llama_get_logits (ctx);
242
- std::vector<float > logits;
243
- logits.insert (logits.end (), query_logits, query_logits + query_size * n_vocab);
244
-
245
- hs_data[task_idx].ending_logprob_count [ending_idx] = 0 ;
246
- hs_data[task_idx].ending_logprob [ending_idx] = 0 .0f ;
313
+ hs_data[task_idx].ending_logprob_count [ending_idx] = 1 ;
314
+ hs_data[task_idx].ending_logprob [ending_idx] = std::log (first_probs[query_embd[0 ]]);
247
315
248
316
// Calculate the logprobs over the ending
249
- for (size_t j = context_size-1 ; j < query_size - 1 ; j++) {
250
- // Calculate probability of next token, given the previous ones.
251
- const std::vector<float > tok_logits (
252
- logits.begin () + (j + 0 ) * n_vocab,
253
- logits.begin () + (j + 1 ) * n_vocab);
317
+ for (size_t j = 0 ; j < query_size - 1 ; j++) {
318
+ std::memcpy (tok_logits.data (), logits.data () + j*n_vocab, n_vocab*sizeof (float ));
254
319
255
- const float prob = softmax (tok_logits)[query_embd[ j + 1 ]];
320
+ const float prob = softmax (tok_logits)[query_embd[j + 1 ]];
256
321
257
322
hs_data[task_idx].ending_logprob [ending_idx] += std::log (prob);
258
323
hs_data[task_idx].ending_logprob_count [ending_idx]++;
@@ -267,9 +332,9 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
267
332
}
268
333
269
334
// Find the ending with maximum logprob
270
- size_t ending_logprob_max_idx = - 1 ;
271
- double ending_logprob_max_val = -INFINITY ;
272
- for (size_t j= 0 ; j < 4 ; j++) {
335
+ size_t ending_logprob_max_idx = 0 ;
336
+ double ending_logprob_max_val = hs_data[task_idx]. ending_logprob [ 0 ] ;
337
+ for (size_t j = 1 ; j < 4 ; j++) {
273
338
if (hs_data[task_idx].ending_logprob [j] > ending_logprob_max_val) {
274
339
ending_logprob_max_idx = j;
275
340
ending_logprob_max_val = hs_data[task_idx].ending_logprob [j];
0 commit comments