Skip to content

Commit c33cd8a

Browse files
committed
speculative : initial example
1 parent b7f2aa9 commit c33cd8a

File tree

6 files changed

+425
-115
lines changed

6 files changed

+425
-115
lines changed

common/common.cpp

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
317317
break;
318318
}
319319
params.model = argv[i];
320+
} else if (arg == "-md" || arg == "--model-draft") {
321+
if (++i >= argc) {
322+
invalid_param = true;
323+
break;
324+
}
325+
params.model_draft = argv[i];
320326
} else if (arg == "-a" || arg == "--alias") {
321327
if (++i >= argc) {
322328
invalid_param = true;
@@ -669,6 +675,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
669675
fprintf(stdout, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
670676
fprintf(stdout, " -m FNAME, --model FNAME\n");
671677
fprintf(stdout, " model path (default: %s)\n", params.model.c_str());
678+
fprintf(stdout, " -md FNAME, --model-draft FNAME\n");
679+
fprintf(stdout, " draft model for speculative sampling (default: %s)\n", params.model.c_str());
672680
fprintf(stdout, " -ld LOGDIR, --logdir LOGDIR\n");
673681
fprintf(stdout, " path under which to save YAML logs (no logging if unset)\n");
674682
fprintf(stdout, "\n");
@@ -824,6 +832,130 @@ std::string llama_detokenize_bpe(llama_context * ctx, const std::vector<llama_to
824832
return result;
825833
}
826834

835+
//
836+
// Sampling utils
837+
//
838+
839+
llama_token llama_sample_token(
840+
struct llama_context * ctx,
841+
struct llama_context * ctx_guidance,
842+
struct llama_grammar * grammar,
843+
const struct gpt_params & params,
844+
const std::vector<llama_token> & last_tokens,
845+
std::vector<llama_token_data> & candidates,
846+
int idx) {
847+
const int n_ctx = llama_n_ctx(ctx);
848+
const int n_vocab = llama_n_vocab(ctx);
849+
850+
const float temp = params.temp;
851+
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
852+
const float top_p = params.top_p;
853+
const float tfs_z = params.tfs_z;
854+
const float typical_p = params.typical_p;
855+
const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
856+
const float repeat_penalty = params.repeat_penalty;
857+
const float alpha_presence = params.presence_penalty;
858+
const float alpha_frequency = params.frequency_penalty;
859+
const int mirostat = params.mirostat;
860+
const float mirostat_tau = params.mirostat_tau;
861+
const float mirostat_eta = params.mirostat_eta;
862+
const bool penalize_nl = params.penalize_nl;
863+
864+
llama_token id = 0;
865+
866+
float * logits = llama_get_logits(ctx) + idx * n_vocab;
867+
868+
// Apply params.logit_bias map
869+
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
870+
logits[it->first] += it->second;
871+
}
872+
873+
candidates.clear();
874+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
875+
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
876+
}
877+
878+
llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };
879+
880+
if (ctx_guidance) {
881+
llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale);
882+
}
883+
884+
// apply penalties
885+
if (!last_tokens.empty()) {
886+
const float nl_logit = logits[llama_token_nl(ctx)];
887+
const int last_n_repeat = std::min(std::min((int)last_tokens.size(), repeat_last_n), n_ctx);
888+
889+
llama_sample_repetition_penalty(ctx, &cur_p,
890+
last_tokens.data() + last_tokens.size() - last_n_repeat,
891+
last_n_repeat, repeat_penalty);
892+
llama_sample_frequency_and_presence_penalties(ctx, &cur_p,
893+
last_tokens.data() + last_tokens.size() - last_n_repeat,
894+
last_n_repeat, alpha_frequency, alpha_presence);
895+
896+
if (!penalize_nl) {
897+
for (size_t idx = 0; idx < cur_p.size; idx++) {
898+
if (cur_p.data[idx].id == llama_token_nl(ctx)) {
899+
cur_p.data[idx].logit = nl_logit;
900+
break;
901+
}
902+
}
903+
}
904+
}
905+
906+
if (grammar != NULL) {
907+
llama_sample_grammar(ctx, &cur_p, grammar);
908+
}
909+
910+
if (temp <= 0) {
911+
// Greedy sampling
912+
id = llama_sample_token_greedy(ctx, &cur_p);
913+
} else {
914+
if (mirostat == 1) {
915+
static float mirostat_mu = 2.0f * mirostat_tau;
916+
const int mirostat_m = 100;
917+
llama_sample_temperature(ctx, &cur_p, temp);
918+
id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
919+
} else if (mirostat == 2) {
920+
static float mirostat_mu = 2.0f * mirostat_tau;
921+
llama_sample_temperature(ctx, &cur_p, temp);
922+
id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
923+
} else {
924+
// Temperature sampling
925+
llama_sample_top_k (ctx, &cur_p, top_k, 1);
926+
llama_sample_tail_free (ctx, &cur_p, tfs_z, 1);
927+
llama_sample_typical (ctx, &cur_p, typical_p, 1);
928+
llama_sample_top_p (ctx, &cur_p, top_p, 1);
929+
llama_sample_temperature(ctx, &cur_p, temp);
930+
931+
{
932+
const int n_top = 10;
933+
LOG("top %d candidates:\n", n_top);
934+
935+
for (int i = 0; i < n_top; i++) {
936+
const llama_token id = cur_p.data[i].id;
937+
LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p);
938+
}
939+
}
940+
941+
id = llama_sample_token(ctx, &cur_p);
942+
943+
LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str());
944+
}
945+
}
946+
// printf("`%d`", candidates_p.size);
947+
948+
if (grammar != NULL) {
949+
llama_grammar_accept_token(ctx, grammar, id);
950+
}
951+
952+
return id;
953+
}
954+
955+
//
956+
// YAML utils
957+
//
958+
827959
// returns true if successful, false otherwise
828960
bool create_directory_with_parents(const std::string & path) {
829961
#ifdef _WIN32
@@ -1062,6 +1194,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
10621194
fprintf(stream, "mirostat_lr: %f # default: 0.1\n", params.mirostat_eta);
10631195
fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false");
10641196
fprintf(stream, "model: %s # default: models/7B/ggml-model.bin\n", params.model.c_str());
1197+
fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str());
10651198
fprintf(stream, "mtest: %s # default: false\n", params.mem_test ? "true" : "false");
10661199
fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false");
10671200
fprintf(stream, "n_gpu_layers: %d # default: 0\n", params.n_gpu_layers);

