@@ -834,13 +834,16 @@ int llama_main(
834
834
llama_vocab vocab,
835
835
llama_model model,
836
836
int64_t t_load_us,
837
- int64_t t_main_start_us) {
837
+ int64_t t_main_start_us,
838
+ std::istream & instream,
839
+ FILE *outstream,
840
+ FILE *errstream) {
838
841
839
842
if (params.seed < 0 ) {
840
843
params.seed = time (NULL );
841
844
}
842
845
843
- fprintf (stderr , " %s: seed = %d\n " , __func__, params.seed );
846
+ fprintf (errstream , " %s: seed = %d\n " , __func__, params.seed );
844
847
845
848
std::mt19937 rng (params.seed );
846
849
if (params.random_prompt ) {
@@ -888,13 +891,13 @@ int llama_main(
888
891
params.interactive = true ;
889
892
}
890
893
891
- fprintf (stderr , " \n " );
892
- fprintf (stderr , " %s: prompt: '%s'\n " , __func__, params.prompt .c_str ());
893
- fprintf (stderr , " %s: number of tokens in prompt = %zu\n " , __func__, embd_inp.size ());
894
+ fprintf (errstream , " \n " );
895
+ fprintf (errstream , " %s: prompt: '%s'\n " , __func__, params.prompt .c_str ());
896
+ fprintf (errstream , " %s: number of tokens in prompt = %zu\n " , __func__, embd_inp.size ());
894
897
for (int i = 0 ; i < (int ) embd_inp.size (); i++) {
895
- fprintf (stderr , " %6d -> '%s'\n " , embd_inp[i], vocab.id_to_token .at (embd_inp[i]).c_str ());
898
+ fprintf (errstream , " %6d -> '%s'\n " , embd_inp[i], vocab.id_to_token .at (embd_inp[i]).c_str ());
896
899
}
897
- fprintf (stderr , " \n " );
900
+ fprintf (errstream , " \n " );
898
901
if (params.interactive ) {
899
902
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
900
903
struct sigaction sigint_action;
@@ -906,16 +909,16 @@ int llama_main(
906
909
signal (SIGINT, sigint_handler);
907
910
#endif
908
911
909
- fprintf (stderr , " %s: interactive mode on.\n " , __func__);
912
+ fprintf (errstream , " %s: interactive mode on.\n " , __func__);
910
913
911
914
if (params.antiprompt .size ()) {
912
915
for (auto antiprompt : params.antiprompt ) {
913
- fprintf (stderr , " Reverse prompt: '%s'\n " , antiprompt.c_str ());
916
+ fprintf (errstream , " Reverse prompt: '%s'\n " , antiprompt.c_str ());
914
917
}
915
918
}
916
919
}
917
- fprintf (stderr , " sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n " , params.temp , params.top_k , params.top_p , params.repeat_last_n , params.repeat_penalty );
918
- fprintf (stderr , " \n\n " );
920
+ fprintf (errstream , " sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n " , params.temp , params.top_k , params.top_p , params.repeat_last_n , params.repeat_penalty );
921
+ fprintf (errstream , " \n\n " );
919
922
920
923
std::vector<llama_vocab::id> embd;
921
924
@@ -924,7 +927,7 @@ int llama_main(
924
927
std::fill (last_n_tokens.begin (), last_n_tokens.end (), 0 );
925
928
926
929
if (params.interactive ) {
927
- fprintf (stderr , " == Running in interactive mode. ==\n "
930
+ fprintf (errstream , " == Running in interactive mode. ==\n "
928
931
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
929
932
" - Press Ctrl+C to interject at any time.\n "
930
933
#endif
@@ -948,7 +951,7 @@ int llama_main(
948
951
SetConsoleMode (hConOut, dwMode | 0x4 ); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4)
949
952
}
950
953
#endif
951
- printf ( ANSI_COLOR_YELLOW);
954
+ fprintf (outstream, ANSI_COLOR_YELLOW);
952
955
}
953
956
954
957
while (remaining_tokens > 0 || params.interactive ) {
@@ -957,7 +960,7 @@ int llama_main(
957
960
const int64_t t_start_us = ggml_time_us ();
958
961
959
962
if (!llama_eval (model, params.n_threads , n_past, embd, logits, mem_per_token)) {
960
- fprintf (stderr , " Failed to predict\n " );
963
+ fprintf (errstream , " Failed to predict\n " );
961
964
return 1 ;
962
965
}
963
966
@@ -1018,13 +1021,13 @@ int llama_main(
1018
1021
// display text
1019
1022
if (!input_noecho) {
1020
1023
for (auto id : embd) {
1021
- printf ( " %s" , vocab.id_to_token [id].c_str ());
1024
+ fprintf (outstream, " %s" , vocab.id_to_token [id].c_str ());
1022
1025
}
1023
- fflush (stdout );
1026
+ fflush (outstream );
1024
1027
}
1025
1028
// reset color to default if we there is no pending user input
1026
1029
if (!input_noecho && params.use_color && (int )embd_inp.size () == input_consumed) {
1027
- printf ( ANSI_COLOR_RESET);
1030
+ fprintf (outstream, ANSI_COLOR_RESET);
1028
1031
}
1029
1032
1030
1033
// in interactive mode, and not currently processing queued inputs;
@@ -1048,24 +1051,24 @@ int llama_main(
1048
1051
input_consumed = embd_inp.size ();
1049
1052
embd_inp.insert (embd_inp.end (), inp_pfx.begin (), inp_pfx.end ());
1050
1053
1051
- printf ( " \n > " );
1054
+ fprintf (outstream, " \n > " );
1052
1055
}
1053
1056
1054
1057
// currently being interactive
1055
- if (params.use_color ) printf ( ANSI_BOLD ANSI_COLOR_GREEN);
1058
+ if (params.use_color ) fprintf (outstream, ANSI_BOLD ANSI_COLOR_GREEN);
1056
1059
std::string buffer;
1057
1060
std::string line;
1058
1061
bool another_line = true ;
1059
1062
do {
1060
- std::getline (std::cin , line);
1063
+ std::getline (instream , line);
1061
1064
if (line.empty () || line.back () != ' \\ ' ) {
1062
1065
another_line = false ;
1063
1066
} else {
1064
1067
line.pop_back (); // Remove the continue character
1065
1068
}
1066
1069
buffer += line + ' \n ' ; // Append the line to the result
1067
1070
} while (another_line);
1068
- if (params.use_color ) printf ( ANSI_COLOR_RESET);
1071
+ if (params.use_color ) fprintf (outstream, ANSI_COLOR_RESET);
1069
1072
1070
1073
std::vector<llama_vocab::id> line_inp = ::llama_tokenize (vocab, buffer, false );
1071
1074
embd_inp.insert (embd_inp.end (), line_inp.begin (), line_inp.end ());
@@ -1086,7 +1089,7 @@ int llama_main(
1086
1089
if (params.interactive ) {
1087
1090
is_interacting = true ;
1088
1091
} else {
1089
- fprintf (stderr , " [end of text]\n " );
1092
+ fprintf (errstream , " [end of text]\n " );
1090
1093
break ;
1091
1094
}
1092
1095
}
@@ -1106,18 +1109,18 @@ int llama_main(
1106
1109
{
1107
1110
const int64_t t_main_end_us = ggml_time_us ();
1108
1111
1109
- fprintf (stderr , " \n\n " );
1110
- fprintf (stderr , " %s: mem per token = %8zu bytes\n " , __func__, mem_per_token);
1111
- fprintf (stderr , " %s: load time = %8.2f ms\n " , __func__, t_load_us/1000 .0f );
1112
- fprintf (stderr , " %s: sample time = %8.2f ms\n " , __func__, t_sample_us/1000 .0f );
1113
- fprintf (stderr , " %s: predict time = %8.2f ms / %.2f ms per token\n " , __func__, t_predict_us/1000 .0f , t_predict_us/1000 .0f /n_past);
1114
- fprintf (stderr , " %s: total time = %8.2f ms\n " , __func__, (t_main_end_us - t_main_start_us)/1000 .0f );
1112
+ fprintf (errstream , " \n\n " );
1113
+ fprintf (errstream , " %s: mem per token = %8zu bytes\n " , __func__, mem_per_token);
1114
+ fprintf (errstream , " %s: load time = %8.2f ms\n " , __func__, t_load_us/1000 .0f );
1115
+ fprintf (errstream , " %s: sample time = %8.2f ms\n " , __func__, t_sample_us/1000 .0f );
1116
+ fprintf (errstream , " %s: predict time = %8.2f ms / %.2f ms per token\n " , __func__, t_predict_us/1000 .0f , t_predict_us/1000 .0f /n_past);
1117
+ fprintf (errstream , " %s: total time = %8.2f ms\n " , __func__, (t_main_end_us - t_main_start_us)/1000 .0f );
1115
1118
}
1116
1119
1117
1120
ggml_free (model.ctx );
1118
1121
1119
1122
if (params.use_color ) {
1120
- printf ( ANSI_COLOR_RESET);
1123
+ fprintf (outstream, ANSI_COLOR_RESET);
1121
1124
}
1122
1125
1123
1126
return 0 ;
0 commit comments