Skip to content

Commit c82c808

Browse files
committed
speculative : initial example
1 parent 8f429fa commit c82c808

File tree

6 files changed

+425
-115
lines changed

6 files changed

+425
-115
lines changed

common/common.cpp

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
317317
break;
318318
}
319319
params.model = argv[i];
320+
} else if (arg == "-md" || arg == "--model-draft") {
321+
if (++i >= argc) {
322+
invalid_param = true;
323+
break;
324+
}
325+
params.model_draft = argv[i];
320326
} else if (arg == "-a" || arg == "--alias") {
321327
if (++i >= argc) {
322328
invalid_param = true;
@@ -669,6 +675,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
669675
fprintf(stdout, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
670676
fprintf(stdout, " -m FNAME, --model FNAME\n");
671677
fprintf(stdout, " model path (default: %s)\n", params.model.c_str());
678+
fprintf(stdout, " -md FNAME, --model-draft FNAME\n");
679+
fprintf(stdout, " draft model for speculative sampling (default: %s)\n", params.model.c_str());
672680
fprintf(stdout, " -ld LOGDIR, --logdir LOGDIR\n");
673681
fprintf(stdout, " path under which to save YAML logs (no logging if unset)\n");
674682
fprintf(stdout, "\n");
@@ -832,6 +840,130 @@ std::string llama_detokenize_bpe(llama_context * ctx, const std::vector<llama_to
832840
return result;
833841
}
834842

843+
//
844+
// Sampling utils
845+
//
846+
847+
llama_token llama_sample_token(
848+
struct llama_context * ctx,
849+
struct llama_context * ctx_guidance,
850+
struct llama_grammar * grammar,
851+
const struct gpt_params & params,
852+
const std::vector<llama_token> & last_tokens,
853+
std::vector<llama_token_data> & candidates,
854+
int idx) {
855+
const int n_ctx = llama_n_ctx(ctx);
856+
const int n_vocab = llama_n_vocab(ctx);
857+
858+
const float temp = params.temp;
859+
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
860+
const float top_p = params.top_p;
861+
const float tfs_z = params.tfs_z;
862+
const float typical_p = params.typical_p;
863+
const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
864+
const float repeat_penalty = params.repeat_penalty;
865+
const float alpha_presence = params.presence_penalty;
866+
const float alpha_frequency = params.frequency_penalty;
867+
const int mirostat = params.mirostat;
868+
const float mirostat_tau = params.mirostat_tau;
869+
const float mirostat_eta = params.mirostat_eta;
870+
const bool penalize_nl = params.penalize_nl;
871+
872+
llama_token id = 0;
873+
874+
float * logits = llama_get_logits(ctx) + idx * n_vocab;
875+
876+
// Apply params.logit_bias map
877+
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
878+
logits[it->first] += it->second;
879+
}
880+
881+
candidates.clear();
882+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
883+
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
884+
}
885+
886+
llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };
887+
888+
if (ctx_guidance) {
889+
llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale);
890+
}
891+
892+
// apply penalties
893+
if (!last_tokens.empty()) {
894+
const float nl_logit = logits[llama_token_nl(ctx)];
895+
const int last_n_repeat = std::min(std::min((int)last_tokens.size(), repeat_last_n), n_ctx);
896+
897+
llama_sample_repetition_penalty(ctx, &cur_p,
898+
last_tokens.data() + last_tokens.size() - last_n_repeat,
899+
last_n_repeat, repeat_penalty);
900+
llama_sample_frequency_and_presence_penalties(ctx, &cur_p,
901+
last_tokens.data() + last_tokens.size() - last_n_repeat,
902+
last_n_repeat, alpha_frequency, alpha_presence);
903+
904+
if (!penalize_nl) {
905+
for (size_t idx = 0; idx < cur_p.size; idx++) {
906+
if (cur_p.data[idx].id == llama_token_nl(ctx)) {
907+
cur_p.data[idx].logit = nl_logit;
908+
break;
909+
}
910+
}
911+
}
912+
}
913+
914+
if (grammar != NULL) {
915+
llama_sample_grammar(ctx, &cur_p, grammar);
916+
}
917+
918+
if (temp <= 0) {
919+
// Greedy sampling
920+
id = llama_sample_token_greedy(ctx, &cur_p);
921+
} else {
922+
if (mirostat == 1) {
923+
static float mirostat_mu = 2.0f * mirostat_tau;
924+
const int mirostat_m = 100;
925+
llama_sample_temperature(ctx, &cur_p, temp);
926+
id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
927+
} else if (mirostat == 2) {
928+
static float mirostat_mu = 2.0f * mirostat_tau;
929+
llama_sample_temperature(ctx, &cur_p, temp);
930+
id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
931+
} else {
932+
// Temperature sampling
933+
llama_sample_top_k (ctx, &cur_p, top_k, 1);
934+
llama_sample_tail_free (ctx, &cur_p, tfs_z, 1);
935+
llama_sample_typical (ctx, &cur_p, typical_p, 1);
936+
llama_sample_top_p (ctx, &cur_p, top_p, 1);
937+
llama_sample_temperature(ctx, &cur_p, temp);
938+
939+
{
940+
const int n_top = 10;
941+
LOG("top %d candidates:\n", n_top);
942+
943+
for (int i = 0; i < n_top; i++) {
944+
const llama_token id = cur_p.data[i].id;
945+
LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p);
946+
}
947+
}
948+
949+
id = llama_sample_token(ctx, &cur_p);
950+
951+
LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str());
952+
}
953+
}
954+
// printf("`%d`", candidates_p.size);
955+
956+
if (grammar != NULL) {
957+
llama_grammar_accept_token(ctx, grammar, id);
958+
}
959+
960+
return id;
961+
}
962+
963+
//
964+
// YAML utils
965+
//
966+
835967
// returns true if successful, false otherwise
836968
bool create_directory_with_parents(const std::string & path) {
837969
#ifdef _WIN32
@@ -1070,6 +1202,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
10701202
fprintf(stream, "mirostat_lr: %f # default: 0.1\n", params.mirostat_eta);
10711203
fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false");
10721204
fprintf(stream, "model: %s # default: models/7B/ggml-model.bin\n", params.model.c_str());
1205+
fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str());
10731206
fprintf(stream, "mtest: %s # default: false\n", params.mem_test ? "true" : "false");
10741207
fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false");
10751208
fprintf(stream, "n_gpu_layers: %d # default: 0\n", params.n_gpu_layers);

