@@ -317,6 +317,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
317
317
break ;
318
318
}
319
319
params.model = argv[i];
320
+ } else if (arg == " -md" || arg == " --model-draft" ) {
321
+ if (++i >= argc) {
322
+ invalid_param = true ;
323
+ break ;
324
+ }
325
+ params.model_draft = argv[i];
320
326
} else if (arg == " -a" || arg == " --alias" ) {
321
327
if (++i >= argc) {
322
328
invalid_param = true ;
@@ -669,6 +675,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
669
675
fprintf (stdout, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n " );
670
676
fprintf (stdout, " -m FNAME, --model FNAME\n " );
671
677
fprintf (stdout, " model path (default: %s)\n " , params.model .c_str ());
678
+ fprintf (stdout, " -md FNAME, --model-draft FNAME\n " );
679
+ fprintf (stdout, " draft model for speculative sampling (default: %s)\n " , params.model .c_str ());
672
680
fprintf (stdout, " -ld LOGDIR, --logdir LOGDIR\n " );
673
681
fprintf (stdout, " path under which to save YAML logs (no logging if unset)\n " );
674
682
fprintf (stdout, " \n " );
@@ -832,6 +840,130 @@ std::string llama_detokenize_bpe(llama_context * ctx, const std::vector<llama_to
832
840
return result;
833
841
}
834
842
843
+ //
844
+ // Sampling utils
845
+ //
846
+
847
+ llama_token llama_sample_token (
848
+ struct llama_context * ctx,
849
+ struct llama_context * ctx_guidance,
850
+ struct llama_grammar * grammar,
851
+ const struct gpt_params & params,
852
+ const std::vector<llama_token> & last_tokens,
853
+ std::vector<llama_token_data> & candidates,
854
+ int idx) {
855
+ const int n_ctx = llama_n_ctx (ctx);
856
+ const int n_vocab = llama_n_vocab (ctx);
857
+
858
+ const float temp = params.temp ;
859
+ const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k ;
860
+ const float top_p = params.top_p ;
861
+ const float tfs_z = params.tfs_z ;
862
+ const float typical_p = params.typical_p ;
863
+ const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n ;
864
+ const float repeat_penalty = params.repeat_penalty ;
865
+ const float alpha_presence = params.presence_penalty ;
866
+ const float alpha_frequency = params.frequency_penalty ;
867
+ const int mirostat = params.mirostat ;
868
+ const float mirostat_tau = params.mirostat_tau ;
869
+ const float mirostat_eta = params.mirostat_eta ;
870
+ const bool penalize_nl = params.penalize_nl ;
871
+
872
+ llama_token id = 0 ;
873
+
874
+ float * logits = llama_get_logits (ctx) + idx * n_vocab;
875
+
876
+ // Apply params.logit_bias map
877
+ for (auto it = params.logit_bias .begin (); it != params.logit_bias .end (); it++) {
878
+ logits[it->first ] += it->second ;
879
+ }
880
+
881
+ candidates.clear ();
882
+ for (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
883
+ candidates.emplace_back (llama_token_data{token_id, logits[token_id], 0 .0f });
884
+ }
885
+
886
+ llama_token_data_array cur_p = { candidates.data (), candidates.size (), false };
887
+
888
+ if (ctx_guidance) {
889
+ llama_sample_classifier_free_guidance (ctx, &cur_p, ctx_guidance, params.cfg_scale );
890
+ }
891
+
892
+ // apply penalties
893
+ if (!last_tokens.empty ()) {
894
+ const float nl_logit = logits[llama_token_nl (ctx)];
895
+ const int last_n_repeat = std::min (std::min ((int )last_tokens.size (), repeat_last_n), n_ctx);
896
+
897
+ llama_sample_repetition_penalty (ctx, &cur_p,
898
+ last_tokens.data () + last_tokens.size () - last_n_repeat,
899
+ last_n_repeat, repeat_penalty);
900
+ llama_sample_frequency_and_presence_penalties (ctx, &cur_p,
901
+ last_tokens.data () + last_tokens.size () - last_n_repeat,
902
+ last_n_repeat, alpha_frequency, alpha_presence);
903
+
904
+ if (!penalize_nl) {
905
+ for (size_t idx = 0 ; idx < cur_p.size ; idx++) {
906
+ if (cur_p.data [idx].id == llama_token_nl (ctx)) {
907
+ cur_p.data [idx].logit = nl_logit;
908
+ break ;
909
+ }
910
+ }
911
+ }
912
+ }
913
+
914
+ if (grammar != NULL ) {
915
+ llama_sample_grammar (ctx, &cur_p, grammar);
916
+ }
917
+
918
+ if (temp <= 0 ) {
919
+ // Greedy sampling
920
+ id = llama_sample_token_greedy (ctx, &cur_p);
921
+ } else {
922
+ if (mirostat == 1 ) {
923
+ static float mirostat_mu = 2 .0f * mirostat_tau;
924
+ const int mirostat_m = 100 ;
925
+ llama_sample_temperature (ctx, &cur_p, temp);
926
+ id = llama_sample_token_mirostat (ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
927
+ } else if (mirostat == 2 ) {
928
+ static float mirostat_mu = 2 .0f * mirostat_tau;
929
+ llama_sample_temperature (ctx, &cur_p, temp);
930
+ id = llama_sample_token_mirostat_v2 (ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
931
+ } else {
932
+ // Temperature sampling
933
+ llama_sample_top_k (ctx, &cur_p, top_k, 1 );
934
+ llama_sample_tail_free (ctx, &cur_p, tfs_z, 1 );
935
+ llama_sample_typical (ctx, &cur_p, typical_p, 1 );
936
+ llama_sample_top_p (ctx, &cur_p, top_p, 1 );
937
+ llama_sample_temperature (ctx, &cur_p, temp);
938
+
939
+ {
940
+ const int n_top = 10 ;
941
+ LOG (" top %d candidates:\n " , n_top);
942
+
943
+ for (int i = 0 ; i < n_top; i++) {
944
+ const llama_token id = cur_p.data [i].id ;
945
+ LOG (" - %5d: '%12s' (%.3f)\n " , id, llama_token_to_piece (ctx, id).c_str (), cur_p.data [i].p );
946
+ }
947
+ }
948
+
949
+ id = llama_sample_token (ctx, &cur_p);
950
+
951
+ LOG (" sampled token: %5d: '%s'\n " , id, llama_token_to_piece (ctx, id).c_str ());
952
+ }
953
+ }
954
+ // printf("`%d`", candidates_p.size);
955
+
956
+ if (grammar != NULL ) {
957
+ llama_grammar_accept_token (ctx, grammar, id);
958
+ }
959
+
960
+ return id;
961
+ }
962
+
963
+ //
964
+ // YAML utils
965
+ //
966
+
835
967
// returns true if successful, false otherwise
836
968
bool create_directory_with_parents (const std::string & path) {
837
969
#ifdef _WIN32
@@ -1070,6 +1202,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
1070
1202
fprintf (stream, " mirostat_lr: %f # default: 0.1\n " , params.mirostat_eta );
1071
1203
fprintf (stream, " mlock: %s # default: false\n " , params.use_mlock ? " true" : " false" );
1072
1204
fprintf (stream, " model: %s # default: models/7B/ggml-model.bin\n " , params.model .c_str ());
1205
+ fprintf (stream, " model_draft: %s # default:\n " , params.model_draft .c_str ());
1073
1206
fprintf (stream, " mtest: %s # default: false\n " , params.mem_test ? " true" : " false" );
1074
1207
fprintf (stream, " multiline_input: %s # default: false\n " , params.multiline_input ? " true" : " false" );
1075
1208
fprintf (stream, " n_gpu_layers: %d # default: 0\n " , params.n_gpu_layers );
0 commit comments