Skip to content

Commit bf44faa

Browse files
committed
Remove direct access to std streams from "run"
The goal is to allow running "run" while connected to other streams, such as TCP sockets. Signed-off-by: Thiago Padilha <[email protected]>
1 parent b7f1fa6 commit bf44faa

File tree

3 files changed

+40
-30
lines changed

3 files changed

+40
-30
lines changed

main.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include "run.h"
22
#include "ggml.h"
33

4+
#include <iostream>
5+
46

57
std::vector<double> softmax(const std::vector<float>& logits) {
68
std::vector<double> probs(logits.size());
@@ -123,5 +125,5 @@ int main(int argc, char ** argv) {
123125
exit(0);
124126
}
125127

126-
return run(ctx, params);
128+
return run(ctx, params, std::cin, stdout, stderr);
127129
}

run.cpp

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -44,21 +44,21 @@ enum console_state {
4444
static console_state con_st = CONSOLE_STATE_DEFAULT;
4545
static bool con_use_color = false;
4646

47-
void set_console_state(console_state new_st)
47+
void set_console_state(FILE *stream, console_state new_st)
4848
{
4949
if (!con_use_color) return;
5050
// only emit color code if state changed
5151
if (new_st != con_st) {
5252
con_st = new_st;
5353
switch(con_st) {
5454
case CONSOLE_STATE_DEFAULT:
55-
printf(ANSI_COLOR_RESET);
55+
fprintf(stream, ANSI_COLOR_RESET);
5656
return;
5757
case CONSOLE_STATE_PROMPT:
58-
printf(ANSI_COLOR_YELLOW);
58+
fprintf(stream, ANSI_COLOR_YELLOW);
5959
return;
6060
case CONSOLE_STATE_USER_INPUT:
61-
printf(ANSI_BOLD ANSI_COLOR_GREEN);
61+
fprintf(stream, ANSI_BOLD ANSI_COLOR_GREEN);
6262
return;
6363
}
6464
}
@@ -68,7 +68,7 @@ static bool is_interacting = false;
6868

6969
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
7070
void sigint_handler(int signo) {
71-
set_console_state(CONSOLE_STATE_DEFAULT);
71+
set_console_state(stdout, CONSOLE_STATE_DEFAULT);
7272
printf("\n"); // this also force flush stdout.
7373
if (signo == SIGINT) {
7474
if (!is_interacting) {
@@ -80,13 +80,17 @@ void sigint_handler(int signo) {
8080
}
8181
#endif
8282

83-
int run(llama_context * ctx, gpt_params params) {
83+
int run(llama_context * ctx,
84+
gpt_params params,
85+
std::istream & instream,
86+
FILE *outstream,
87+
FILE *errstream) {
8488

8589
if (params.seed <= 0) {
8690
params.seed = time(NULL);
8791
}
8892

89-
fprintf(stderr, "%s: seed = %d\n", __func__, params.seed);
93+
fprintf(errstream, "%s: seed = %d\n", __func__, params.seed);
9094

9195
std::mt19937 rng(params.seed);
9296
if (params.random_prompt) {
@@ -138,13 +142,13 @@ int run(llama_context * ctx, gpt_params params) {
138142
params.interactive = true;
139143
}
140144

141-
fprintf(stderr, "\n");
142-
fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
143-
fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
145+
fprintf(errstream, "\n");
146+
fprintf(errstream, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
147+
fprintf(errstream, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
144148
for (int i = 0; i < (int) embd_inp.size(); i++) {
145-
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i]));
149+
fprintf(errstream, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i]));
146150
}
147-
fprintf(stderr, "\n");
151+
fprintf(errstream, "\n");
148152
if (params.interactive) {
149153
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
150154
struct sigaction sigint_action;
@@ -156,16 +160,16 @@ int run(llama_context * ctx, gpt_params params) {
156160
signal(SIGINT, sigint_handler);
157161
#endif
158162

159-
fprintf(stderr, "%s: interactive mode on.\n", __func__);
163+
fprintf(errstream, "%s: interactive mode on.\n", __func__);
160164

161165
if(params.antiprompt.size()) {
162166
for (auto antiprompt : params.antiprompt) {
163-
fprintf(stderr, "Reverse prompt: '%s'\n", antiprompt.c_str());
167+
fprintf(errstream, "Reverse prompt: '%s'\n", antiprompt.c_str());
164168
}
165169
}
166170
}
167-
fprintf(stderr, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
168-
fprintf(stderr, "\n\n");
171+
fprintf(errstream, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
172+
fprintf(errstream, "\n\n");
169173

170174
std::vector<llama_token> embd;
171175

@@ -174,7 +178,7 @@ int run(llama_context * ctx, gpt_params params) {
174178
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
175179

176180
if (params.interactive) {
177-
fprintf(stderr, "== Running in interactive mode. ==\n"
181+
fprintf(errstream, "== Running in interactive mode. ==\n"
178182
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
179183
" - Press Ctrl+C to interject at any time.\n"
180184
#endif
@@ -199,13 +203,13 @@ int run(llama_context * ctx, gpt_params params) {
199203
}
200204
#endif
201205
// the first thing we will do is to output the prompt, so set color accordingly
202-
set_console_state(CONSOLE_STATE_PROMPT);
206+
set_console_state(outstream, CONSOLE_STATE_PROMPT);
203207

204208
while (remaining_tokens > 0 || params.interactive) {
205209
// predict
206210
if (embd.size() > 0) {
207211
if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) {
208-
fprintf(stderr, "%s : failed to eval\n", __func__);
212+
fprintf(errstream, "%s : failed to eval\n", __func__);
209213
return 1;
210214
}
211215
}
@@ -263,13 +267,13 @@ int run(llama_context * ctx, gpt_params params) {
263267
// display text
264268
if (!input_noecho) {
265269
for (auto id : embd) {
266-
printf("%s", llama_token_to_str(ctx, id));
270+
fprintf(outstream, "%s", llama_token_to_str(ctx, id));
267271
}
268-
fflush(stdout);
272+
fflush(outstream);
269273
}
270274
// reset color to default if we there is no pending user input
271275
if (!input_noecho && (int)embd_inp.size() == input_consumed) {
272-
set_console_state(CONSOLE_STATE_DEFAULT);
276+
set_console_state(outstream, CONSOLE_STATE_DEFAULT);
273277
}
274278

275279
// in interactive mode, and not currently processing queued inputs;
@@ -290,20 +294,20 @@ int run(llama_context * ctx, gpt_params params) {
290294
}
291295
if (is_interacting) {
292296
// potentially set color to indicate we are taking user input
293-
set_console_state(CONSOLE_STATE_USER_INPUT);
297+
set_console_state(outstream, CONSOLE_STATE_USER_INPUT);
294298

295299
if (params.instruct) {
296300
input_consumed = embd_inp.size();
297301
embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
298302

299-
printf("\n> ");
303+
fprintf(outstream, "\n> ");
300304
}
301305

302306
std::string buffer;
303307
std::string line;
304308
bool another_line = true;
305309
do {
306-
std::getline(std::cin, line);
310+
std::getline(instream, line);
307311
if (line.empty() || line.back() != '\\') {
308312
another_line = false;
309313
} else {
@@ -313,7 +317,7 @@ int run(llama_context * ctx, gpt_params params) {
313317
} while (another_line);
314318

315319
// done taking input, reset color
316-
set_console_state(CONSOLE_STATE_DEFAULT);
320+
set_console_state(outstream, CONSOLE_STATE_DEFAULT);
317321

318322
auto line_inp = ::llama_tokenize(ctx, buffer, false);
319323
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
@@ -334,7 +338,7 @@ int run(llama_context * ctx, gpt_params params) {
334338
if (params.interactive) {
335339
is_interacting = true;
336340
} else {
337-
fprintf(stderr, " [end of text]\n");
341+
fprintf(errstream, " [end of text]\n");
338342
break;
339343
}
340344
}
@@ -354,7 +358,7 @@ int run(llama_context * ctx, gpt_params params) {
354358

355359
llama_free(ctx);
356360

357-
set_console_state(CONSOLE_STATE_DEFAULT);
361+
set_console_state(outstream, CONSOLE_STATE_DEFAULT);
358362

359363
return 0;
360364
}

run.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,8 @@
33
#include "llama.h"
44
#include "utils.h"
55

6-
int run(llama_context * ctx, gpt_params params);
6+
int run(llama_context * ctx,
7+
gpt_params params,
8+
std::istream & instream,
9+
FILE *outstream,
10+
FILE *errstream);

0 commit comments

Comments
 (0)