Skip to content

Commit e62c491

Browse files
mirostat
1 parent 9c78250 commit e62c491

File tree

7 files changed

+427
-93
lines changed

7 files changed

+427
-93
lines changed

examples/common.cpp

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,24 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
144144
break;
145145
}
146146
params.alpha_presence = std::stof(argv[i]);
147+
} else if (arg == "--mirostat") {
148+
if (++i >= argc) {
149+
invalid_param = true;
150+
break;
151+
}
152+
params.mirostat = std::stoi(argv[i]);
153+
} else if (arg == "--mirostat_eta") {
154+
if (++i >= argc) {
155+
invalid_param = true;
156+
break;
157+
}
158+
params.mirostat_eta = std::stof(argv[i]);
159+
} else if (arg == "--mirostat_tau") {
160+
if (++i >= argc) {
161+
invalid_param = true;
162+
break;
163+
}
164+
params.mirostat_tau = std::stof(argv[i]);
147165
} else if (arg == "-b" || arg == "--batch_size") {
148166
if (++i >= argc) {
149167
invalid_param = true;
@@ -259,14 +277,17 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
259277
fprintf(stderr, " -f FNAME, --file FNAME\n");
260278
fprintf(stderr, " prompt file to start generation.\n");
261279
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict);
262-
fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
263-
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", (double)params.top_p);
264-
fprintf(stderr, " --tfs N tail free sampling (default: %.1f)\n", (double)params.tfs_z);
265-
fprintf(stderr, " --typical N locally typical sampling (default: %.1f)\n", (double)params.typical_p);
266-
fprintf(stderr, " --alpha_presence N repeat alpha presence (default: %d)\n", params.alpha_presence);
267-
fprintf(stderr, " --alpha_frequency N repeat alpha frequency (default: %.1f)\n", (double)params.alpha_frequency);
268-
fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n);
269-
fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", (double)params.repeat_penalty);
280+
fprintf(stderr, " --top_k N top-k sampling (default: %d, disabled: 0)\n", params.top_k);
281+
fprintf(stderr, " --top_p N top-p sampling (default: %.1f, disabled: 1.0)\n", (double)params.top_p);
282+
fprintf(stderr, " --tfs N tail free sampling, parameter z (default: %.1f, disabled: 1.0)\n", (double)params.tfs_z);
283+
fprintf(stderr, " --typical N locally typical sampling, parameter p (default: %.1f, disabled: 1.0)\n", (double)params.typical_p);
284+
fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d, disabled: 0)\n", params.repeat_last_n);
285+
fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f, disabled: 1.0)\n", (double)params.repeat_penalty);
286+
fprintf(stderr, " --alpha_presence N repeat alpha presence (default: %.1f, disabled: 0.0)\n", (double)params.alpha_presence);
287+
fprintf(stderr, " --alpha_frequency N repeat alpha frequency (default: %.1f, disabled: 0.0)\n", (double)params.alpha_frequency);
288+
fprintf(stderr, " --mirostat N use mirostat sampling (default: %d, disabled: 0, mirostat: 1, mirostat 2.0: 2)\n", params.mirostat);
289+
fprintf(stderr, " --mirostat_eta N mirostat learning rate (default: %.1f)\n", (double)params.mirostat_eta);
290+
fprintf(stderr, " --mirostat_tau N mirostat target entropy (default: %.1f)\n", (double)params.mirostat_tau);
270291
fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx);
271292
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating\n");
272293
fprintf(stderr, " --memory_f32 use f32 instead of f16 for memory key+value\n");

examples/common.h

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,24 @@ struct gpt_params {
1717
int32_t seed = -1; // RNG seed
1818
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
1919
int32_t n_predict = 128; // new tokens to predict
20-
int32_t repeat_last_n = 64; // last n tokens to penalize
2120
int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions)
2221
int32_t n_ctx = 512; // context size
2322
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
2423
int32_t n_keep = 0; // number of tokens to keep from initial prompt
2524

2625
// sampling parameters
27-
int32_t top_k = 40;
28-
float top_p = 0.95f;
29-
float temp = 0.80f;
30-
float repeat_penalty = 1.10f;
26+
int32_t top_k = 0; // <= 0 to use vocab size
27+
float top_p = 1.0f; // 1.0 = disabled
28+
float tfs_z = 1.0f; // 1.0 = disabled
29+
float typical_p = 1.0f; // 1.0 = disabled
30+
float temp = 1.0f; // 1.0 = disabled
31+
float repeat_penalty = 1.0f; // 1.0 = disabled
32+
int32_t repeat_last_n = -1; // last n tokens to penalize (0 = disable penalty, -1 = context size)
33+
float alpha_frequency = 0.0f; // 0.0 = disabled
34+
float alpha_presence = 0.0f; // 0.0 = disabled
35+
int mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
36+
float mirostat_tau = 5.0f; // target entropy
37+
float mirostat_eta = 0.1f; // learning rate
3138

