@@ -317,6 +317,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
317
317
break ;
318
318
}
319
319
params.model = argv[i];
320
+ } else if (arg == " -md" || arg == " --model-draft" ) {
321
+ if (++i >= argc) {
322
+ invalid_param = true ;
323
+ break ;
324
+ }
325
+ params.model_draft = argv[i];
320
326
} else if (arg == " -a" || arg == " --alias" ) {
321
327
if (++i >= argc) {
322
328
invalid_param = true ;
@@ -669,6 +675,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
669
675
fprintf (stdout, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n " );
670
676
fprintf (stdout, " -m FNAME, --model FNAME\n " );
671
677
fprintf (stdout, " model path (default: %s)\n " , params.model .c_str ());
678
+ fprintf (stdout, " -md FNAME, --model-draft FNAME\n " );
679
+ fprintf (stdout, " draft model for speculative sampling (default: %s)\n " , params.model .c_str ());
672
680
fprintf (stdout, " -ld LOGDIR, --logdir LOGDIR\n " );
673
681
fprintf (stdout, " path under which to save YAML logs (no logging if unset)\n " );
674
682
fprintf (stdout, " \n " );
@@ -824,6 +832,130 @@ std::string llama_detokenize_bpe(llama_context * ctx, const std::vector<llama_to
824
832
return result;
825
833
}
826
834
835
+ //
836
+ // Sampling utils
837
+ //
838
+
839
+ llama_token llama_sample_token (
840
+ struct llama_context * ctx,
841
+ struct llama_context * ctx_guidance,
842
+ struct llama_grammar * grammar,
843
+ const struct gpt_params & params,
844
+ const std::vector<llama_token> & last_tokens,
845
+ std::vector<llama_token_data> & candidates,
846
+ int idx) {
847
+ const int n_ctx = llama_n_ctx (ctx);
848
+ const int n_vocab = llama_n_vocab (ctx);
849
+
850
+ const float temp = params.temp ;
851
+ const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k ;
852
+ const float top_p = params.top_p ;
853
+ const float tfs_z = params.tfs_z ;
854
+ const float typical_p = params.typical_p ;
855
+ const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n ;
856
+ const float repeat_penalty = params.repeat_penalty ;
857
+ const float alpha_presence = params.presence_penalty ;
858
+ const float alpha_frequency = params.frequency_penalty ;
859
+ const int mirostat = params.mirostat ;
860
+ const float mirostat_tau = params.mirostat_tau ;
861
+ const float mirostat_eta = params.mirostat_eta ;
862
+ const bool penalize_nl = params.penalize_nl ;
863
+
864
+ llama_token id = 0 ;
865
+
866
+ float * logits = llama_get_logits (ctx) + idx * n_vocab;
867
+
868
+ // Apply params.logit_bias map
869
+ for (auto it = params.logit_bias .begin (); it != params.logit_bias .end (); it++) {
870
+ logits[it->first ] += it->second ;
871
+ }
872
+
873
+ candidates.clear ();
874
+ for (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
875
+ candidates.emplace_back (llama_token_data{token_id, logits[token_id], 0 .0f });
876
+ }
877
+
878
+ llama_token_data_array cur_p = { candidates.data (), candidates.size (), false };
879
+
880
+ if (ctx_guidance) {
881
+ llama_sample_classifier_free_guidance (ctx, &cur_p, ctx_guidance, params.cfg_scale );
882
+ }
883
+
884
+ // apply penalties
885
+ if (!last_tokens.empty ()) {
886
+ const float nl_logit = logits[llama_token_nl (ctx)];
887
+ const int last_n_repeat = std::min (std::min ((int )last_tokens.size (), repeat_last_n), n_ctx);
888
+
889
+ llama_sample_repetition_penalty (ctx, &cur_p,
890
+ last_tokens.data () + last_tokens.size () - last_n_repeat,
891
+ last_n_repeat, repeat_penalty);
892
+ llama_sample_frequency_and_presence_penalties (ctx, &cur_p,
893
+ last_tokens.data () + last_tokens.size () - last_n_repeat,
894
+ last_n_repeat, alpha_frequency, alpha_presence);
895
+
896
+ if (!penalize_nl) {
897
+ for (size_t idx = 0 ; idx < cur_p.size ; idx++) {
898
+ if (cur_p.data [idx].id == llama_token_nl (ctx)) {
899
+ cur_p.data [idx].logit = nl_logit;
900
+ break ;
901
+ }
902
+ }
903
+ }
904
+ }
905
+
906
+ if (grammar != NULL ) {
907
+ llama_sample_grammar (ctx, &cur_p, grammar);
908
+ }
909
+
910
+ if (temp <= 0 ) {
911
+ // Greedy sampling
912
+ id = llama_sample_token_greedy (ctx, &cur_p);
913
+ } else {
914
+ if (mirostat == 1 ) {
915
+ static float mirostat_mu = 2 .0f * mirostat_tau;
916
+ const int mirostat_m = 100 ;
917
+ llama_sample_temperature (ctx, &cur_p, temp);
918
+ id = llama_sample_token_mirostat (ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
919
+ } else if (mirostat == 2 ) {
920
+ static float mirostat_mu = 2 .0f * mirostat_tau;
921
+ llama_sample_temperature (ctx, &cur_p, temp);
922
+ id = llama_sample_token_mirostat_v2 (ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
923
+ } else {
924
+ // Temperature sampling
925
+ llama_sample_top_k (ctx, &cur_p, top_k, 1 );
926
+ llama_sample_tail_free (ctx, &cur_p, tfs_z, 1 );
927
+ llama_sample_typical (ctx, &cur_p, typical_p, 1 );
928
+ llama_sample_top_p (ctx, &cur_p, top_p, 1 );
929
+ llama_sample_temperature (ctx, &cur_p, temp);
930
+
931
+ {
932
+ const int n_top = 10 ;
933
+ LOG (" top %d candidates:\n " , n_top);
934
+
935
+ for (int i = 0 ; i < n_top; i++) {
936
+ const llama_token id = cur_p.data [i].id ;
937
+ LOG (" - %5d: '%12s' (%.3f)\n " , id, llama_token_to_piece (ctx, id).c_str (), cur_p.data [i].p );
938
+ }
939
+ }
940
+
941
+ id = llama_sample_token (ctx, &cur_p);
942
+
943
+ LOG (" sampled token: %5d: '%s'\n " , id, llama_token_to_piece (ctx, id).c_str ());
944
+ }
945
+ }
946
+ // printf("`%d`", candidates_p.size);
947
+
948
+ if (grammar != NULL ) {
949
+ llama_grammar_accept_token (ctx, grammar, id);
950
+ }
951
+
952
+ return id;
953
+ }
954
+
955
+ //
956
+ // YAML utils
957
+ //
958
+
827
959
// returns true if successful, false otherwise
828
960
bool create_directory_with_parents (const std::string & path) {
829
961
#ifdef _WIN32
@@ -1062,6 +1194,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
1062
1194
fprintf (stream, " mirostat_lr: %f # default: 0.1\n " , params.mirostat_eta );
1063
1195
fprintf (stream, " mlock: %s # default: false\n " , params.use_mlock ? " true" : " false" );
1064
1196
fprintf (stream, " model: %s # default: models/7B/ggml-model.bin\n " , params.model .c_str ());
1197
+ fprintf (stream, " model_draft: %s # default:\n " , params.model_draft .c_str ());
1065
1198
fprintf (stream, " mtest: %s # default: false\n " , params.mem_test ? " true" : " false" );
1066
1199
fprintf (stream, " multiline_input: %s # default: false\n " , params.multiline_input ? " true" : " false" );
1067
1200
fprintf (stream, " n_gpu_layers: %d # default: 0\n " , params.n_gpu_layers );
0 commit comments