@@ -280,11 +280,10 @@ int main(int argc, char ** argv) {
280
280
sparams.token_healing_enabled = false ;
281
281
LOG (" token healing: disabled due to custom suffix/conversation mode" );
282
282
}
283
- std::string token_healing_prefix;
284
- int token_healing_n_removed = 0 ;
283
+ llama_token_healing_output token_healing_out{};
285
284
if (!params.interactive_first && sparams.token_healing_enabled ) {
286
- token_healing_prefix = llama_token_healing_rollback (ctx, sparams.token_healing_type , embd_inp,
287
- sparams.token_healing_n_rollback , &token_healing_n_removed );
285
+ token_healing_out = llama_token_healing_rollback (ctx, sparams.token_healing_type , embd_inp,
286
+ sparams.token_healing_n_rollback );
288
287
}
289
288
290
289
// Should not run without any tokens
@@ -306,7 +305,7 @@ int main(int argc, char ** argv) {
306
305
std::vector<llama_token> original_inp = ::llama_tokenize (ctx, params.prompt , true , true );
307
306
LOG (" original_inp tokenized: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, original_inp).c_str ());
308
307
309
- original_prompt_len = original_inp.size () - token_healing_n_removed ;
308
+ original_prompt_len = original_inp.size () - token_healing_out. n_tokens_removed ;
310
309
guidance_offset = (int )guidance_inp.size () - original_prompt_len;
311
310
LOG (" original_prompt_len: %s" , log_tostr (original_prompt_len));
312
311
LOG (" guidance_offset: %s" , log_tostr (guidance_offset));
@@ -528,7 +527,7 @@ int main(int argc, char ** argv) {
528
527
fprintf (stderr, " %s: failed to initialize sampling subsystem\n " , __func__);
529
528
exit (1 );
530
529
}
531
- llama_token_healing_set_prefix (ctx_sampling, token_healing_prefix );
530
+ llama_token_healing_set_prefix (ctx_sampling, token_healing_out. prefix );
532
531
533
532
while ((n_remain != 0 && !is_antiprompt) || params.interactive ) {
534
533
// predict
@@ -845,7 +844,8 @@ int main(int argc, char ** argv) {
845
844
assistant_ss << llama_token_to_piece (ctx, id, false );
846
845
}
847
846
848
- token_healing_n_removed = 0 ;
847
+ token_healing_out = {};
848
+
849
849
if (n_past > 0 && is_interacting) {
850
850
LOG (" waiting for user input\n " );
851
851
@@ -917,9 +917,8 @@ int main(int argc, char ** argv) {
917
917
const int max_to_remove = sparams.token_healing_n_rollback < 0
918
918
? n_new_tokens
919
919
: std::min (sparams.token_healing_n_rollback , n_new_tokens);
920
- token_healing_prefix = llama_token_healing_rollback (ctx, sparams.token_healing_type , embd_inp,
921
- max_to_remove, &token_healing_n_removed);
922
- n_bytes_to_skip = token_healing_prefix.size ();
920
+ token_healing_out = llama_token_healing_rollback (ctx, sparams.token_healing_type , embd_inp, max_to_remove);
921
+ n_bytes_to_skip = token_healing_out.prefix .size ();
923
922
}
924
923
925
924
for (size_t i = original_size; i < embd_inp.size (); ++i) {
@@ -931,7 +930,7 @@ int main(int argc, char ** argv) {
931
930
// reset assistant message
932
931
assistant_ss.str (" " );
933
932
934
- n_remain -= line_inp.size () + token_healing_n_removed ;
933
+ n_remain -= line_inp.size () + token_healing_out. n_tokens_removed ;
935
934
LOG (" n_remain: %d\n " , n_remain);
936
935
} else {
937
936
LOG (" empty line, passing control back\n " );
@@ -943,9 +942,9 @@ int main(int argc, char ** argv) {
943
942
if (n_past > 0 ) {
944
943
if (is_interacting) {
945
944
llama_sampling_reset (ctx_sampling);
946
- if (token_healing_n_removed > 0 ) {
945
+ if (token_healing_out. n_tokens_removed > 0 ) {
947
946
// Set new prefix after an interaction
948
- llama_token_healing_set_prefix (ctx_sampling, token_healing_prefix );
947
+ llama_token_healing_set_prefix (ctx_sampling, token_healing_out. prefix );
949
948
}
950
949
}
951
950
is_interacting = false ;
0 commit comments