Skip to content

Commit 9b3b07c

Browse files
Sample interface, new samplers.
New samplers: - locally typical sampling - tail free sampling - frequency and presence penalty - mirostat Ignore EOS fix: -inf should be used.
1 parent 5fba3c0 commit 9b3b07c

File tree

7 files changed

+450
-131
lines changed

7 files changed

+450
-131
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
7676
# Compile flags
7777
#
7878

79-
set(CMAKE_CXX_STANDARD 11)
79+
set(CMAKE_CXX_STANDARD 20)
8080
set(CMAKE_CXX_STANDARD_REQUIRED true)
8181
set(CMAKE_C_STANDARD 11)
8282
set(CMAKE_C_STANDARD_REQUIRED true)

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ endif
3535

3636
# keep standard at C11 and C++11
3737
CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC
38-
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC
38+
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++20 -fPIC
3939
LDFLAGS =
4040

4141
# warnings

examples/common.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
114114
break;
115115
}
116116
params.temp = std::stof(argv[i]);
117+
} else if (arg == "--tfs") {
118+
if (++i >= argc) {
119+
invalid_param = true;
120+
break;
121+
}
122+
params.tfs_z = std::stof(argv[i]);
123+
} else if (arg == "--typical") {
124+
if (++i >= argc) {
125+
invalid_param = true;
126+
break;
127+
}
128+
params.typical_p = std::stof(argv[i]);
117129
} else if (arg == "--repeat_last_n") {
118130
if (++i >= argc) {
119131
invalid_param = true;
@@ -126,6 +138,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
126138
break;
127139
}
128140
params.repeat_penalty = std::stof(argv[i]);
141+
} else if (arg == "--alpha_frequency") {
142+
if (++i >= argc) {
143+
invalid_param = true;
144+
break;
145+
}
146+
params.alpha_frequency = std::stof(argv[i]);
147+
} else if (arg == "--alpha_presence") {
148+
if (++i >= argc) {
149+
invalid_param = true;
150+
break;
151+
}
152+
params.alpha_presence = std::stof(argv[i]);
129153
} else if (arg == "-b" || arg == "--batch_size") {
130154
if (++i >= argc) {
131155
invalid_param = true;
@@ -242,6 +266,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
242266
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict);
243267
fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
244268
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", (double)params.top_p);
269+
fprintf(stderr, " --tfs N tail free sampling (default: %.1f)\n", (double)params.tfs_z);
270+
fprintf(stderr, " --typical N locally typical sampling (default: %.1f)\n", (double)params.typical_p);
271+
fprintf(stderr, " --alpha_presence N repeat alpha presence (default: %d)\n", params.alpha_presence);
272+
fprintf(stderr, " --alpha_frequency N repeat alpha frequency (default: %.1f)\n", (double)params.alpha_frequency);
245273
fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n);
246274
fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", (double)params.repeat_penalty);
247275
fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx);

examples/main/main.cpp

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,8 @@ int main(int argc, char ** argv) {
276276
fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
277277
}
278278
}
279-
fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n",
280-
params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
279+
fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, alpha_presence = %f, alpha_frequency = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f\n",
280+
params.repeat_last_n, params.repeat_penalty, params.alpha_presence, params.alpha_frequency, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp);
281281
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
282282
fprintf(stderr, "\n\n");
283283

@@ -387,10 +387,15 @@ int main(int argc, char ** argv) {
387387

388388
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
389389
// out of user input, sample next token
390-
const int32_t top_k = params.top_k;
391-
const float top_p = params.top_p;
392390
const float temp = params.temp;
391+
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
392+
const float top_p = params.top_p;
393+
const float tfs_z = params.tfs_z;
394+
const float typical_p = params.typical_p;
395+
const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
393396
const float repeat_penalty = params.repeat_penalty;
397+
const float alpha_presence = params.alpha_presence;
398+
const float alpha_frequency = params.alpha_frequency;
394399

395400
// optionally save the session on first sample (for faster prompt loading next time)
396401
if (!path_session.empty() && need_to_save_session) {
@@ -402,14 +407,55 @@ int main(int argc, char ** argv) {
402407

403408
{
404409
auto logits = llama_get_logits(ctx);
410+
auto n_vocab = llama_n_vocab(ctx);
405411

406412
if (params.ignore_eos) {
407-
logits[llama_token_eos()] = 0;
413+
logits[llama_token_eos()] = -INFINITY;
414+
}
415+
416+
std::vector<llama_token_data> candidates;
417+
candidates.reserve(n_vocab);
418+
for (size_t i = 0; i < n_vocab; i++) {
419+
candidates.emplace_back(i, logits[i], 0.0f);
408420
}
409421

410-
id = llama_sample_top_p_top_k(ctx,
411-
last_n_tokens.data() + n_ctx - params.repeat_last_n,
412-
params.repeat_last_n, top_k, top_p, temp, repeat_penalty);
422+
llama_token_data_array candidates_p = { candidates.data(), candidates.size() };
423+
424+
// Apply penalties
425+
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
426+
llama_sample_repetition_penalty(&candidates_p,
427+
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
428+
last_n_repeat, repeat_penalty);
429+
llama_sample_frequency_and_presence_penalties(&candidates_p,
430+
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
431+
last_n_repeat, alpha_frequency, alpha_presence);
432+
433+
434+
#if 1
435+
if (temp <= 0) {
436+
// Greedy sampling
437+
id = llama_sample_token_greedy(ctx, &candidates_p);
438+
} else {
439+
// Temperature sampling
440+
llama_sample_top_k(&candidates_p, top_k);
441+
llama_sample_tail_free(&candidates_p, tfs_z);
442+
llama_sample_typical(&candidates_p, typical_p);
443+
llama_sample_top_p(&candidates_p, top_p);
444+
445+
llama_sample_temperature(&candidates_p, temp);
446+
// printf("`%d`", candidates_p.size);
447+
id = llama_sample_token(ctx, &candidates_p);
448+
}
449+
#else
450+
const float tau = 5.0f;
451+
static float mu = 2.0f * tau;
452+
static int k = 40;
453+
const float eta = 0.1f;
454+
const int m = 100;
455+
const float N = n_vocab;
456+
id = llama_sample_mirostat(ctx, &candidates_p, tau, eta, m, N, &k, &mu);
457+
// id = llama_sample_mirostat_v2(ctx, &candidates_p, tau, eta, &mu);
458+
#endif
413459

414460
last_n_tokens.erase(last_n_tokens.begin());
415461
last_n_tokens.push_back(id);

0 commit comments

Comments
 (0)