Skip to content

Commit 6f9722c

Browse files
committed
server : allow to specify custom prompt for penalty calculation
1 parent 8a5be3b commit 6f9722c

File tree

4 files changed

+54
-3
lines changed

4 files changed

+54
-3
lines changed

common/sampling.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,12 +193,14 @@ llama_token llama_sampling_sample(
193193
}
194194

195195
// apply penalties
196-
if (!prev.empty()) {
196+
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
197+
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
198+
if (penalty_tokens_used_size) {
197199
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
198200

199201
llama_sample_repetition_penalties(ctx_main, &cur_p,
200-
prev.data() + prev.size() - penalty_last_n,
201-
penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
202+
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
203+
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
202204

203205
if (!penalize_nl) {
204206
for (size_t idx = 0; idx < cur_p.size; idx++) {

common/sampling.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ typedef struct llama_sampling_params {
3636
float cfg_scale = 1.f; // how strong is guidance
3737

3838
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
39+
40+
std::vector<llama_token> penalty_prompt_tokens;
41+
bool use_penalty_prompt_tokens = false;
3942
} llama_sampling_params;
4043

4144
// general sampler context

examples/server/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ node index.js
148148

149149
`frequency_penalty`: Repeat alpha frequency penalty (default: 0.0, 0.0 = disabled);
150150

151+
`penalty_prompt`: This will replace the `prompt` for the purpose of the penalty evaluation. Can be either `null`, a string or an array of numbers representing tokens (default: `null` = use the original `prompt`).
152+
151153
`mirostat`: Enable Mirostat sampling, controlling perplexity during text generation (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0).
152154

153155
`mirostat_tau`: Set the Mirostat target entropy, parameter tau (default: 5.0).

examples/server/server.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -760,6 +760,42 @@ struct llama_server_context
760760
slot->prompt = "";
761761
}
762762

763+
slot->sparams.penalty_prompt_tokens.clear();
764+
slot->sparams.use_penalty_prompt_tokens = false;
765+
const auto &penalty_prompt = data.find("penalty_prompt");
766+
if (penalty_prompt != data.end())
767+
{
768+
if (penalty_prompt->is_string())
769+
{
770+
const auto penalty_prompt_string = penalty_prompt->get<std::string>();
771+
auto penalty_tokens = llama_tokenize(model, penalty_prompt_string, false);
772+
slot->sparams.penalty_prompt_tokens.swap(penalty_tokens);
773+
if (slot->params.n_predict > 0)
774+
{
775+
slot->sparams.penalty_prompt_tokens.reserve(slot->sparams.penalty_prompt_tokens.size() + slot->params.n_predict);
776+
}
777+
slot->sparams.use_penalty_prompt_tokens = true;
778+
}
779+
else if (penalty_prompt->is_array())
780+
{
781+
const auto n_tokens = penalty_prompt->size();
782+
slot->sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot->params.n_predict));
783+
const int n_vocab = llama_n_vocab(model);
784+
for (const auto &penalty_token : *penalty_prompt)
785+
{
786+
if (penalty_token.is_number_integer())
787+
{
788+
const auto tok = penalty_token.get<llama_token>();
789+
if (tok >= 0 && tok < n_vocab)
790+
{
791+
slot->sparams.penalty_prompt_tokens.push_back(tok);
792+
}
793+
}
794+
}
795+
slot->sparams.use_penalty_prompt_tokens = true;
796+
}
797+
}
798+
763799
slot->sparams.logit_bias.clear();
764800

765801
if (json_value(data, "ignore_eos", false))
@@ -991,6 +1027,12 @@ struct llama_server_context
9911027
slot.generated_text += token_str;
9921028
slot.has_next_token = true;
9931029

1030+
if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1)
1031+
{
1032+
// we can change penalty_prompt_tokens because it is always created from scratch each request
1033+
slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok);
1034+
}
1035+
9941036
// check if there is incomplete UTF-8 character at the end
9951037
bool incomplete = false;
9961038
for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i)
@@ -1182,6 +1224,8 @@ struct llama_server_context
11821224
{"repeat_penalty", slot.sparams.penalty_repeat},
11831225
{"presence_penalty", slot.sparams.penalty_present},
11841226
{"frequency_penalty", slot.sparams.penalty_freq},
1227+
{"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens},
1228+
{"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens},
11851229
{"mirostat", slot.sparams.mirostat},
11861230
{"mirostat_tau", slot.sparams.mirostat_tau},
11871231
{"mirostat_eta", slot.sparams.mirostat_eta},

0 commit comments

Comments
 (0)