Skip to content

Commit e1ece90

Browse files
committed
cont : simplify logit_bias + add ignore_eos flag
ggml-ci
1 parent 052ec33 commit e1ece90

File tree

7 files changed

+38
-27
lines changed

7 files changed

+38
-27
lines changed

common/common.cpp

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,7 +1039,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
10391039
return true;
10401040
}
10411041
if (arg == "--ignore-eos") {
1042-
params.ignore_eos = true;
1042+
sparams.ignore_eos = true;
10431043
return true;
10441044
}
10451045
if (arg == "--penalize-nl") {
@@ -1054,7 +1054,8 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
10541054
std::string value_str;
10551055
try {
10561056
if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) {
1057-
sparams.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f);
1057+
const float bias = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f);
1058+
sparams.logit_bias.push_back({key, bias});
10581059
}
10591060
else {
10601061
throw std::exception();
@@ -2165,8 +2166,9 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
21652166
llama_lora_adapters_apply(lctx, iparams.lora_adapters);
21662167
}
21672168

2168-
if (params.ignore_eos) {
2169-
params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
2169+
if (params.sparams.ignore_eos && llama_token_eos(model) == -1) {
2170+
fprintf(stderr, "%s: warning: model does not have an EOS token, ignoring --ignore-eos\n", __func__);
2171+
params.sparams.ignore_eos = false;
21702172
}
21712173

21722174
if (params.warmup) {
@@ -3205,10 +3207,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
32053207
fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n");
32063208
fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
32073209
fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks);
3208-
3209-
const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(llama_get_model(lctx)));
3210-
const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY;
3211-
fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false");
3210+
fprintf(stream, "ignore_eos: %s # default: false\n", sparams.ignore_eos ? "true" : "false");
32123211

32133212
yaml_dump_string_multiline(stream, "in_prefix", params.input_prefix.c_str());
32143213
fprintf(stream, "in_prefix_bos: %s # default: false\n", params.input_prefix_bos ? "true" : "false");
@@ -3219,11 +3218,8 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
32193218
fprintf(stream, "logdir: %s # default: unset (no logging)\n", params.logdir.c_str());
32203219

