Skip to content

Commit ad19812

Browse files
authored
perplexity : faster HellaSwag via batching (#5017)
* perplexity : faster HellaSwag ggml-ci * perplexity : clean-up ggml-ci * perplexity : no need for decode_helper ggml-ci * perplexity : add comments * perplexity : option to specify max batched tasks via `n_parallel` * perplexity : remove HellaSwag restruction for n_batch
1 parent 682986a commit ad19812

File tree

1 file changed

+148
-111
lines changed

1 file changed

+148
-111
lines changed

examples/perplexity/perplexity.cpp

Lines changed: 148 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
470470
prompt_lines.push_back(line);
471471
}
472472

473-
if( prompt_lines.size() % 6 != 0) {
473+
if (prompt_lines.size() % 6 != 0) {
474474
fprintf(stderr, "%s : number of lines in prompt not a multiple of 6.\n", __func__);
475475
return;
476476
}
@@ -485,7 +485,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
485485
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
486486

487487
// Number of tasks to use when computing the score
488-
if ( params.hellaswag_tasks < hs_task_count ) {
488+
if (params.hellaswag_tasks < hs_task_count) {
489489
hs_task_count = params.hellaswag_tasks;
490490
}
491491

@@ -502,178 +502,215 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
502502
std::string ending[4];
503503
size_t ending_logprob_count[4];
504504
double ending_logprob[4];
505+
506+
size_t i_batch; // starting index in the llama_batch
507+
size_t common_prefix; // max number of initial tokens that are the same in all sentences
508+
size_t required_tokens; // needed number of tokens to evaluate all 4 endings
509+
std::vector<llama_token> seq_tokens[4];
505510
};
506511

507512
fprintf(stderr, "%s : selecting %zu %s tasks.\n", __func__, hs_task_count, (randomize_tasks?"randomized":"the first") );
508513

509514
// Select and read data from prompt lines
510-
hs_data_t *hs_data = new hs_data_t[hs_task_count];
511-
for (size_t i=0; i < hs_task_count; i++) {
515+
std::vector<hs_data_t> hs_data(hs_task_count);
516+
for (size_t i = 0; i < hs_task_count; i++) {
512517
size_t idx = i;
513518

519+
auto & hs_cur = hs_data[i];
520+
514521
// Select a random example of those left in the prompt
515522
if (randomize_tasks) {
516523
std::uniform_int_distribution<size_t> dist(0, prompt_lines.size()/6-1 ) ;
517524
idx = dist(rng);
518525
}
519526

520-
hs_data[i].context = prompt_lines[idx*6];
521-
hs_data[i].gold_ending_idx = std::stoi( prompt_lines[idx*6+1] );
522-
for (size_t j=0; j < 4; j++) {
523-
hs_data[i].ending[j] = prompt_lines[idx*6+2+j];
527+
hs_cur.context = prompt_lines[idx*6];
528+
hs_cur.gold_ending_idx = std::stoi( prompt_lines[idx*6+1] );
529+
for (size_t j = 0; j < 4; j++) {
530+
hs_cur.ending[j] = prompt_lines[idx*6+2+j];
531+
hs_cur.seq_tokens[j] = ::llama_tokenize(ctx, hs_cur.context + " " + hs_cur.ending[j], add_bos);
524532
}
525533

534+
// determine the common prefix of the endings
535+
hs_cur.common_prefix = 0;
536+
hs_cur.required_tokens = 0;
537+
for (size_t k = 0; k < hs_cur.seq_tokens[0].size(); k++) {
538+
if (hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[1][k] ||
539+
hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[2][k] ||
540+
hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[3][k]) {
541+
break;
542+
}
543+
hs_cur.common_prefix++;
544+
}
545+
hs_cur.required_tokens = hs_cur.common_prefix +
546+
hs_cur.seq_tokens[0].size() - hs_cur.common_prefix +
547+
hs_cur.seq_tokens[1].size() - hs_cur.common_prefix +
548+
hs_cur.seq_tokens[2].size() - hs_cur.common_prefix +
549+
hs_cur.seq_tokens[3].size() - hs_cur.common_prefix;
550+
551+
//GGML_ASSERT(hs_cur.common_prefix >= ::llama_tokenize(ctx, hs_cur.context, add_bos).size());
552+
526553
// Delete the selected random example from the prompt
527554
if (randomize_tasks) {
528555
prompt_lines.erase( std::next(prompt_lines.begin(),idx*6) , std::next(prompt_lines.begin(),idx*6+6) );
529556
}
530557
}
531558

532559
fprintf(stderr, "%s : calculating hellaswag score over selected tasks.\n", __func__);
560+
533561
printf("\ntask\tacc_norm\n");
534562

535563
double acc = 0.0f;
564+
536565
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
537-
const int n_ctx = llama_n_ctx(ctx);
566+
const int n_ctx = llama_n_ctx(ctx);
567+
const int n_batch = params.n_batch;
538568

539-
std::vector<std::vector<int>> ending_tokens(4);
569+
const int max_tasks_per_batch = params.n_parallel;
570+
const int max_seq = 4*max_tasks_per_batch;
540571

541-
std::vector<float> tok_logits(n_vocab);
572+
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
542573

543-
for (size_t task_idx = 0; task_idx < hs_task_count; task_idx++) {
544-
// Tokenize the context to count tokens
545-
std::vector<int> context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, add_bos);
546-
size_t context_size = context_embd.size();
547-
548-
for (int i = 0; i < 4; ++i) {
549-
ending_tokens[i] = ::llama_tokenize(ctx, hs_data[task_idx].context + " " + hs_data[task_idx].ending[i], add_bos);
550-
for (int k = 0; k < int(context_size); ++k) {
551-
if (ending_tokens[i][k] != context_embd[k]) {
552-
fprintf(stderr, "Oops: ending %d of task %d differs from context at position %d\n",i,int(task_idx),k);
553-
break;
554-
}
574+
std::vector<float> tok_logits(n_vocab);
575+
std::vector<float> batch_logits(n_ctx*n_vocab);
576+
577+
auto decode_helper = [&](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
578+
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
579+
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
580+
581+
llama_batch batch_view = {
582+
n_tokens,
583+
batch.token + i,
584+
nullptr,
585+
batch.pos + i,
586+
batch.n_seq_id + i,
587+
batch.seq_id + i,
588+
batch.logits + i,
589+
0, 0, 0, // unused
590+
};
591+
592+
const int ret = llama_decode(ctx, batch_view);
593+
if (ret != 0) {
594+
LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
595+
return false;
555596
}
556-
}
557597

558-
// Do the 1st ending
559-
// In this case we include the context when evaluating
560-
//auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], add_bos);
561-
auto query_embd = ending_tokens[0];
562-
auto query_size = query_embd.size();
563-
564-
// Stop if query wont fit the ctx window
565-
if (query_size > (size_t)n_ctx) {
566-
fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
567-
return;
598+
memcpy(batch_logits.data() + i*n_vocab, llama_get_logits(ctx), n_tokens*n_vocab*sizeof(float));
568599
}
569600

570-
// Speedup small evaluations by evaluating atleast 32 tokens
571-
if (query_size < 32) {
572-
query_embd.resize(32);
573-
}
601+
return true;
602+
};
574603

575-
// clear the KV cache
576-
llama_kv_cache_clear(ctx);
604+
for (size_t i0 = 0; i0 < hs_task_count; i0++) {
605+
int n_cur = 0;
577606

578-
auto logits = evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab);
579-
if (logits.empty()) {
580-
fprintf(stderr, "%s : failed to eval\n", __func__);
581-
return;
582-
}
607+
size_t i1 = i0;
608+
size_t i_batch = 0; // this tells us where in `llama_batch` we are currently
609+
610+
llama_batch_clear(batch);
583611

584-
std::memcpy(tok_logits.data(), logits.data() + (context_size-1)*n_vocab, n_vocab*sizeof(float));
585-
const auto first_probs = softmax(tok_logits);
612+
// batch as much tasks as possible into the available context
613+
// each task has 4 unique seuqnce ids - one for each ending
614+
// the common prefix is shared among the 4 sequences to save tokens
615+
// we extract logits only from the last common token and from all ending tokens of each sequence
616+
while (n_cur + (int) hs_data[i1].required_tokens <= n_ctx) {
617+
auto & hs_cur = hs_data[i1];
586618

587-
hs_data[task_idx].ending_logprob_count[0] = 1;
588-
hs_data[task_idx].ending_logprob[0] = std::log(first_probs[query_embd[context_size]]);
619+
const int s0 = 4*(i1 - i0);
620+
if (s0 + 4 > max_seq) {
621+
break;
622+
}
589623

590-
// Calculate the logprobs over the ending
591-
for (size_t j = context_size; j < query_size - 1; j++) {
624+
for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
625+
llama_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
626+
}
627+
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
592628

593-
std::memcpy(tok_logits.data(), logits.data() + j*n_vocab, n_vocab*sizeof(float));
629+
for (int s = 0; s < 4; ++s) {
630+
for (size_t i = hs_cur.common_prefix; i < hs_cur.seq_tokens[s].size(); ++i) {
631+
llama_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, true);
632+
}
633+
}
594634

595-
const float prob = softmax(tok_logits)[query_embd[j + 1]];
635+
hs_cur.i_batch = i_batch;
636+
i_batch += hs_cur.required_tokens;
596637

597-
hs_data[task_idx].ending_logprob[0] += std::log(prob);
598-
hs_data[task_idx].ending_logprob_count[0]++;
638+
n_cur += hs_data[i1].required_tokens;
639+
if (++i1 == hs_task_count) {
640+
break;
641+
}
599642
}
600643

601-
// Calculate the mean token logprob for acc_norm
602-
hs_data[task_idx].ending_logprob[0] /= hs_data[task_idx].ending_logprob_count[0];
644+
if (i0 == i1) {
645+
fprintf(stderr, "%s : task %zu does not fit in the context window\n", __func__, i0);
646+
return;
647+
}
603648

604-
// Do the remaining endings
605-
// For these, we use the bare ending with n_past = context_size
606-
//
607-
for (size_t ending_idx = 1; ending_idx < 4; ending_idx++) {
649+
llama_kv_cache_clear(ctx);
608650

609-
// Tokenize the query
610-
query_embd.resize(ending_tokens[ending_idx].size() - context_size);
611-
std::memcpy(query_embd.data(), ending_tokens[ending_idx].data() + context_size, query_embd.size()*sizeof(int));
612-
query_size = query_embd.size();
651+
// decode all tasks [i0, i1)
652+
if (!decode_helper(ctx, batch, n_batch)) {
653+
fprintf(stderr, "%s: llama_decode() failed\n", __func__);
654+
return;
655+
}
613656

614-
// Stop if query wont fit the ctx window
615-
if (context_size + query_size > (size_t)n_ctx) {
616-
fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
617-
return;
618-
}
657+
// compute the logprobs for each ending of the decoded tasks
658+
for (size_t i = i0; i < i1; ++i) {
659+
auto & hs_cur = hs_data[i];
619660

620-
// Speedup small evaluations by evaluating atleast 32 tokens
621-
// No, resizing to 32 is actually slightly slower (at least on CUDA)
622-
//if (query_size < 32) {
623-
// query_embd.resize(32);
624-
//}
661+
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(hs_cur.i_batch + hs_cur.common_prefix - 1), n_vocab*sizeof(float));
625662

626-
// Evaluate the query
627-
logits = evaluate_tokens(ctx, query_embd, context_size, params.n_batch, n_vocab);
628-
if (logits.empty()) {
629-
fprintf(stderr, "%s : failed to eval\n", __func__);
630-
return;
631-
}
663+
const auto first_probs = softmax(tok_logits);
632664

633-
hs_data[task_idx].ending_logprob_count[ending_idx] = 1;
634-
hs_data[task_idx].ending_logprob[ending_idx] = std::log(first_probs[query_embd[0]]);
665+
size_t li = hs_cur.common_prefix; // logits index in the batch
635666

636-
// Calculate the logprobs over the ending
637-
for (size_t j = 0; j < query_size - 1; j++) {
638-
std::memcpy(tok_logits.data(), logits.data() + j*n_vocab, n_vocab*sizeof(float));
667+
for (int s = 0; s < 4; ++s) {
668+
hs_cur.ending_logprob_count[s] = 1;
669+
hs_cur.ending_logprob[s] = std::log(first_probs[hs_cur.seq_tokens[s][hs_cur.common_prefix]]);
639670

640-
const float prob = softmax(tok_logits)[query_embd[j + 1]];
671+
// Calculate the logprobs over the ending
672+
for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) {
673+
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(hs_cur.i_batch + li++), n_vocab*sizeof(float));
641674

642-
hs_data[task_idx].ending_logprob[ending_idx] += std::log(prob);
643-
hs_data[task_idx].ending_logprob_count[ending_idx]++;
644-
}
675+
const float prob = softmax(tok_logits)[hs_cur.seq_tokens[s][j + 1]];
645676

646-
// Calculate the mean token logprob for acc_norm
647-
hs_data[task_idx].ending_logprob[ending_idx] /= hs_data[task_idx].ending_logprob_count[ending_idx];
677+
hs_cur.ending_logprob[s] += std::log(prob);
678+
hs_cur.ending_logprob_count[s]++;
679+
}
648680

681+
// account that we skip the last token in the ending
682+
++li;
649683

650-
// printf("task %lu, ending %lu, whole_len %lu, context_len %lu, ending_logprob_count %lu, ending_logprob %.4f\n",
651-
// task_idx,ending_idx,whole_size,context_size, hs_data[task_idx].ending_logprob_count[ending_idx], hs_data[task_idx].ending_logprob[ending_idx] );
652-
}
684+
// Calculate the mean token logprob for acc_norm
685+
hs_cur.ending_logprob[s] /= hs_cur.ending_logprob_count[s];
686+
}
653687

654-
// Find the ending with maximum logprob
655-
size_t ending_logprob_max_idx = 0;
656-
double ending_logprob_max_val = hs_data[task_idx].ending_logprob[0];
657-
for (size_t j = 1; j < 4; j++) {
658-
if (hs_data[task_idx].ending_logprob[j] > ending_logprob_max_val) {
659-
ending_logprob_max_idx = j;
660-
ending_logprob_max_val = hs_data[task_idx].ending_logprob[j];
688+
// Find the ending with maximum logprob
689+
size_t ending_logprob_max_idx = 0;
690+
double ending_logprob_max_val = hs_cur.ending_logprob[0];
691+
for (size_t s = 1; s < 4; s++) {
692+
if (hs_cur.ending_logprob[s] > ending_logprob_max_val) {
693+
ending_logprob_max_idx = s;
694+
ending_logprob_max_val = hs_cur.ending_logprob[s];
695+
}
661696
}
662-
}
663697

664-
// printf("max logprob ending idx %lu, gold ending idx %lu\n", ending_logprob_max_idx, hs_data[task_idx].gold_ending_idx);
698+
//printf("max logprob ending idx %lu, gold ending idx %lu\n", ending_logprob_max_idx, hs_cur.gold_ending_idx);
699+
700+
// If the gold ending got the maximum logprobe add one accuracy point
701+
if (ending_logprob_max_idx == hs_cur.gold_ending_idx) {
702+
acc += 1.0;
703+
}
665704

666-
// If the gold ending got the maximum logprobe add one accuracy point
667-
if (ending_logprob_max_idx == hs_data[task_idx].gold_ending_idx) {
668-
acc += 1.0;
705+
// Print the accumulated accuracy mean x 100
706+
printf("%zu\t%.8lf\n", i + 1, acc/double(i + 1)*100.0);
707+
fflush(stdout);
669708
}
670709

671-
// Print the accumulated accuracy mean x 100
672-
printf("%zu\t%.8lf\n",task_idx+1, acc/double(task_idx+1)*100.0);
673-
fflush(stdout);
710+
i0 = i1 - 1;
674711
}
675712

676-
delete [] hs_data;
713+
llama_batch_free(batch);
677714

678715
printf("\n");
679716
}

0 commit comments

Comments
 (0)