@@ -470,7 +470,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
470
470
prompt_lines.push_back (line);
471
471
}
472
472
473
- if ( prompt_lines.size () % 6 != 0 ) {
473
+ if ( prompt_lines.size () % 6 != 0 ) {
474
474
fprintf (stderr, " %s : number of lines in prompt not a multiple of 6.\n " , __func__);
475
475
return ;
476
476
}
@@ -485,7 +485,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
485
485
const bool add_bos = llama_should_add_bos_token (llama_get_model (ctx));
486
486
487
487
// Number of tasks to use when computing the score
488
- if ( params.hellaswag_tasks < hs_task_count ) {
488
+ if (params.hellaswag_tasks < hs_task_count) {
489
489
hs_task_count = params.hellaswag_tasks ;
490
490
}
491
491
@@ -502,178 +502,215 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
502
502
std::string ending[4 ];
503
503
size_t ending_logprob_count[4 ];
504
504
double ending_logprob[4 ];
505
+
506
+ size_t i_batch; // starting index in the llama_batch
507
+ size_t common_prefix; // max number of initial tokens that are the same in all sentences
508
+ size_t required_tokens; // needed number of tokens to evaluate all 4 endings
509
+ std::vector<llama_token> seq_tokens[4 ];
505
510
};
506
511
507
512
fprintf (stderr, " %s : selecting %zu %s tasks.\n " , __func__, hs_task_count, (randomize_tasks?" randomized" :" the first" ) );
508
513
509
514
// Select and read data from prompt lines
510
- hs_data_t * hs_data = new hs_data_t [ hs_task_count] ;
511
- for (size_t i= 0 ; i < hs_task_count; i++) {
515
+ std::vector< hs_data_t > hs_data ( hs_task_count) ;
516
+ for (size_t i = 0 ; i < hs_task_count; i++) {
512
517
size_t idx = i;
513
518
519
+ auto & hs_cur = hs_data[i];
520
+
514
521
// Select a random example of those left in the prompt
515
522
if (randomize_tasks) {
516
523
std::uniform_int_distribution<size_t > dist (0 , prompt_lines.size ()/6 -1 ) ;
517
524
idx = dist (rng);
518
525
}
519
526
520
- hs_data[i].context = prompt_lines[idx*6 ];
521
- hs_data[i].gold_ending_idx = std::stoi ( prompt_lines[idx*6 +1 ] );
522
- for (size_t j=0 ; j < 4 ; j++) {
523
- hs_data[i].ending [j] = prompt_lines[idx*6 +2 +j];
527
+ hs_cur.context = prompt_lines[idx*6 ];
528
+ hs_cur.gold_ending_idx = std::stoi ( prompt_lines[idx*6 +1 ] );
529
+ for (size_t j = 0 ; j < 4 ; j++) {
530
+ hs_cur.ending [j] = prompt_lines[idx*6 +2 +j];
531
+ hs_cur.seq_tokens [j] = ::llama_tokenize (ctx, hs_cur.context + " " + hs_cur.ending [j], add_bos);
524
532
}
525
533
534
+ // determine the common prefix of the endings
535
+ hs_cur.common_prefix = 0 ;
536
+ hs_cur.required_tokens = 0 ;
537
+ for (size_t k = 0 ; k < hs_cur.seq_tokens [0 ].size (); k++) {
538
+ if (hs_cur.seq_tokens [0 ][k] != hs_cur.seq_tokens [1 ][k] ||
539
+ hs_cur.seq_tokens [0 ][k] != hs_cur.seq_tokens [2 ][k] ||
540
+ hs_cur.seq_tokens [0 ][k] != hs_cur.seq_tokens [3 ][k]) {
541
+ break ;
542
+ }
543
+ hs_cur.common_prefix ++;
544
+ }
545
+ hs_cur.required_tokens = hs_cur.common_prefix +
546
+ hs_cur.seq_tokens [0 ].size () - hs_cur.common_prefix +
547
+ hs_cur.seq_tokens [1 ].size () - hs_cur.common_prefix +
548
+ hs_cur.seq_tokens [2 ].size () - hs_cur.common_prefix +
549
+ hs_cur.seq_tokens [3 ].size () - hs_cur.common_prefix ;
550
+
551
+ // GGML_ASSERT(hs_cur.common_prefix >= ::llama_tokenize(ctx, hs_cur.context, add_bos).size());
552
+
526
553
// Delete the selected random example from the prompt
527
554
if (randomize_tasks) {
528
555
prompt_lines.erase ( std::next (prompt_lines.begin (),idx*6 ) , std::next (prompt_lines.begin (),idx*6 +6 ) );
529
556
}
530
557
}
531
558
532
559
fprintf (stderr, " %s : calculating hellaswag score over selected tasks.\n " , __func__);
560
+
533
561
printf (" \n task\t acc_norm\n " );
534
562
535
563
double acc = 0 .0f ;
564
+
536
565
const int n_vocab = llama_n_vocab (llama_get_model (ctx));
537
- const int n_ctx = llama_n_ctx (ctx);
566
+ const int n_ctx = llama_n_ctx (ctx);
567
+ const int n_batch = params.n_batch ;
538
568
539
- std::vector<std::vector<int >> ending_tokens (4 );
569
+ const int max_tasks_per_batch = params.n_parallel ;
570
+ const int max_seq = 4 *max_tasks_per_batch;
540
571
541
- std::vector< float > tok_logits (n_vocab );
572
+ llama_batch batch = llama_batch_init (n_ctx, 0 , max_seq );
542
573
543
- for (size_t task_idx = 0 ; task_idx < hs_task_count; task_idx++) {
544
- // Tokenize the context to count tokens
545
- std::vector<int > context_embd = ::llama_tokenize (ctx, hs_data[task_idx].context , add_bos);
546
- size_t context_size = context_embd.size ();
547
-
548
- for (int i = 0 ; i < 4 ; ++i) {
549
- ending_tokens[i] = ::llama_tokenize (ctx, hs_data[task_idx].context + " " + hs_data[task_idx].ending [i], add_bos);
550
- for (int k = 0 ; k < int (context_size); ++k) {
551
- if (ending_tokens[i][k] != context_embd[k]) {
552
- fprintf (stderr, " Oops: ending %d of task %d differs from context at position %d\n " ,i,int (task_idx),k);
553
- break ;
554
- }
574
+ std::vector<float > tok_logits (n_vocab);
575
+ std::vector<float > batch_logits (n_ctx*n_vocab);
576
+
577
+ auto decode_helper = [&](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
578
+ for (int32_t i = 0 ; i < (int32_t ) batch.n_tokens ; i += n_batch) {
579
+ const int32_t n_tokens = std::min (n_batch, (int32_t ) (batch.n_tokens - i));
580
+
581
+ llama_batch batch_view = {
582
+ n_tokens,
583
+ batch.token + i,
584
+ nullptr ,
585
+ batch.pos + i,
586
+ batch.n_seq_id + i,
587
+ batch.seq_id + i,
588
+ batch.logits + i,
589
+ 0 , 0 , 0 , // unused
590
+ };
591
+
592
+ const int ret = llama_decode (ctx, batch_view);
593
+ if (ret != 0 ) {
594
+ LOG_TEE (" failed to decode the batch, n_batch = %d, ret = %d\n " , n_batch, ret);
595
+ return false ;
555
596
}
556
- }
557
597
558
- // Do the 1st ending
559
- // In this case we include the context when evaluating
560
- // auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], add_bos);
561
- auto query_embd = ending_tokens[0 ];
562
- auto query_size = query_embd.size ();
563
-
564
- // Stop if query wont fit the ctx window
565
- if (query_size > (size_t )n_ctx) {
566
- fprintf (stderr, " %s : number of tokens in query %zu > n_ctxl\n " , __func__, query_size);
567
- return ;
598
+ memcpy (batch_logits.data () + i*n_vocab, llama_get_logits (ctx), n_tokens*n_vocab*sizeof (float ));
568
599
}
569
600
570
- // Speedup small evaluations by evaluating atleast 32 tokens
571
- if (query_size < 32 ) {
572
- query_embd.resize (32 );
573
- }
601
+ return true ;
602
+ };
574
603
575
- // clear the KV cache
576
- llama_kv_cache_clear (ctx) ;
604
+ for ( size_t i0 = 0 ; i0 < hs_task_count; i0++) {
605
+ int n_cur = 0 ;
577
606
578
- auto logits = evaluate_tokens (ctx, query_embd, 0 , params.n_batch , n_vocab);
579
- if (logits.empty ()) {
580
- fprintf (stderr, " %s : failed to eval\n " , __func__);
581
- return ;
582
- }
607
+ size_t i1 = i0;
608
+ size_t i_batch = 0 ; // this tells us where in `llama_batch` we are currently
609
+
610
+ llama_batch_clear (batch);
583
611
584
- std::memcpy (tok_logits.data (), logits.data () + (context_size-1 )*n_vocab, n_vocab*sizeof (float ));
585
- const auto first_probs = softmax (tok_logits);
612
+ // batch as much tasks as possible into the available context
613
+ // each task has 4 unique seuqnce ids - one for each ending
614
+ // the common prefix is shared among the 4 sequences to save tokens
615
+ // we extract logits only from the last common token and from all ending tokens of each sequence
616
+ while (n_cur + (int ) hs_data[i1].required_tokens <= n_ctx) {
617
+ auto & hs_cur = hs_data[i1];
586
618
587
- hs_data[task_idx].ending_logprob_count [0 ] = 1 ;
588
- hs_data[task_idx].ending_logprob [0 ] = std::log (first_probs[query_embd[context_size]]);
619
+ const int s0 = 4 *(i1 - i0);
620
+ if (s0 + 4 > max_seq) {
621
+ break ;
622
+ }
589
623
590
- // Calculate the logprobs over the ending
591
- for (size_t j = context_size; j < query_size - 1 ; j++) {
624
+ for (size_t i = 0 ; i < hs_cur.common_prefix ; ++i) {
625
+ llama_batch_add (batch, hs_cur.seq_tokens [0 ][i], i, { s0 + 0 , s0 + 1 , s0 + 2 , s0 + 3 }, false );
626
+ }
627
+ batch.logits [batch.n_tokens - 1 ] = true ; // we need logits for the last token of the common prefix
592
628
593
- std::memcpy (tok_logits.data (), logits.data () + j*n_vocab, n_vocab*sizeof (float ));
629
+ for (int s = 0 ; s < 4 ; ++s) {
630
+ for (size_t i = hs_cur.common_prefix ; i < hs_cur.seq_tokens [s].size (); ++i) {
631
+ llama_batch_add (batch, hs_cur.seq_tokens [s][i], i, { s0 + s }, true );
632
+ }
633
+ }
594
634
595
- const float prob = softmax (tok_logits)[query_embd[j + 1 ]];
635
+ hs_cur.i_batch = i_batch;
636
+ i_batch += hs_cur.required_tokens ;
596
637
597
- hs_data[task_idx].ending_logprob [0 ] += std::log (prob);
598
- hs_data[task_idx].ending_logprob_count [0 ]++;
638
+ n_cur += hs_data[i1].required_tokens ;
639
+ if (++i1 == hs_task_count) {
640
+ break ;
641
+ }
599
642
}
600
643
601
- // Calculate the mean token logprob for acc_norm
602
- hs_data[task_idx].ending_logprob [0 ] /= hs_data[task_idx].ending_logprob_count [0 ];
644
+ if (i0 == i1) {
645
+ fprintf (stderr, " %s : task %zu does not fit in the context window\n " , __func__, i0);
646
+ return ;
647
+ }
603
648
604
- // Do the remaining endings
605
- // For these, we use the bare ending with n_past = context_size
606
- //
607
- for (size_t ending_idx = 1 ; ending_idx < 4 ; ending_idx++) {
649
+ llama_kv_cache_clear (ctx);
608
650
609
- // Tokenize the query
610
- query_embd.resize (ending_tokens[ending_idx].size () - context_size);
611
- std::memcpy (query_embd.data (), ending_tokens[ending_idx].data () + context_size, query_embd.size ()*sizeof (int ));
612
- query_size = query_embd.size ();
651
+ // decode all tasks [i0, i1)
652
+ if (!decode_helper (ctx, batch, n_batch)) {
653
+ fprintf (stderr, " %s: llama_decode() failed\n " , __func__);
654
+ return ;
655
+ }
613
656
614
- // Stop if query wont fit the ctx window
615
- if (context_size + query_size > (size_t )n_ctx) {
616
- fprintf (stderr, " %s : number of tokens in query %zu > n_ctxl\n " , __func__, query_size);
617
- return ;
618
- }
657
+ // compute the logprobs for each ending of the decoded tasks
658
+ for (size_t i = i0; i < i1; ++i) {
659
+ auto & hs_cur = hs_data[i];
619
660
620
- // Speedup small evaluations by evaluating atleast 32 tokens
621
- // No, resizing to 32 is actually slightly slower (at least on CUDA)
622
- // if (query_size < 32) {
623
- // query_embd.resize(32);
624
- // }
661
+ std::memcpy (tok_logits.data (), batch_logits.data () + n_vocab*(hs_cur.i_batch + hs_cur.common_prefix - 1 ), n_vocab*sizeof (float ));
625
662
626
- // Evaluate the query
627
- logits = evaluate_tokens (ctx, query_embd, context_size, params.n_batch , n_vocab);
628
- if (logits.empty ()) {
629
- fprintf (stderr, " %s : failed to eval\n " , __func__);
630
- return ;
631
- }
663
+ const auto first_probs = softmax (tok_logits);
632
664
633
- hs_data[task_idx].ending_logprob_count [ending_idx] = 1 ;
634
- hs_data[task_idx].ending_logprob [ending_idx] = std::log (first_probs[query_embd[0 ]]);
665
+ size_t li = hs_cur.common_prefix ; // logits index in the batch
635
666
636
- // Calculate the logprobs over the ending
637
- for ( size_t j = 0 ; j < query_size - 1 ; j++) {
638
- std::memcpy (tok_logits. data (), logits. data () + j*n_vocab, n_vocab* sizeof ( float ) );
667
+ for ( int s = 0 ; s < 4 ; ++s) {
668
+ hs_cur. ending_logprob_count [s] = 1 ;
669
+ hs_cur. ending_logprob [s] = std::log (first_probs[hs_cur. seq_tokens [s][hs_cur. common_prefix ]] );
639
670
640
- const float prob = softmax (tok_logits)[query_embd[j + 1 ]];
671
+ // Calculate the logprobs over the ending
672
+ for (size_t j = hs_cur.common_prefix ; j < hs_cur.seq_tokens [s].size () - 1 ; j++) {
673
+ std::memcpy (tok_logits.data (), batch_logits.data () + n_vocab*(hs_cur.i_batch + li++), n_vocab*sizeof (float ));
641
674
642
- hs_data[task_idx].ending_logprob [ending_idx] += std::log (prob);
643
- hs_data[task_idx].ending_logprob_count [ending_idx]++;
644
- }
675
+ const float prob = softmax (tok_logits)[hs_cur.seq_tokens [s][j + 1 ]];
645
676
646
- // Calculate the mean token logprob for acc_norm
647
- hs_data[task_idx].ending_logprob [ending_idx] /= hs_data[task_idx].ending_logprob_count [ending_idx];
677
+ hs_cur.ending_logprob [s] += std::log (prob);
678
+ hs_cur.ending_logprob_count [s]++;
679
+ }
648
680
681
+ // account that we skip the last token in the ending
682
+ ++li;
649
683
650
- // printf("task %lu, ending %lu, whole_len %lu, context_len %lu, ending_logprob_count %lu, ending_logprob %.4f\n",
651
- // task_idx,ending_idx,whole_size,context_size, hs_data[task_idx].ending_logprob_count[ending_idx], hs_data[task_idx].ending_logprob[ending_idx] ) ;
652
- }
684
+ // Calculate the mean token logprob for acc_norm
685
+ hs_cur. ending_logprob [s] /= hs_cur. ending_logprob_count [s] ;
686
+ }
653
687
654
- // Find the ending with maximum logprob
655
- size_t ending_logprob_max_idx = 0 ;
656
- double ending_logprob_max_val = hs_data[task_idx].ending_logprob [0 ];
657
- for (size_t j = 1 ; j < 4 ; j++) {
658
- if (hs_data[task_idx].ending_logprob [j] > ending_logprob_max_val) {
659
- ending_logprob_max_idx = j;
660
- ending_logprob_max_val = hs_data[task_idx].ending_logprob [j];
688
+ // Find the ending with maximum logprob
689
+ size_t ending_logprob_max_idx = 0 ;
690
+ double ending_logprob_max_val = hs_cur.ending_logprob [0 ];
691
+ for (size_t s = 1 ; s < 4 ; s++) {
692
+ if (hs_cur.ending_logprob [s] > ending_logprob_max_val) {
693
+ ending_logprob_max_idx = s;
694
+ ending_logprob_max_val = hs_cur.ending_logprob [s];
695
+ }
661
696
}
662
- }
663
697
664
- // printf("max logprob ending idx %lu, gold ending idx %lu\n", ending_logprob_max_idx, hs_data[task_idx].gold_ending_idx);
698
+ // printf("max logprob ending idx %lu, gold ending idx %lu\n", ending_logprob_max_idx, hs_cur.gold_ending_idx);
699
+
700
+ // If the gold ending got the maximum logprobe add one accuracy point
701
+ if (ending_logprob_max_idx == hs_cur.gold_ending_idx ) {
702
+ acc += 1.0 ;
703
+ }
665
704
666
- // If the gold ending got the maximum logprobe add one accuracy point
667
- if (ending_logprob_max_idx == hs_data[task_idx]. gold_ending_idx ) {
668
- acc += 1.0 ;
705
+ // Print the accumulated accuracy mean x 100
706
+ printf ( " %zu \t %.8lf \n " , i + 1 , acc/ double (i + 1 )* 100.0 );
707
+ fflush (stdout) ;
669
708
}
670
709
671
- // Print the accumulated accuracy mean x 100
672
- printf (" %zu\t %.8lf\n " ,task_idx+1 , acc/double (task_idx+1 )*100.0 );
673
- fflush (stdout);
710
+ i0 = i1 - 1 ;
674
711
}
675
712
676
- delete [] hs_data ;
713
+ llama_batch_free (batch) ;
677
714
678
715
printf (" \n " );
679
716
}
0 commit comments