@@ -534,98 +534,20 @@ struct llama_server_context
534
534
return result;
535
535
}
536
536
537
- // out of user input, sample next token
538
- const float temp = params.temp ;
539
- const int32_t top_k = params.top_k <= 0 ? llama_n_vocab (model) : params.top_k ;
540
- const float top_p = params.top_p ;
541
- const float tfs_z = params.tfs_z ;
542
- const float typical_p = params.typical_p ;
543
- const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n ;
544
- const float repeat_penalty = params.repeat_penalty ;
545
- const float alpha_presence = params.presence_penalty ;
546
- const float alpha_frequency = params.frequency_penalty ;
547
- const int mirostat = params.mirostat ;
548
- const float mirostat_tau = params.mirostat_tau ;
549
- const float mirostat_eta = params.mirostat_eta ;
550
- const bool penalize_nl = params.penalize_nl ;
551
- const int32_t n_probs = params.n_probs ;
552
-
553
537
{
554
- auto *logits = llama_get_logits (ctx);
555
- auto n_vocab = llama_n_vocab (model);
556
-
557
- // Apply params.logit_bias map
558
- for (const auto &it : params.logit_bias )
559
- {
560
- logits[it.first ] += it.second ;
561
- }
562
-
538
+ // out of user input, sample next token
563
539
std::vector<llama_token_data> candidates;
564
- candidates.reserve (n_vocab);
565
- for (llama_token token_id = 0 ; token_id < n_vocab; token_id++)
566
- {
567
- candidates.emplace_back (llama_token_data{token_id, logits[token_id], 0 .0f });
568
- }
540
+ candidates.reserve (llama_n_vocab (model));
569
541
570
- llama_token_data_array candidates_p = {candidates.data (), candidates.size (), false };
571
-
572
- // Apply penalties
573
- float nl_logit = logits[llama_token_nl (ctx)];
574
- auto last_n_repeat = std::min (std::min ((int )last_n_tokens.size (), repeat_last_n), n_ctx);
575
- llama_sample_repetition_penalty (ctx, &candidates_p,
576
- last_n_tokens.data () + last_n_tokens.size () - last_n_repeat,
577
- last_n_repeat, repeat_penalty);
578
- llama_sample_frequency_and_presence_penalties (ctx, &candidates_p,
579
- last_n_tokens.data () + last_n_tokens.size () - last_n_repeat,
580
- last_n_repeat, alpha_frequency, alpha_presence);
581
- if (!penalize_nl)
582
- {
583
- logits[llama_token_nl (ctx)] = nl_logit;
584
- }
542
+ result.tok = llama_sample_token (ctx, NULL , grammar, params, last_n_tokens, candidates);
585
543
586
- if (grammar != nullptr ) {
587
- llama_sample_grammar (ctx, &candidates_p, grammar);
588
- }
544
+ llama_token_data_array candidates_p = { candidates.data (), candidates.size (), false };
589
545
590
- if (temp <= 0 )
591
- {
592
- // Greedy sampling
593
- result.tok = llama_sample_token_greedy (ctx, &candidates_p);
594
- if (n_probs > 0 )
595
- {
596
- llama_sample_softmax (ctx, &candidates_p);
597
- }
598
- }
599
- else
546
+ const int32_t n_probs = params.n_probs ;
547
+ if (params.temp <= 0 && n_probs > 0 )
600
548
{
601
- if (mirostat == 1 )
602
- {
603
- static float mirostat_mu = 2 .0f * mirostat_tau;
604
- const int mirostat_m = 100 ;
605
- llama_sample_temp (ctx, &candidates_p, temp);
606
- result.tok = llama_sample_token_mirostat (ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
607
- }
608
- else if (mirostat == 2 )
609
- {
610
- static float mirostat_mu = 2 .0f * mirostat_tau;
611
- llama_sample_temp (ctx, &candidates_p, temp);
612
- result.tok = llama_sample_token_mirostat_v2 (ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
613
- }
614
- else
615
- {
616
- // Temperature sampling
617
- size_t min_keep = std::max (1 , n_probs);
618
- llama_sample_top_k (ctx, &candidates_p, top_k, min_keep);
619
- llama_sample_tail_free (ctx, &candidates_p, tfs_z, min_keep);
620
- llama_sample_typical (ctx, &candidates_p, typical_p, min_keep);
621
- llama_sample_top_p (ctx, &candidates_p, top_p, min_keep);
622
- llama_sample_temp (ctx, &candidates_p, temp);
623
- result.tok = llama_sample_token (ctx, &candidates_p);
624
- }
625
- }
626
-
627
- if (grammar != nullptr ) {
628
- llama_grammar_accept_token (ctx, grammar, result.tok );
549
+ // For llama_sample_token_greedy we need to sort candidates
550
+ llama_sample_softmax (ctx, &candidates_p);
629
551
}
630
552
631
553
for (size_t i = 0 ; i < std::min (candidates_p.size , (size_t )n_probs); ++i)
0 commit comments