32213220
fprintf(stream, "logit_bias:\n");
3222-
for (std::pair<llama_token, float> lb : sparams.logit_bias) {
3223-
if (ignore_eos && lb.first == logit_bias_eos->first) {
3224-
continue;
3225-
}
3226-
fprintf(stream, " %d: %f", lb.first, lb.second);
3221+
for (const auto & logit_bias : sparams.logit_bias) {
3222+
fprintf(stream, " %d: %f", logit_bias.token, logit_bias.bias);
32273223
}
32283224

32293225
fprintf(stream, "lora:\n");

common/common.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,6 @@ struct gpt_params {
172172
bool flash_attn = false; // flash attention
173173

174174
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
175-
bool ignore_eos = false; // ignore generated EOS tokens
176175
bool logits_all = false; // return logits for all tokens in the batch
177176
bool use_mmap = true; // use mmap for faster loads
178177
bool use_mlock = false; // use mlock to keep model in memory

common/sampling.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,12 @@ static llama_token_data_array llama_sampling_prepare_impl(
332332
}
333333

334334
// apply params.logit_bias map
335-
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
336-
logits[it->first] += it->second;
335+
for (const auto & logit_bias : params.logit_bias) {
336+
logits[logit_bias.token] += logit_bias.bias;
337+
}
338+
339+
if (params.ignore_eos) {
340+
logits[llama_token_eos(llama_get_model(ctx_main))] = -INFINITY;
337341
}
338342

339343
if (ctx_cfg) {

common/sampling.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#include "llama.h"
44

55
#include <string>
6-
#include <unordered_map>
76
#include <vector>
87

98
// sampler types
@@ -37,6 +36,7 @@ typedef struct gpt_sampling_params {
3736
float mirostat_tau = 5.00f; // target entropy
3837
float mirostat_eta = 0.10f; // learning rate
3938
bool penalize_nl = false; // consider newlines as a repeatable token
39+
bool ignore_eos = false;
4040
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
4141

4242
std::vector<llama_sampler_type> samplers_sequence = {
@@ -55,7 +55,7 @@ typedef struct gpt_sampling_params {
5555
std::string cfg_negative_prompt; // string to help guidance
5656
float cfg_scale = 1.f; // how strong is guidance
5757

58-
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
58+
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
5959

6060
std::vector<llama_token> penalty_prompt_tokens;
6161
bool use_penalty_prompt_tokens = false;

examples/server/server.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,7 +1034,7 @@ struct server_context {
10341034
slot.sparams.logit_bias.clear();
10351035

10361036
if (json_value(data, "ignore_eos", false) && has_eos_token) {
1037-
slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
1037+
slot.sparams.logit_bias.push_back({llama_token_eos(model), -INFINITY});
10381038
}
10391039

10401040
const auto & logit_bias = data.find("logit_bias");
@@ -1055,12 +1055,12 @@ struct server_context {
10551055
if (el[0].is_number_integer()) {
10561056
llama_token tok = el[0].get<llama_token>();
10571057
if (tok >= 0 && tok < n_vocab) {
1058-
slot.sparams.logit_bias[tok] = bias;
1058+
slot.sparams.logit_bias.push_back({tok, bias});
10591059
}
10601060
} else if (el[0].is_string()) {
10611061
auto toks = llama_tokenize(model, el[0].get<std::string>(), false);
10621062
for (auto tok : toks) {
1063-
slot.sparams.logit_bias[tok] = bias;
1063+
slot.sparams.logit_bias.push_back({tok, bias});
10641064
}
10651065
}
10661066
}
@@ -1313,9 +1313,6 @@ struct server_context {
13131313
}
13141314

13151315
json get_formated_generation(const server_slot & slot) const {
1316-
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
1317-
const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second);
1318-
13191316
std::vector<std::string> samplers_sequence;
13201317
samplers_sequence.reserve(slot.sparams.samplers_sequence.size());
13211318
for (const auto & sampler_type : slot.sparams.samplers_sequence) {
@@ -1349,13 +1346,13 @@ struct server_context {
13491346
{"max_tokens", slot.params.n_predict}, // User configured n_predict
13501347
{"n_keep", slot.params.n_keep},
13511348
{"n_discard", slot.params.n_discard},
1352-
{"ignore_eos", ignore_eos},
1349+
{"ignore_eos", slot.sparams.ignore_eos},
13531350
{"stream", slot.params.stream},
1354-
{"logit_bias", slot.sparams.logit_bias},
1351+
//{"logit_bias", slot.sparams.logit_bias},
13551352
{"n_probs", slot.sparams.n_probs},
13561353
{"min_keep", slot.sparams.min_keep},
13571354
{"grammar", slot.sparams.grammar},
1358-
{"samplers", samplers_sequence}
1355+
{"samplers", samplers_sequence},
13591356
};
13601357
}
13611358

include/llama.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,11 @@ extern "C" {
356356
void * kv_overrides; // pointer to vector containing overrides
357357
} llama_model_quantize_params;
358358

359+
typedef struct llama_logit_bias {
360+
llama_token token;
361+
float bias;
362+
} llama_logit_bias;
363+
359364
// parameters for sampling the logits
360365
typedef struct llama_sampling_params {
361366
uint32_t seed; // the seed used to initialize llama_sampling_context
@@ -378,6 +383,12 @@ extern "C" {
378383
float mirostat_tau; // target entropy
379384
float mirostat_eta; // learning rate
380385
bool penalize_nl; // consider newlines as a repeatable token
386+
bool ignore_eos; // ignore the end-of-sequence token
387+
388+
const char * grammar;
389+
390+
int32_t n_logit_bias;
391+
const llama_logit_bias * logit_bias;
381392
} llama_sampling_params;
382393

383394
// performance timing information

src/llama.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16902,6 +16902,10 @@ struct llama_sampling_params llama_sampling_default_params() {
1690216902
/*.mirostat_tau =*/ 5.00f,
1690316903
/*.mirostat_eta =*/ 0.10f,
1690416904
/*.penalize_nl =*/ false,
16905+
/*.ignore_eos =*/ false,
16906+
/*.grammar =*/ nullptr,
16907+
/*.n_logit_bias =*/ 0,
16908+
/*.logit_bias =*/ nullptr,
1690516909
};
1690616910

1690716911
return result;

0 commit comments

Comments
 (0)