3239
std::string model = "models/lamma-7B/ggml-model.bin"; // model path
3340
std::string prompt = "";

examples/main/main.cpp

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,8 @@ int main(int argc, char ** argv) {
230230
fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
231231
}
232232
}
233-
fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, alpha_presence = %f, alpha_frequency = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f\n",
234-
params.repeat_last_n, params.repeat_penalty, params.alpha_presence, params.alpha_frequency, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp);
233+
fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, alpha_presence = %f, alpha_frequency = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_eta = %f, mirostat_tau = %f\n",
234+
params.repeat_last_n, params.repeat_penalty, params.alpha_presence, params.alpha_frequency, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau);
235235
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
236236
fprintf(stderr, "\n\n");
237237

@@ -313,6 +313,9 @@ int main(int argc, char ** argv) {
313313
const float repeat_penalty = params.repeat_penalty;
314314
const float alpha_presence = params.alpha_presence;
315315
const float alpha_frequency = params.alpha_frequency;
316+
const int mirostat = params.mirostat;
317+
const float mirostat_tau = params.mirostat_tau;
318+
const float mirostat_eta = params.mirostat_eta;
316319

317320
llama_token id = 0;
318321

@@ -326,47 +329,45 @@ int main(int argc, char ** argv) {
326329

327330
std::vector<llama_token_data> candidates;
328331
candidates.reserve(n_vocab);
329-
for (size_t i = 0; i < n_vocab; i++) {
332+
for (size_t i = 0; i < (size_t) n_vocab; i++) {
330333
candidates.emplace_back(i, logits[i], 0.0f);
331334
}
332335

333-
llama_token_data_array candidates_p = { candidates.data(), candidates.size() };
336+
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
334337

335338
// Apply penalties
336339
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
337-
llama_sample_repetition_penalty(&candidates_p,
340+
llama_sample_repetition_penalty(ctx, &candidates_p,
338341
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
339342
last_n_repeat, repeat_penalty);
340-
llama_sample_frequency_and_presence_penalties(&candidates_p,
343+
llama_sample_frequency_and_presence_penalties(ctx, &candidates_p,
341344
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
342345
last_n_repeat, alpha_frequency, alpha_presence);
343346

344347

345-
#if 1
346348
if (temp <= 0) {
347349
// Greedy sampling
348350
id = llama_sample_token_greedy(ctx, &candidates_p);
349351
} else {
350-
// Temperature sampling
351-
llama_sample_top_k(&candidates_p, top_k);
352-
llama_sample_tail_free(&candidates_p, tfs_z);
353-
llama_sample_typical(&candidates_p, typical_p);
354-
llama_sample_top_p(&candidates_p, top_p);
355-
356-
llama_sample_temperature(&candidates_p, temp);
357-
// printf("`%d`", candidates_p.size);
358-
id = llama_sample_token(ctx, &candidates_p);
352+
if (mirostat == 1) {
353+
static float mirostat_mu = 2.0f * mirostat_tau;
354+
static int mirostat_k = 40;
355+
const int mirostat_m = 100;
356+
id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, float(n_vocab), &mirostat_k, &mirostat_mu);
357+
} else if (mirostat == 2) {
358+
static float mirostat_mu = 2.0f * mirostat_tau;
359+
id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
360+
} else {
361+
// Temperature sampling
362+
llama_sample_top_k(ctx, &candidates_p, top_k);
363+
llama_sample_tail_free(ctx, &candidates_p, tfs_z);
364+
llama_sample_typical(ctx, &candidates_p, typical_p);
365+
llama_sample_top_p(ctx, &candidates_p, top_p);
366+
llama_sample_temperature(ctx, &candidates_p, temp);
367+
id = llama_sample_token(ctx, &candidates_p);
368+
}
359369
}
360-
#else
361-
const float tau = 5.0f;
362-
static float mu = 2.0f * tau;
363-
static int k = 40;
364-
const float eta = 0.1f;
365-
const int m = 100;
366-
const float N = n_vocab;
367-
id = llama_sample_mirostat(ctx, &candidates_p, tau, eta, m, N, &k, &mu);
368-
// id = llama_sample_mirostat_v2(ctx, &candidates_p, tau, eta, &mu);
369-
#endif
370+
// printf("`%d`", candidates_p.size);
370371

371372
last_n_tokens.erase(last_n_tokens.begin());
372373
last_n_tokens.push_back(id);

0 commit comments

Comments
 (0)