@@ -184,6 +184,7 @@ struct server_slot {
184
184
// stats
185
185
size_t n_sent_text = 0 ; // number of sent text character
186
186
size_t n_sent_token_probs = 0 ;
187
+ size_t n_th_prefix = 0 ; // size of remaining token healing prefix
187
188
188
189
int64_t t_start_process_prompt;
189
190
int64_t t_start_generation;
@@ -205,6 +206,7 @@ struct server_slot {
205
206
infill = false ;
206
207
ga_i = 0 ;
207
208
n_past_se = 0 ;
209
+ n_th_prefix = 0 ;
208
210
209
211
generated_token_probs.clear ();
210
212
}
@@ -1083,6 +1085,36 @@ struct server_context {
1083
1085
}
1084
1086
}
1085
1087
1088
+ {
1089
+ const auto & token_healing_str = data.find (" token_healing" );
1090
+ auto & th_enabled = slot.sparams .token_healing_enabled ;
1091
+ th_enabled = default_sparams.token_healing_enabled ;
1092
+ if (token_healing_str != data.end () && token_healing_str->is_string ()) {
1093
+ const auto value = token_healing_str->get <std::string>();
1094
+ auto & th_type = slot.sparams .token_healing_type ;
1095
+ auto & th_n_rollback = slot.sparams .token_healing_n_rollback ;
1096
+ th_enabled = true ;
1097
+ /* */ if (value == " 0" ) { th_enabled = false ; }
1098
+ else if (value == " 1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; }
1099
+ else if (value == " d1" ) { th_type = llama_token_healing_type::DYNAMIC_ONCE; }
1100
+ else if (value == " d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; }
1101
+ else if (value[0 ] == ' r' ) {
1102
+ th_type = llama_token_healing_type::ROLLBACK_MULTI;
1103
+ th_n_rollback = std::stoi (value.substr (1 ));
1104
+ if (th_n_rollback <= 0 ) {
1105
+ th_enabled = false ;
1106
+ }
1107
+ } else { th_enabled = false ; }
1108
+
1109
+ LOG_VERBOSE (" token healing" , {
1110
+ {" id_slot" , slot.id },
1111
+ {" enabled" , th_enabled},
1112
+ {" type" , th_type},
1113
+ {" n_rollback" , th_n_rollback}
1114
+ });
1115
+ }
1116
+ }
1117
+
1086
1118
{
1087
1119
if (slot.ctx_sampling != nullptr ) {
1088
1120
llama_sampling_free (slot.ctx_sampling );
@@ -1178,14 +1210,26 @@ struct server_context {
1178
1210
}
1179
1211
1180
1212
bool process_token (completion_token_output & result, server_slot & slot) {
1181
- // remember which tokens were sampled - used for repetition penalties during sampling
1182
1213
const std::string token_str = llama_token_to_piece (ctx, result.tok , false );
1183
1214
slot.sampled = result.tok ;
1184
-
1185
- // search stop word and delete it
1186
- slot.generated_text += token_str;
1187
1215
slot.has_next_token = true ;
1188
1216
1217
+ // Suppress generating the token healing prefix to not repeat the input prompt's suffix
1218
+ bool is_token_healing = false ;
1219
+ if (slot.n_th_prefix > 0 ) {
1220
+ if (slot.n_th_prefix < token_str.size ()) {
1221
+ slot.generated_text += token_str.substr (slot.n_th_prefix );
1222
+ slot.n_th_prefix = 0 ;
1223
+ is_token_healing = false ; // to send partial token text when streaming
1224
+ } else {
1225
+ slot.n_th_prefix -= token_str.size ();
1226
+ is_token_healing = true ;
1227
+ }
1228
+ } else {
1229
+ slot.generated_text += token_str;
1230
+ }
1231
+
1232
+ // remember which tokens were sampled - used for repetition penalties during sampling
1189
1233
if (slot.ctx_sampling ->params .use_penalty_prompt_tokens && result.tok != -1 ) {
1190
1234
// we can change penalty_prompt_tokens because it is always created from scratch each request
1191
1235
slot.ctx_sampling ->params .penalty_prompt_tokens .push_back (result.tok );
@@ -1213,7 +1257,7 @@ struct server_context {
1213
1257
break ;
1214
1258
}
1215
1259
1216
- if (!incomplete) {
1260
+ if (!incomplete && !is_token_healing ) {
1217
1261
size_t pos = std::min (slot.n_sent_text , slot.generated_text .size ());
1218
1262
1219
1263
const std::string str_test = slot.generated_text .substr (pos);
@@ -1245,7 +1289,7 @@ struct server_context {
1245
1289
}
1246
1290
}
1247
1291
1248
- if (incomplete) {
1292
+ if (incomplete || is_token_healing ) {
1249
1293
slot.has_next_token = true ;
1250
1294
}
1251
1295
@@ -1350,7 +1394,8 @@ struct server_context {
1350
1394
{" n_probs" , slot.sparams .n_probs },
1351
1395
{" min_keep" , slot.sparams .min_keep },
1352
1396
{" grammar" , slot.sparams .grammar },
1353
- {" samplers" , samplers_sequence}
1397
+ {" samplers" , samplers_sequence},
1398
+ {" token_healing_enabled" , slot.sparams .token_healing_enabled }
1354
1399
};
1355
1400
}
1356
1401
@@ -2082,6 +2127,21 @@ struct server_context {
2082
2127
continue ;
2083
2128
}
2084
2129
2130
+ // Roll back prompt tokens if token healing
2131
+ llama_token_healing_output token_healing_out{};
2132
+ if (slot.sparams .token_healing_enabled ) {
2133
+ token_healing_out = llama_token_healing_rollback (ctx, slot.sparams .token_healing_type ,
2134
+ prompt_tokens, slot.sparams .token_healing_n_rollback );
2135
+ slot.n_th_prefix = token_healing_out.prefix .size ();
2136
+ slot.n_prompt_tokens = prompt_tokens.size ();
2137
+ LOG_VERBOSE (" token healing prompt" , {
2138
+ {" id_slot" , slot.id },
2139
+ {" id_task" , slot.id_task },
2140
+ {" removed_suffix" , token_healing_out.prefix },
2141
+ {" n_tokens_removed" , token_healing_out.n_tokens_removed }
2142
+ });
2143
+ }
2144
+
2085
2145
if (slot.embedding ) {
2086
2146
// this prompt is too large to process - discard it
2087
2147
if (slot.n_prompt_tokens > n_ubatch) {
@@ -2132,6 +2192,9 @@ struct server_context {
2132
2192
}
2133
2193
2134
2194
llama_sampling_reset (slot.ctx_sampling );
2195
+ if (slot.sparams .token_healing_enabled ) {
2196
+ llama_token_healing_set_prefix (slot.ctx_sampling , token_healing_out.prefix );
2197
+ }
2135
2198
2136
2199
if (!slot.params .cache_prompt ) {
2137
2200
slot.n_past_se = 0 ;
0 commit comments