@@ -44,21 +44,21 @@ enum console_state {
44
44
static console_state con_st = CONSOLE_STATE_DEFAULT;
45
45
static bool con_use_color = false ;
46
46
47
- void set_console_state (console_state new_st)
47
+ void set_console_state (FILE *stream, console_state new_st)
48
48
{
49
49
if (!con_use_color) return ;
50
50
// only emit color code if state changed
51
51
if (new_st != con_st) {
52
52
con_st = new_st;
53
53
switch (con_st) {
54
54
case CONSOLE_STATE_DEFAULT:
55
- printf ( ANSI_COLOR_RESET);
55
+ fprintf (stream, ANSI_COLOR_RESET);
56
56
return ;
57
57
case CONSOLE_STATE_PROMPT:
58
- printf ( ANSI_COLOR_YELLOW);
58
+ fprintf (stream, ANSI_COLOR_YELLOW);
59
59
return ;
60
60
case CONSOLE_STATE_USER_INPUT:
61
- printf ( ANSI_BOLD ANSI_COLOR_GREEN);
61
+ fprintf (stream, ANSI_BOLD ANSI_COLOR_GREEN);
62
62
return ;
63
63
}
64
64
}
@@ -68,7 +68,7 @@ static bool is_interacting = false;
68
68
69
69
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
70
70
void sigint_handler (int signo) {
71
- set_console_state (CONSOLE_STATE_DEFAULT);
71
+ set_console_state (stdout, CONSOLE_STATE_DEFAULT);
72
72
printf (" \n " ); // this also force flush stdout.
73
73
if (signo == SIGINT) {
74
74
if (!is_interacting) {
@@ -80,13 +80,17 @@ void sigint_handler(int signo) {
80
80
}
81
81
#endif
82
82
83
- int run (llama_context * ctx, gpt_params params) {
83
+ int run (llama_context * ctx,
84
+ gpt_params params,
85
+ std::istream & instream,
86
+ FILE *outstream,
87
+ FILE *errstream) {
84
88
85
89
if (params.seed <= 0 ) {
86
90
params.seed = time (NULL );
87
91
}
88
92
89
- fprintf (stderr , " %s: seed = %d\n " , __func__, params.seed );
93
+ fprintf (errstream , " %s: seed = %d\n " , __func__, params.seed );
90
94
91
95
std::mt19937 rng (params.seed );
92
96
if (params.random_prompt ) {
@@ -138,13 +142,13 @@ int run(llama_context * ctx, gpt_params params) {
138
142
params.interactive = true ;
139
143
}
140
144
141
- fprintf (stderr , " \n " );
142
- fprintf (stderr , " %s: prompt: '%s'\n " , __func__, params.prompt .c_str ());
143
- fprintf (stderr , " %s: number of tokens in prompt = %zu\n " , __func__, embd_inp.size ());
145
+ fprintf (errstream , " \n " );
146
+ fprintf (errstream , " %s: prompt: '%s'\n " , __func__, params.prompt .c_str ());
147
+ fprintf (errstream , " %s: number of tokens in prompt = %zu\n " , __func__, embd_inp.size ());
144
148
for (int i = 0 ; i < (int ) embd_inp.size (); i++) {
145
- fprintf (stderr , " %6d -> '%s'\n " , embd_inp[i], llama_token_to_str (ctx, embd_inp[i]));
149
+ fprintf (errstream , " %6d -> '%s'\n " , embd_inp[i], llama_token_to_str (ctx, embd_inp[i]));
146
150
}
147
- fprintf (stderr , " \n " );
151
+ fprintf (errstream , " \n " );
148
152
if (params.interactive ) {
149
153
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
150
154
struct sigaction sigint_action;
@@ -156,16 +160,16 @@ int run(llama_context * ctx, gpt_params params) {
156
160
signal (SIGINT, sigint_handler);
157
161
#endif
158
162
159
- fprintf (stderr , " %s: interactive mode on.\n " , __func__);
163
+ fprintf (errstream , " %s: interactive mode on.\n " , __func__);
160
164
161
165
if (params.antiprompt .size ()) {
162
166
for (auto antiprompt : params.antiprompt ) {
163
- fprintf (stderr , " Reverse prompt: '%s'\n " , antiprompt.c_str ());
167
+ fprintf (errstream , " Reverse prompt: '%s'\n " , antiprompt.c_str ());
164
168
}
165
169
}
166
170
}
167
- fprintf (stderr , " sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n " , params.temp , params.top_k , params.top_p , params.repeat_last_n , params.repeat_penalty );
168
- fprintf (stderr , " \n\n " );
171
+ fprintf (errstream , " sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n " , params.temp , params.top_k , params.top_p , params.repeat_last_n , params.repeat_penalty );
172
+ fprintf (errstream , " \n\n " );
169
173
170
174
std::vector<llama_token> embd;
171
175
@@ -174,7 +178,7 @@ int run(llama_context * ctx, gpt_params params) {
174
178
std::fill (last_n_tokens.begin (), last_n_tokens.end (), 0 );
175
179
176
180
if (params.interactive ) {
177
- fprintf (stderr , " == Running in interactive mode. ==\n "
181
+ fprintf (errstream , " == Running in interactive mode. ==\n "
178
182
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
179
183
" - Press Ctrl+C to interject at any time.\n "
180
184
#endif
@@ -199,13 +203,13 @@ int run(llama_context * ctx, gpt_params params) {
199
203
}
200
204
#endif
201
205
// the first thing we will do is to output the prompt, so set color accordingly
202
- set_console_state (CONSOLE_STATE_PROMPT);
206
+ set_console_state (outstream, CONSOLE_STATE_PROMPT);
203
207
204
208
while (remaining_tokens > 0 || params.interactive ) {
205
209
// predict
206
210
if (embd.size () > 0 ) {
207
211
if (llama_eval (ctx, embd.data (), embd.size (), n_past, params.n_threads )) {
208
- fprintf (stderr , " %s : failed to eval\n " , __func__);
212
+ fprintf (errstream , " %s : failed to eval\n " , __func__);
209
213
return 1 ;
210
214
}
211
215
}
@@ -263,13 +267,13 @@ int run(llama_context * ctx, gpt_params params) {
263
267
// display text
264
268
if (!input_noecho) {
265
269
for (auto id : embd) {
266
- printf ( " %s" , llama_token_to_str (ctx, id));
270
+ fprintf (outstream, " %s" , llama_token_to_str (ctx, id));
267
271
}
268
- fflush (stdout );
272
+ fflush (outstream );
269
273
}
270
274
// reset color to default if we there is no pending user input
271
275
if (!input_noecho && (int )embd_inp.size () == input_consumed) {
272
- set_console_state (CONSOLE_STATE_DEFAULT);
276
+ set_console_state (outstream, CONSOLE_STATE_DEFAULT);
273
277
}
274
278
275
279
// in interactive mode, and not currently processing queued inputs;
@@ -290,20 +294,20 @@ int run(llama_context * ctx, gpt_params params) {
290
294
}
291
295
if (is_interacting) {
292
296
// potentially set color to indicate we are taking user input
293
- set_console_state (CONSOLE_STATE_USER_INPUT);
297
+ set_console_state (outstream, CONSOLE_STATE_USER_INPUT);
294
298
295
299
if (params.instruct ) {
296
300
input_consumed = embd_inp.size ();
297
301
embd_inp.insert (embd_inp.end (), inp_pfx.begin (), inp_pfx.end ());
298
302
299
- printf ( " \n > " );
303
+ fprintf (outstream, " \n > " );
300
304
}
301
305
302
306
std::string buffer;
303
307
std::string line;
304
308
bool another_line = true ;
305
309
do {
306
- std::getline (std::cin , line);
310
+ std::getline (instream , line);
307
311
if (line.empty () || line.back () != ' \\ ' ) {
308
312
another_line = false ;
309
313
} else {
@@ -313,7 +317,7 @@ int run(llama_context * ctx, gpt_params params) {
313
317
} while (another_line);
314
318
315
319
// done taking input, reset color
316
- set_console_state (CONSOLE_STATE_DEFAULT);
320
+ set_console_state (outstream, CONSOLE_STATE_DEFAULT);
317
321
318
322
auto line_inp = ::llama_tokenize (ctx, buffer, false );
319
323
embd_inp.insert (embd_inp.end (), line_inp.begin (), line_inp.end ());
@@ -334,7 +338,7 @@ int run(llama_context * ctx, gpt_params params) {
334
338
if (params.interactive ) {
335
339
is_interacting = true ;
336
340
} else {
337
- fprintf (stderr , " [end of text]\n " );
341
+ fprintf (errstream , " [end of text]\n " );
338
342
break ;
339
343
}
340
344
}
@@ -354,7 +358,7 @@ int run(llama_context * ctx, gpt_params params) {
354
358
355
359
llama_free (ctx);
356
360
357
- set_console_state (CONSOLE_STATE_DEFAULT);
361
+ set_console_state (outstream, CONSOLE_STATE_DEFAULT);
358
362
359
363
return 0 ;
360
364
}
0 commit comments