common/common.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ struct gpt_params {
6363
float cfg_scale = 1.f; // How strong is guidance
6464

6565
std::string model = "models/7B/ggml-model-f16.gguf"; // model path
66+
std::string model_draft = ""; // draft model for speculative sampling
6667
std::string model_alias = "unknown"; // model alias
6768
std::string prompt = "";
6869
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
@@ -156,6 +157,40 @@ std::string llama_detokenize_bpe(
156157
llama_context * ctx,
157158
const std::vector<llama_token> & tokens);
158159

160+
//
161+
// Sampling utils
162+
//
163+
164+
// this is a common sampling function used across the examples for convenience
165+
// it can serve as a starting point for implementing your own sampling function
166+
//
167+
// required:
168+
// - ctx: context to use for sampling
169+
// - params: sampling parameters
170+
//
171+
// optional:
172+
// - ctx_guidance: context to use for classifier-free guidance, ignore if NULL
173+
// - grammar: grammar to use for sampling, ignore if NULL
174+
// - last_tokens: needed for repetition penalty, ignore if empty
175+
// - idx: sample from llama_get_logits(ctx) + idx * n_vocab
176+
//
177+
// returns:
178+
// - token: sampled token
179+
// - candidates: vector of candidate tokens
180+
//
181+
llama_token llama_sample_token(
182+
struct llama_context * ctx,
183+
struct llama_context * ctx_guidance,
184+
struct llama_grammar * grammar,
185+
const struct gpt_params & params,
186+
const std::vector<llama_token> & last_tokens,
187+
std::vector<llama_token_data> & candidates,
188+
int idx = 0);
189+
190+
//
191+
// YAML utils
192+
//
193+
159194
bool create_directory_with_parents(const std::string & path);
160195
void dump_vector_float_yaml(FILE * stream, const char * prop_name, const std::vector<float> & data);
161196
void dump_vector_int_yaml(FILE * stream, const char * prop_name, const std::vector<int> & data);

examples/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ else()
2323
add_subdirectory(train-text-from-scratch)
2424
add_subdirectory(convert-llama2c-to-ggml)
2525
add_subdirectory(simple)
26+
add_subdirectory(speculative)
2627
add_subdirectory(embd-input)
2728
add_subdirectory(llama-bench)
2829
add_subdirectory(beam-search)

0 commit comments

Comments
 (0)