@@ -425,8 +425,9 @@ int main(int argc, char ** argv) {
425
425
LOG_TEE (" generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n " , n_ctx, params.n_batch , params.n_predict , params.n_keep );
426
426
LOG_TEE (" \n\n " );
427
427
428
+ struct llama_grammar * grammar = NULL ;
428
429
grammar_parser::parse_state parsed_grammar;
429
- llama_grammar * grammar = NULL ;
430
+
430
431
if (!params.grammar .empty ()) {
431
432
parsed_grammar = grammar_parser::parse (params.grammar .c_str ());
432
433
// will be empty (default) if there are parse errors
@@ -450,8 +451,8 @@ int main(int argc, char ** argv) {
450
451
}
451
452
452
453
// TODO: replace with ring-buffer
453
- std::vector<llama_token> last_n_tokens (n_ctx);
454
- std::fill (last_n_tokens .begin (), last_n_tokens .end (), 0 );
454
+ std::vector<llama_token> last_tokens (n_ctx);
455
+ std::fill (last_tokens .begin (), last_tokens .end (), 0 );
455
456
456
457
if (params.interactive ) {
457
458
const char *control_message;
@@ -500,6 +501,11 @@ int main(int argc, char ** argv) {
500
501
llama_reset_timings (ctx);
501
502
}
502
503
504
+ const int n_vocab = llama_n_vocab (ctx);
505
+
506
+ std::vector<llama_token_data> candidates;
507
+ candidates.reserve (n_vocab);
508
+
503
509
while ((n_remain != 0 && !is_antiprompt) || params.interactive ) {
504
510
// predict
505
511
if (embd.size () > 0 ) {
@@ -537,8 +543,8 @@ int main(int argc, char ** argv) {
537
543
538
544
LOG (" after swap: n_past = %d, n_past_guidance = %d\n " , n_past, n_past_guidance);
539
545
540
- // insert n_left/2 tokens at the start of embd from last_n_tokens
541
- embd.insert (embd.begin (), last_n_tokens .begin () + n_ctx - n_left/2 - embd.size (), last_n_tokens .end () - embd.size ());
546
+ // insert n_left/2 tokens at the start of embd from last_tokens
547
+ embd.insert (embd.begin (), last_tokens .begin () + n_ctx - n_left/2 - embd.size (), last_tokens .end () - embd.size ());
542
548
543
549
LOG (" embd: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, embd));
544
550
@@ -637,20 +643,6 @@ int main(int argc, char ** argv) {
637
643
embd_guidance.clear ();
638
644
639
645
if ((int ) embd_inp.size () <= n_consumed && !is_interacting) {
640
- const float temp = params.temp ;
641
- const int32_t top_k = params.top_k <= 0 ? llama_n_vocab (ctx) : params.top_k ;
642
- const float top_p = params.top_p ;
643
- const float tfs_z = params.tfs_z ;
644
- const float typical_p = params.typical_p ;
645
- const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n ;
646
- const float repeat_penalty = params.repeat_penalty ;
647
- const float alpha_presence = params.presence_penalty ;
648
- const float alpha_frequency = params.frequency_penalty ;
649
- const int mirostat = params.mirostat ;
650
- const float mirostat_tau = params.mirostat_tau ;
651
- const float mirostat_eta = params.mirostat_eta ;
652
- const bool penalize_nl = params.penalize_nl ;
653
-
654
646
// optionally save the session on first sample (for faster prompt loading next time)
655
647
if (!path_session.empty () && need_to_save_session && !params.prompt_cache_ro ) {
656
648
need_to_save_session = false ;
@@ -659,98 +651,12 @@ int main(int argc, char ** argv) {
659
651
LOG (" saved session to %s\n " , path_session.c_str ());
660
652
}
661
653
662
- llama_token id = 0 ;
663
-
664
- {
665
- auto logits = llama_get_logits (ctx);
666
- auto n_vocab = llama_n_vocab (ctx);
667
-
668
- // Apply params.logit_bias map
669
- for (auto it = params.logit_bias .begin (); it != params.logit_bias .end (); it++) {
670
- logits[it->first ] += it->second ;
671
- }
672
-
673
- std::vector<llama_token_data> candidates;
674
- candidates.reserve (n_vocab);
675
- for (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
676
- candidates.emplace_back (llama_token_data{token_id, logits[token_id], 0 .0f });
677
- }
678
-
679
- llama_token_data_array cur_p = { candidates.data (), candidates.size (), false };
680
-
681
- if (ctx_guidance) {
682
- llama_sample_classifier_free_guidance (ctx, &cur_p, ctx_guidance, params.cfg_scale );
683
- }
684
-
685
- // Apply penalties
686
- float nl_logit = logits[llama_token_nl (ctx)];
687
- auto last_n_repeat = std::min (std::min ((int )last_n_tokens.size (), repeat_last_n), n_ctx);
688
- llama_sample_repetition_penalty (ctx, &cur_p,
689
- last_n_tokens.data () + last_n_tokens.size () - last_n_repeat,
690
- last_n_repeat, repeat_penalty);
691
- llama_sample_frequency_and_presence_penalties (ctx, &cur_p,
692
- last_n_tokens.data () + last_n_tokens.size () - last_n_repeat,
693
- last_n_repeat, alpha_frequency, alpha_presence);
694
- if (!penalize_nl) {
695
- for (size_t idx = 0 ; idx < cur_p.size ; idx++) {
696
- if (cur_p.data [idx].id == llama_token_nl (ctx)) {
697
- cur_p.data [idx].logit = nl_logit;
698
- break ;
699
- }
700
- }
701
- }
702
-
703
- if (grammar != NULL ) {
704
- llama_sample_grammar (ctx, &cur_p, grammar);
705
- }
706
-
707
- if (temp <= 0 ) {
708
- // Greedy sampling
709
- id = llama_sample_token_greedy (ctx, &cur_p);
710
- } else {
711
- if (mirostat == 1 ) {
712
- static float mirostat_mu = 2 .0f * mirostat_tau;
713
- const int mirostat_m = 100 ;
714
- llama_sample_temperature (ctx, &cur_p, temp);
715
- id = llama_sample_token_mirostat (ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
716
- } else if (mirostat == 2 ) {
717
- static float mirostat_mu = 2 .0f * mirostat_tau;
718
- llama_sample_temperature (ctx, &cur_p, temp);
719
- id = llama_sample_token_mirostat_v2 (ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
720
- } else {
721
- // Temperature sampling
722
- llama_sample_top_k (ctx, &cur_p, top_k, 1 );
723
- llama_sample_tail_free (ctx, &cur_p, tfs_z, 1 );
724
- llama_sample_typical (ctx, &cur_p, typical_p, 1 );
725
- llama_sample_top_p (ctx, &cur_p, top_p, 1 );
726
- llama_sample_temperature (ctx, &cur_p, temp);
727
-
728
- {
729
- const int n_top = 10 ;
730
- LOG (" top %d candidates:\n " , n_top);
731
-
732
- for (int i = 0 ; i < n_top; i++) {
733
- const llama_token id = cur_p.data [i].id ;
734
- LOG (" - %5d: '%12s' (%.3f)\n " , id, llama_token_to_piece (ctx, id).c_str (), cur_p.data [i].p );
735
- }
736
- }
737
-
738
- id = llama_sample_token (ctx, &cur_p);
654
+ const llama_token id = llama_sample_token (ctx, ctx_guidance, grammar, params, last_tokens, candidates);
739
655
740
- LOG (" sampled token: %5d: '%s'\n " , id, llama_token_to_piece (ctx, id).c_str ());
741
- }
742
- }
743
- // printf("`%d`", candidates_p.size);
656
+ last_tokens.erase (last_tokens.begin ());
657
+ last_tokens.push_back (id);
744
658
745
- if (grammar != NULL ) {
746
- llama_grammar_accept_token (ctx, grammar, id);
747
- }
748
-
749
- last_n_tokens.erase (last_n_tokens.begin ());
750
- last_n_tokens.push_back (id);
751
-
752
- LOG (" last: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, last_n_tokens));
753
- }
659
+ LOG (" last: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, last_tokens));
754
660
755
661
embd.push_back (id);
756
662
@@ -766,8 +672,8 @@ int main(int argc, char ** argv) {
766
672
LOG (" embd_inp.size(): %d, n_consumed: %d\n " , (int ) embd_inp.size (), n_consumed);
767
673
while ((int ) embd_inp.size () > n_consumed) {
768
674
embd.push_back (embd_inp[n_consumed]);
769
- last_n_tokens .erase (last_n_tokens .begin ());
770
- last_n_tokens .push_back (embd_inp[n_consumed]);
675
+ last_tokens .erase (last_tokens .begin ());
676
+ last_tokens .push_back (embd_inp[n_consumed]);
771
677
++n_consumed;
772
678
if ((int ) embd.size () >= params.n_batch ) {
773
679
break ;
@@ -800,7 +706,7 @@ int main(int argc, char ** argv) {
800
706
// check for reverse prompt
801
707
if (params.antiprompt .size ()) {
802
708
std::string last_output;
803
- for (auto id : last_n_tokens ) {
709
+ for (auto id : last_tokens ) {
804
710
last_output += llama_token_to_piece (ctx, id);
805
711
}
806
712
@@ -831,7 +737,7 @@ int main(int argc, char ** argv) {
831
737
}
832
738
833
739
// deal with end of text token in interactive mode
834
- if (last_n_tokens .back () == llama_token_eos (ctx)) {
740
+ if (last_tokens .back () == llama_token_eos (ctx)) {
835
741
LOG (" found EOS token\n " );
836
742
837
743
if (params.interactive ) {
@@ -933,7 +839,7 @@ int main(int argc, char ** argv) {
933
839
if (grammar != NULL ) {
934
840
llama_grammar_free (grammar);
935
841
936
- std::vector<const llama_grammar_element *> grammar_rules ( parsed_grammar.c_rules ());
842
+ std::vector<const llama_grammar_element *> grammar_rules (parsed_grammar.c_rules ());
937
843
grammar = llama_grammar_init (
938
844
grammar_rules.data (), grammar_rules.size (),
939
845
parsed_grammar.symbol_ids .at (" root" ));
0 commit comments