Skip to content

Commit f01c67f

Browse files
mirostat
1 parent 9b3b07c commit f01c67f

File tree

7 files changed

+427
-93
lines changed

7 files changed

+427
-93
lines changed

examples/common.cpp

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,24 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
150150
break;
151151
}
152152
params.alpha_presence = std::stof(argv[i]);
153+
} else if (arg == "--mirostat") {
154+
if (++i >= argc) {
155+
invalid_param = true;
156+
break;
157+
}
158+
params.mirostat = std::stoi(argv[i]);
159+
} else if (arg == "--mirostat_eta") {
160+
if (++i >= argc) {
161+
invalid_param = true;
162+
break;
163+
}
164+
params.mirostat_eta = std::stof(argv[i]);
165+
} else if (arg == "--mirostat_tau") {
166+
if (++i >= argc) {
167+
invalid_param = true;
168+
break;
169+
}
170+
params.mirostat_tau = std::stof(argv[i]);
153171
} else if (arg == "-b" || arg == "--batch_size") {
154172
if (++i >= argc) {
155173
invalid_param = true;
@@ -264,14 +282,17 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
264282
fprintf(stderr, " -f FNAME, --file FNAME\n");
265283
fprintf(stderr, " prompt file to start generation.\n");
266284
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict);
267-
fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
268-
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", (double)params.top_p);
269-
fprintf(stderr, " --tfs N tail free sampling (default: %.1f)\n", (double)params.tfs_z);
270-
fprintf(stderr, " --typical N locally typical sampling (default: %.1f)\n", (double)params.typical_p);
271-
fprintf(stderr, " --alpha_presence N repeat alpha presence (default: %d)\n", params.alpha_presence);
272-
fprintf(stderr, " --alpha_frequency N repeat alpha frequency (default: %.1f)\n", (double)params.alpha_frequency);
273-
fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n);
274-
fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", (double)params.repeat_penalty);
285+
fprintf(stderr, " --top_k N top-k sampling (default: %d, disabled: 0)\n", params.top_k);
286+
fprintf(stderr, " --top_p N top-p sampling (default: %.1f, disabled: 1.0)\n", (double)params.top_p);
287+
fprintf(stderr, " --tfs N tail free sampling, parameter z (default: %.1f, disabled: 1.0)\n", (double)params.tfs_z);
288+
fprintf(stderr, " --typical N locally typical sampling, parameter p (default: %.1f, disabled: 1.0)\n", (double)params.typical_p);
289+
fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d, disabled: 0)\n", params.repeat_last_n);
290+
fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f, disabled: 1.0)\n", (double)params.repeat_penalty);
291+
fprintf(stderr, " --alpha_presence N repeat alpha presence (default: %.1f, disabled: 0.0)\n", (double)params.alpha_presence);
292+
fprintf(stderr, " --alpha_frequency N repeat alpha frequency (default: %.1f, disabled: 0.0)\n", (double)params.alpha_frequency);
293+
fprintf(stderr, " --mirostat N use mirostat sampling (default: %d, disabled: 0, mirostat: 1, mirostat 2.0: 2)\n", params.mirostat);
294+
fprintf(stderr, " --mirostat_eta N mirostat learning rate (default: %.1f)\n", (double)params.mirostat_eta);
295+
fprintf(stderr, " --mirostat_tau N mirostat target entropy (default: %.1f)\n", (double)params.mirostat_tau);
275296
fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx);
276297
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating\n");
277298
fprintf(stderr, " --memory_f32 use f32 instead of f16 for memory key+value\n");

examples/common.h

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,24 @@ struct gpt_params {
1717
int32_t seed = -1; // RNG seed
1818
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
1919
int32_t n_predict = 128; // new tokens to predict
20-
int32_t repeat_last_n = 64; // last n tokens to penalize
2120
int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions)
2221
int32_t n_ctx = 512; // context size
2322
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
2423
int32_t n_keep = 0; // number of tokens to keep from initial prompt
2524

2625
// sampling parameters
27-
int32_t top_k = 40;
28-
float top_p = 0.95f;
29-
float temp = 0.80f;
30-
float repeat_penalty = 1.10f;
26+
int32_t top_k = 0; // <= 0 to use vocab size
27+
float top_p = 1.0f; // 1.0 = disabled
28+
float tfs_z = 1.0f; // 1.0 = disabled
29+
float typical_p = 1.0f; // 1.0 = disabled
30+
float temp = 1.0f; // 1.0 = disabled
31+
float repeat_penalty = 1.0f; // 1.0 = disabled
32+
int32_t repeat_last_n = -1; // last n tokens to penalize (0 = disable penalty, -1 = context size)
33+
float alpha_frequency = 0.0f; // 0.0 = disabled
34+
float alpha_presence = 0.0f; // 0.0 = disabled
35+
int mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
36+
float mirostat_tau = 5.0f; // target entropy
37+
float mirostat_eta = 0.1f; // learning rate
3138

