Skip to content

Commit f8ebe38

Browse files
committed
server : add token healing support
1 parent 73fc9d8 commit f8ebe38

File tree

2 files changed

+72
-7
lines changed

2 files changed

+72
-7
lines changed

examples/server/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,8 @@ node index.js
266266

267267
`json_schema`: Set a JSON schema for grammar-based sampling (e.g. `{"items": {"type": "string"}, "minItems": 10, "maxItems": 100}` of a list of strings, or `{}` for any JSON). See [tests](../../tests/test-json-schema-to-grammar.cpp) for supported features. Default: no JSON schema.
268268

269+
`token_healing`: Set token healing strategy. Default: `0`, which is disabled.
270+
269271
`seed`: Set the random number generator (RNG) seed. Default: `-1`, which is a random seed.
270272

271273
`ignore_eos`: Ignore end of stream token and continue generating. Default: `false`

examples/server/server.cpp

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ struct server_slot {
184184
// stats
185185
size_t n_sent_text = 0; // number of sent text character
186186
size_t n_sent_token_probs = 0;
187+
size_t n_th_prefix = 0; // size of remaining token healing prefix
187188

188189
int64_t t_start_process_prompt;
189190
int64_t t_start_generation;
@@ -205,6 +206,7 @@ struct server_slot {
205206
infill = false;
206207
ga_i = 0;
207208
n_past_se = 0;
209+
n_th_prefix = 0;
208210

209211
generated_token_probs.clear();
210212
}
@@ -1083,6 +1085,36 @@ struct server_context {
10831085
}
10841086
}
10851087

1088+
{
1089+
const auto & token_healing_str = data.find("token_healing");
1090+
auto & th_enabled = slot.sparams.token_healing_enabled;
1091+
th_enabled = default_sparams.token_healing_enabled;
1092+
if (token_healing_str != data.end() && token_healing_str->is_string()) {
1093+
const auto value = token_healing_str->get<std::string>();
1094+
auto & th_type = slot.sparams.token_healing_type;
1095+
auto & th_n_rollback = slot.sparams.token_healing_n_rollback;
1096+
th_enabled = true;
1097+
/**/ if (value == "0" ) { th_enabled = false; }
1098+
else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; }
1099+
else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; }
1100+
else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; }
1101+
else if (value[0] == 'r' ) {
1102+
th_type = llama_token_healing_type::ROLLBACK_MULTI;
1103+
th_n_rollback = std::stoi(value.substr(1));
1104+
if (th_n_rollback <= 0) {
1105+
th_enabled = false;
1106+
}
1107+
} else { th_enabled = false; }
1108+
1109+
LOG_VERBOSE("token healing", {
1110+
{"id_slot", slot.id},
1111+
{"enabled", th_enabled},
1112+
{"type", th_type},
1113+
{"n_rollback", th_n_rollback}
1114+
});
1115+
}
1116+
}
1117+
10861118
{
10871119
if (slot.ctx_sampling != nullptr) {
10881120
llama_sampling_free(slot.ctx_sampling);
@@ -1178,14 +1210,26 @@ struct server_context {
11781210
}
11791211

11801212
bool process_token(completion_token_output & result, server_slot & slot) {
1181-
// remember which tokens were sampled - used for repetition penalties during sampling
11821213
const std::string token_str = llama_token_to_piece(ctx, result.tok, false);
11831214
slot.sampled = result.tok;
1184-
1185-
// search stop word and delete it
1186-
slot.generated_text += token_str;
11871215
slot.has_next_token = true;
11881216

1217+
// Suppress generating the token healing prefix to not repeat the input prompt's suffix
1218+
bool is_token_healing = false;
1219+
if (slot.n_th_prefix > 0) {
1220+
if (slot.n_th_prefix < token_str.size()) {
1221+
slot.generated_text += token_str.substr(slot.n_th_prefix);
1222+
slot.n_th_prefix = 0;
1223+
is_token_healing = false; // to send partial token text when streaming
1224+
} else {
1225+
slot.n_th_prefix -= token_str.size();
1226+
is_token_healing = true;
1227+
}
1228+
} else {
1229+
slot.generated_text += token_str;
1230+
}
1231+
1232+
// remember which tokens were sampled - used for repetition penalties during sampling
11891233
if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) {
11901234
// we can change penalty_prompt_tokens because it is always created from scratch each request
11911235
slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok);
@@ -1213,7 +1257,7 @@ struct server_context {
12131257
break;
12141258
}
12151259

1216-
if (!incomplete) {
1260+
if (!incomplete && !is_token_healing) {
12171261
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
12181262

12191263
const std::string str_test = slot.generated_text.substr(pos);
@@ -1245,7 +1289,7 @@ struct server_context {
12451289
}
12461290
}
12471291

1248-
if (incomplete) {
1292+
if (incomplete || is_token_healing) {
12491293
slot.has_next_token = true;
12501294
}
12511295

@@ -1350,7 +1394,8 @@ struct server_context {
13501394
{"n_probs", slot.sparams.n_probs},
13511395
{"min_keep", slot.sparams.min_keep},
13521396
{"grammar", slot.sparams.grammar},
1353-
{"samplers", samplers_sequence}
1397+
{"samplers", samplers_sequence},
1398+
{"token_healing_enabled", slot.sparams.token_healing_enabled}
13541399
};
13551400
}
13561401

@@ -2082,6 +2127,21 @@ struct server_context {
20822127
continue;
20832128
}
20842129

2130+
// Roll back prompt tokens if token healing
2131+
llama_token_healing_output token_healing_out{};
2132+
if (slot.sparams.token_healing_enabled) {
2133+
token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing_type,
2134+
prompt_tokens, slot.sparams.token_healing_n_rollback);
2135+
slot.n_th_prefix = token_healing_out.prefix.size();
2136+
slot.n_prompt_tokens = prompt_tokens.size();
2137+
LOG_VERBOSE("token healing prompt", {
2138+
{"id_slot", slot.id},
2139+
{"id_task", slot.id_task},
2140+
{"removed_suffix", token_healing_out.prefix},
2141+
{"n_tokens_removed", token_healing_out.n_tokens_removed}
2142+
});
2143+
}
2144+
20852145
if (slot.embedding) {
20862146
// this prompt is too large to process - discard it
20872147
if (slot.n_prompt_tokens > n_ubatch) {
@@ -2132,6 +2192,9 @@ struct server_context {
21322192
}
21332193

21342194
llama_sampling_reset(slot.ctx_sampling);
2195+
if (slot.sparams.token_healing_enabled) {
2196+
llama_token_healing_set_prefix(slot.ctx_sampling, token_healing_out.prefix);
2197+
}
21352198

21362199
if (!slot.params.cache_prompt) {
21372200
slot.n_past_se = 0;

0 commit comments

Comments
 (0)