Skip to content

Commit 61f822f

Browse files
Added --logit-bias and --no-penalize-nl, removed std::span
1 parent f01c67f commit 61f822f

File tree

6 files changed

+185
-165
lines changed

6 files changed

+185
-165
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ endif
3535

3636
# keep standard at C11 and C++11
3737
CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC
38-
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++20 -fPIC
38+
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC
3939
LDFLAGS =
4040

4141
# warnings

examples/common.cpp

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
#include <string>
77
#include <iterator>
88
#include <algorithm>
9+
#include <sstream>
10+
#include <iostream>
911

1012
#if defined (_WIN32)
1113
#include <fcntl.h>
@@ -138,18 +140,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
138140
break;
139141
}
140142
params.repeat_penalty = std::stof(argv[i]);
141-
} else if (arg == "--alpha_frequency") {
143+
} else if (arg == "--frequency_penalty") {
142144
if (++i >= argc) {
143145
invalid_param = true;
144146
break;
145147
}
146-
params.alpha_frequency = std::stof(argv[i]);
147-
} else if (arg == "--alpha_presence") {
148+
params.frequency_penalty = std::stof(argv[i]);
149+
} else if (arg == "--presence_penalty") {
148150
if (++i >= argc) {
149151
invalid_param = true;
150152
break;
151153
}
152-
params.alpha_presence = std::stof(argv[i]);
154+
params.presence_penalty = std::stof(argv[i]);
153155
} else if (arg == "--mirostat") {
154156
if (++i >= argc) {
155157
invalid_param = true;
@@ -227,7 +229,28 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
227229
} else if (arg == "--perplexity") {
228230
params.perplexity = true;
229231
} else if (arg == "--ignore-eos") {
230-
params.ignore_eos = true;
232+
params.logit_bias[llama_token_eos()] = -INFINITY;
233+
} else if (arg == "--no-penalize-nl") {
234+
params.penalize_nl = false;
235+
} else if (arg == "-l" || arg == "--logit-bias") {
236+
if (++i >= argc) {
237+
invalid_param = true;
238+
break;
239+
}
240+
std::stringstream ss(argv[i]);
241+
llama_token key;
242+
char sign;
243+
std::string value_str;
244+
try {
245+
if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-' || sign == '=' || sign == ':')) {
246+
params.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f);
247+
} else {
248+
throw std::exception();
249+
}
250+
} catch (const std::exception &e) {
251+
invalid_param = true;
252+
break;
253+
}
231254
} else if (arg == "--n_parts") {
232255
if (++i >= argc) {
233256
invalid_param = true;
@@ -282,19 +305,23 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
282305
fprintf(stderr, " -f FNAME, --file FNAME\n");
283306
fprintf(stderr, " prompt file to start generation.\n");
284307
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict);
285-
fprintf(stderr, " --top_k N top-k sampling (default: %d, disabled: 0)\n", params.top_k);
286-
fprintf(stderr, " --top_p N top-p sampling (default: %.1f, disabled: 1.0)\n", (double)params.top_p);
287-
fprintf(stderr, " --tfs N tail free sampling, parameter z (default: %.1f, disabled: 1.0)\n", (double)params.tfs_z);
288-
fprintf(stderr, " --typical N locally typical sampling, parameter p (default: %.1f, disabled: 1.0)\n", (double)params.typical_p);
289-
fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d, disabled: 0)\n", params.repeat_last_n);
290-
fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f, disabled: 1.0)\n", (double)params.repeat_penalty);
291-
fprintf(stderr, " --alpha_presence N repeat alpha presence (default: %.1f, disabled: 0.0)\n", (double)params.alpha_presence);
292-
fprintf(stderr, " --alpha_frequency N repeat alpha frequency (default: %.1f, disabled: 0.0)\n", (double)params.alpha_frequency);
293-
fprintf(stderr, " --mirostat N use mirostat sampling (default: %d, disabled: 0, mirostat: 1, mirostat 2.0: 2)\n", params.mirostat);
308+
fprintf(stderr, " --top_k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
309+
fprintf(stderr, " --top_p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
310+
fprintf(stderr, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z);
311+
fprintf(stderr, " --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)params.typical_p);
312+
fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d, 0 = disabled)\n", params.repeat_last_n);
313+
fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)params.repeat_penalty);
314+
fprintf(stderr, " --presence_penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)params.presence_penalty);
315+
fprintf(stderr, " --frequency_penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)params.frequency_penalty);
316+
fprintf(stderr, " --mirostat N use mirostat sampling (default: %d, 0 = disabled, 1 = mirostat, 2 = mirostat 2.0)\n", params.mirostat);
294317
fprintf(stderr, " --mirostat_eta N mirostat learning rate (default: %.1f)\n", (double)params.mirostat_eta);
295318
fprintf(stderr, " --mirostat_tau N mirostat target entropy (default: %.1f)\n", (double)params.mirostat_tau);
319+
fprintf(stderr, " -l TOKEN+BIAS, --logit-bias TOKEN+BIAS");
320+
fprintf(stderr, " modifies the likelihood of token appearing in the completion,\n");
321+
fprintf(stderr, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello'\n");
296322
fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx);
297-
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating\n");
323+
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2+-inf)\n");
324+
fprintf(stderr, " --no-penalize-nl do not penalize newline token\n");
298325
fprintf(stderr, " --memory_f32 use f32 instead of f16 for memory key+value\n");
299326
fprintf(stderr, " --temp N temperature (default: %.1f)\n", (double)params.temp);
300327
fprintf(stderr, " --n_parts N number of model parts (default: -1 = determine from dimensions)\n");

examples/common.h

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <vector>
99
#include <random>
1010
#include <thread>
11+
#include <unordered_map>
1112

1213
//
1314
// CLI argument parsing
@@ -23,18 +24,19 @@ struct gpt_params {
2324
int32_t n_keep = 0; // number of tokens to keep from initial prompt
2425

2526
// sampling parameters
26-
int32_t top_k = 0; // <= 0 to use vocab size
27-
float top_p = 1.0f; // 1.0 = disabled
28-
float tfs_z = 1.0f; // 1.0 = disabled
29-
float typical_p = 1.0f; // 1.0 = disabled
30-
float temp = 1.0f; // 1.0 = disabled
27+
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
28+
int32_t top_k = 0; // <= 0 to use vocab size
29+
float top_p = 1.0f; // 1.0 = disabled
30+
float tfs_z = 1.0f; // 1.0 = disabled
31+
float typical_p = 1.0f; // 1.0 = disabled
32+
float temp = 1.0f; // 1.0 = disabled
3133
float repeat_penalty = 1.0f; // 1.0 = disabled
32-
int32_t repeat_last_n = -1; // last n tokens to penalize (0 = disable penalty, -1 = context size)
33-
float alpha_frequency = 0.0f; // 0.0 = disabled
34-
float alpha_presence = 0.0f; // 0.0 = disabled
35-
int mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
36-
float mirostat_tau = 5.0f; // target entropy
37-
float mirostat_eta = 0.1f; // learning rate
34+
int32_t repeat_last_n = -1; // last n tokens to penalize (0 = disable penalty, -1 = context size)
35+
float frequency_penalty = 0.0f; // 0.0 = disabled
36+
float presence_penalty = 0.0f; // 0.0 = disabled
37+
int mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
38+
float mirostat_tau = 5.0f; // target entropy
39+
float mirostat_eta = 0.1f; // learning rate
3840

3941
std::string model = "models/lamma-7B/ggml-model.bin"; // model path
4042
std::string prompt = "";
@@ -54,7 +56,7 @@ struct gpt_params {
5456
bool interactive_first = false; // wait for user input immediately
5557

5658
bool instruct = false; // instruction mode (used for Alpaca models)
57-
bool ignore_eos = false; // do not stop generating after eos
59+
bool penalize_nl = true; // consider newlines as a repeatable token
5860
bool perplexity = false; // compute perplexity over the prompt
5961
bool use_mmap = true; // use mmap for faster loads
6062
bool use_mlock = false; // use mlock to keep model in memory

examples/main/main.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,8 @@ int main(int argc, char ** argv) {
276276
fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
277277
}
278278
}
279-
fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, alpha_presence = %f, alpha_frequency = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_eta = %f, mirostat_tau = %f\n",
280-
params.repeat_last_n, params.repeat_penalty, params.alpha_presence, params.alpha_frequency, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau);
279+
fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_eta = %f, mirostat_tau = %f\n",
280+
params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau);
281281
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
282282
fprintf(stderr, "\n\n");
283283