3239
std::string model = "models/lamma-7B/ggml-model.bin"; // model path
3340
std::string prompt = "";

examples/main/main.cpp

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,8 @@ int main(int argc, char ** argv) {
276276
fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
277277
}
278278
}
279-
fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, alpha_presence = %f, alpha_frequency = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f\n",
280-
params.repeat_last_n, params.repeat_penalty, params.alpha_presence, params.alpha_frequency, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp);
279+
fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, alpha_presence = %f, alpha_frequency = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_eta = %f, mirostat_tau = %f\n",
280+
params.repeat_last_n, params.repeat_penalty, params.alpha_presence, params.alpha_frequency, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau);
281281
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
282282
fprintf(stderr, "\n\n");
283283

@@ -396,6 +396,9 @@ int main(int argc, char ** argv) {
396396
const float repeat_penalty = params.repeat_penalty;
397397
const float alpha_presence = params.alpha_presence;
398398
const float alpha_frequency = params.alpha_frequency;
399+
const int mirostat = params.mirostat;
400+
const float mirostat_tau = params.mirostat_tau;
401+
const float mirostat_eta = params.mirostat_eta;
399402

400403
// optionally save the session on first sample (for faster prompt loading next time)
401404
if (!path_session.empty() && need_to_save_session) {
@@ -415,47 +418,45 @@ int main(int argc, char ** argv) {
415418

416419
std::vector<llama_token_data> candidates;
417420
candidates.reserve(n_vocab);
418-
for (size_t i = 0; i < n_vocab; i++) {
421+
for (size_t i = 0; i < (size_t) n_vocab; i++) {
419422
candidates.emplace_back(i, logits[i], 0.0f);
420423
}
421424

422-
llama_token_data_array candidates_p = { candidates.data(), candidates.size() };
425+
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
423426

424427
// Apply penalties
425428
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
426-
llama_sample_repetition_penalty(&candidates_p,
429+
llama_sample_repetition_penalty(ctx, &candidates_p,
427430
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
428431
last_n_repeat, repeat_penalty);
429-
llama_sample_frequency_and_presence_penalties(&candidates_p,
432+
llama_sample_frequency_and_presence_penalties(ctx, &candidates_p,
430433
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
431434
last_n_repeat, alpha_frequency, alpha_presence);
432435

433436

434-
#if 1
435437
if (temp <= 0) {
436438
// Greedy sampling
437439
id = llama_sample_token_greedy(ctx, &candidates_p);
438440
} else {
439-
// Temperature sampling
440-
llama_sample_top_k(&candidates_p, top_k);
441-
llama_sample_tail_free(&candidates_p, tfs_z);
442-
llama_sample_typical(&candidates_p, typical_p);
443-
llama_sample_top_p(&candidates_p, top_p);
444-
445-
llama_sample_temperature(&candidates_p, temp);
446-
// printf("`%d`", candidates_p.size);
447-
id = llama_sample_token(ctx, &candidates_p);
441+
if (mirostat == 1) {
442+
static float mirostat_mu = 2.0f * mirostat_tau;
443+
static int mirostat_k = 40;
444+
const int mirostat_m = 100;
445+
id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, float(n_vocab), &mirostat_k, &mirostat_mu);
446+
} else if (mirostat == 2) {
447+
static float mirostat_mu = 2.0f * mirostat_tau;
448+
id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
449+
} else {
450+
// Temperature sampling
451+
llama_sample_top_k(ctx, &candidates_p, top_k);
452+
llama_sample_tail_free(ctx, &candidates_p, tfs_z);
453+
llama_sample_typical(ctx, &candidates_p, typical_p);
454+
llama_sample_top_p(ctx, &candidates_p, top_p);
455+
llama_sample_temperature(ctx, &candidates_p, temp);
456+
id = llama_sample_token(ctx, &candidates_p);
457+
}
448458
}
449-
#else
450-
const float tau = 5.0f;
451-
static float mu = 2.0f * tau;
452-
static int k = 40;
453-
const float eta = 0.1f;
454-
const int m = 100;
455-
const float N = n_vocab;
456-
id = llama_sample_mirostat(ctx, &candidates_p, tau, eta, m, N, &k, &mu);
457-
// id = llama_sample_mirostat_v2(ctx, &candidates_p, tau, eta, &mu);
458-
#endif
459+
// printf("`%d`", candidates_p.size);
459460

460461
last_n_tokens.erase(last_n_tokens.begin());
461462
last_n_tokens.push_back(id);

0 commit comments

Comments
 (0)