common/common.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ struct gpt_params {
6363
float cfg_scale = 1.f; // How strong is guidance
6464

6565
std::string model = "models/7B/ggml-model-f16.gguf"; // model path
66+
std::string model_draft = ""; // draft model for speculative sampling
6667
std::string model_alias = "unknown"; // model alias
6768
std::string prompt = "";
6869
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
@@ -156,6 +157,40 @@ std::string llama_detokenize_bpe(
156157
llama_context * ctx,
157158
const std::vector<llama_token> & tokens);
158159

160+
//
161+
// Sampling utils
162+
//
163+
164+
// this is a common sampling function used across the examples for convenience
165+
// it can serve as a starting point for implementing your own sampling function
166+
//
167+
// required:
168+
// - ctx: context to use for sampling
169+
// - params: sampling parameters
170+
//
171+
// optional:
172+
// - ctx_guidance: context to use for classifier-free guidance, ignore if NULL
173+
// - grammar: grammar to use for sampling, ignore if NULL
174+
// - last_tokens: needed for repetition penalty, ignore if empty
175+
// - idx: sample from llama_get_logits(ctx) + idx * n_vocab
176+
//
177+
// returns:
178+
// - token: sampled token
179+
// - candidates: vector of candidate tokens
180+
//
181+
llama_token llama_sample_token(
182+
struct llama_context * ctx,
183+
struct llama_context * ctx_guidance,
184+
struct llama_grammar * grammar,
185+
const struct gpt_params & params,
186+
const std::vector<llama_token> & last_tokens,
187+
std::vector<llama_token_data> & candidates,
188+
int idx = 0);
189+
190+
//
191+
// YAML utils
192+
//
193+
159194
bool create_directory_with_parents(const std::string & path);
160195
void dump_vector_float_yaml(FILE * stream, const char * prop_name, const std::vector<float> & data);
161196
void dump_vector_int_yaml(FILE * stream, const char * prop_name, const std::vector<int> & data);

examples/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ else()
2323
add_subdirectory(train-text-from-scratch)
2424
add_subdirectory(convert-llama2c-to-ggml)
2525
add_subdirectory(simple)
26+
add_subdirectory(speculative)
2627
add_subdirectory(embd-input)
2728
add_subdirectory(llama-bench)
2829
add_subdirectory(beam-search)

0 commit comments

Comments
 (0)