@@ -394,11 +394,12 @@ int main(int argc, char ** argv) {
394394
const float typical_p = params.typical_p;
395395
const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
396396
const float repeat_penalty = params.repeat_penalty;
397-
const float alpha_presence = params.alpha_presence;
398-
const float alpha_frequency = params.alpha_frequency;
399-
const int mirostat = params.mirostat;
397+
const float alpha_presence = params.presence_penalty;
398+
const float alpha_frequency = params.frequency_penalty;
399+
const int mirostat = params.mirostat;
400400
const float mirostat_tau = params.mirostat_tau;
401401
const float mirostat_eta = params.mirostat_eta;
402+
const bool penalize_nl = params.penalize_nl;
402403

403404
// optionally save the session on first sample (for faster prompt loading next time)
404405
if (!path_session.empty() && need_to_save_session) {
@@ -412,8 +413,9 @@ int main(int argc, char ** argv) {
412413
auto logits = llama_get_logits(ctx);
413414
auto n_vocab = llama_n_vocab(ctx);
414415

415-
if (params.ignore_eos) {
416-
logits[llama_token_eos()] = -INFINITY;
416+
// Apply params.logit_bias map
417+
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
418+
logits[it->first] += it->second;
417419
}
418420

419421
std::vector<llama_token_data> candidates;
@@ -425,14 +427,17 @@ int main(int argc, char ** argv) {
425427
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
426428

427429
// Apply penalties
430+
float nl_logit = logits[llama_token_nl()];
428431
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
429432
llama_sample_repetition_penalty(ctx, &candidates_p,
430433
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
431434
last_n_repeat, repeat_penalty);
432435
llama_sample_frequency_and_presence_penalties(ctx, &candidates_p,
433436
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
434437
last_n_repeat, alpha_frequency, alpha_presence);
435-
438+
if (!penalize_nl) {
439+
logits[llama_token_nl()] = nl_logit;
440+
}
436441

437442
if (temp <= 0) {
438443
// Greedy sampling

0 commit comments

Comments
 (0)