Skip to content

Commit 7b8dbcb

Browse files
authored
main.cpp fixes, refactoring (#571)
- main: entering empty line passes back control without new input in interactive/instruct modes - instruct mode: keep prompt fix - instruct mode: duplicate instruct prompt fix - refactor: move common console code from main->common
1 parent 4b8efff commit 7b8dbcb

File tree

3 files changed

+143
-118
lines changed

3 files changed

+143
-118
lines changed

examples/common.cpp

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,20 @@
99
#include <iterator>
1010
#include <algorithm>
1111

12-
#if defined(_MSC_VER) || defined(__MINGW32__)
13-
#include <malloc.h> // using malloc.h with MSC/MINGW
14-
#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
15-
#include <alloca.h>
16-
#endif
12+
#if defined(_MSC_VER) || defined(__MINGW32__)
13+
#include <malloc.h> // using malloc.h with MSC/MINGW
14+
#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
15+
#include <alloca.h>
16+
#endif
17+
18+
#if defined (_WIN32)
19+
#pragma comment(lib,"kernel32.lib")
20+
extern "C" __declspec(dllimport) void* __stdcall GetStdHandle(unsigned long nStdHandle);
21+
extern "C" __declspec(dllimport) int __stdcall GetConsoleMode(void* hConsoleHandle, unsigned long* lpMode);
22+
extern "C" __declspec(dllimport) int __stdcall SetConsoleMode(void* hConsoleHandle, unsigned long dwMode);
23+
extern "C" __declspec(dllimport) int __stdcall SetConsoleCP(unsigned int wCodePageID);
24+
extern "C" __declspec(dllimport) int __stdcall SetConsoleOutputCP(unsigned int wCodePageID);
25+
#endif
1726

1827
bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
1928
// determine sensible default number of threads.
@@ -204,7 +213,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
204213
fprintf(stderr, " --in-prefix STRING string to prefix user inputs with (default: empty)\n");
205214
fprintf(stderr, " -f FNAME, --file FNAME\n");
206215
fprintf(stderr, " prompt file to start generation.\n");
207-
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 - infinity)\n", params.n_predict);
216+
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict);
208217
fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
209218
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p);
210219
fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n);
@@ -216,7 +225,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
216225
fprintf(stderr, " --n_parts N number of model parts (default: -1 = determine from dimensions)\n");
217226
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
218227
fprintf(stderr, " --perplexity compute perplexity over the prompt\n");
219-
fprintf(stderr, " --keep number of tokens to keep from the initial prompt\n");
228+
fprintf(stderr, " --keep number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
220229
if (ggml_mlock_supported()) {
221230
fprintf(stderr, " --mlock force system to keep model in RAM rather than swapping or compressing\n");
222231
}
@@ -256,3 +265,47 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
256265

257266
return res;
258267
}
268+
269+
/* Keep track of current color of output, and emit ANSI code if it changes. */
270+
void set_console_color(console_state & con_st, console_color_t color) {
271+
if (con_st.use_color && con_st.color != color) {
272+
switch(color) {
273+
case CONSOLE_COLOR_DEFAULT:
274+
printf(ANSI_COLOR_RESET);
275+
break;
276+
case CONSOLE_COLOR_PROMPT:
277+
printf(ANSI_COLOR_YELLOW);
278+
break;
279+
case CONSOLE_COLOR_USER_INPUT:
280+
printf(ANSI_BOLD ANSI_COLOR_GREEN);
281+
break;
282+
}
283+
con_st.color = color;
284+
}
285+
}
286+
287+
#if defined (_WIN32)
288+
void win32_console_init(bool enable_color) {
289+
unsigned long dwMode = 0;
290+
void* hConOut = GetStdHandle((unsigned long)-11); // STD_OUTPUT_HANDLE (-11)
291+
if (!hConOut || hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode)) {
292+
hConOut = GetStdHandle((unsigned long)-12); // STD_ERROR_HANDLE (-12)
293+
if (hConOut && (hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode))) {
294+
hConOut = 0;
295+
}
296+
}
297+
if (hConOut) {
298+
// Enable ANSI colors on Windows 10+
299+
if (enable_color && !(dwMode & 0x4)) {
300+
SetConsoleMode(hConOut, dwMode | 0x4); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4)
301+
}
302+
// Set console output codepage to UTF8
303+
SetConsoleOutputCP(65001); // CP_UTF8
304+
}
305+
void* hConIn = GetStdHandle((unsigned long)-10); // STD_INPUT_HANDLE (-10)
306+
if (hConIn && hConIn != (void*)-1 && GetConsoleMode(hConIn, &dwMode)) {
307+
// Set console input codepage to UTF8
308+
SetConsoleCP(65001); // CP_UTF8
309+
}
310+
}
311+
#endif

examples/common.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,33 @@ std::string gpt_random_prompt(std::mt19937 & rng);
6363
//
6464

6565
std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos);
66+
67+
//
68+
// Console utils
69+
//
70+
71+
#define ANSI_COLOR_RED "\x1b[31m"
72+
#define ANSI_COLOR_GREEN "\x1b[32m"
73+
#define ANSI_COLOR_YELLOW "\x1b[33m"
74+
#define ANSI_COLOR_BLUE "\x1b[34m"
75+
#define ANSI_COLOR_MAGENTA "\x1b[35m"
76+
#define ANSI_COLOR_CYAN "\x1b[36m"
77+
#define ANSI_COLOR_RESET "\x1b[0m"
78+
#define ANSI_BOLD "\x1b[1m"
79+
80+
enum console_color_t {
81+
CONSOLE_COLOR_DEFAULT=0,
82+
CONSOLE_COLOR_PROMPT,
83+
CONSOLE_COLOR_USER_INPUT
84+
};
85+
86+
struct console_state {
87+
bool use_color = false;
88+
console_color_t color = CONSOLE_COLOR_DEFAULT;
89+
};
90+
91+
void set_console_color(console_state & con_st, console_color_t color);
92+
93+
#if defined (_WIN32)
94+
void win32_console_init(bool enable_color);
95+
#endif

examples/main/main.cpp

Lines changed: 53 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -18,58 +18,13 @@
1818
#include <signal.h>
1919
#endif
2020

21-
#if defined (_WIN32)
22-
#pragma comment(lib,"kernel32.lib")
23-
extern "C" __declspec(dllimport) void* __stdcall GetStdHandle(unsigned long nStdHandle);
24-
extern "C" __declspec(dllimport) int __stdcall GetConsoleMode(void* hConsoleHandle, unsigned long* lpMode);
25-
extern "C" __declspec(dllimport) int __stdcall SetConsoleMode(void* hConsoleHandle, unsigned long dwMode);
26-
extern "C" __declspec(dllimport) int __stdcall SetConsoleCP(unsigned int wCodePageID);
27-
extern "C" __declspec(dllimport) int __stdcall SetConsoleOutputCP(unsigned int wCodePageID);
28-
#endif
29-
30-
#define ANSI_COLOR_RED "\x1b[31m"
31-
#define ANSI_COLOR_GREEN "\x1b[32m"
32-
#define ANSI_COLOR_YELLOW "\x1b[33m"
33-
#define ANSI_COLOR_BLUE "\x1b[34m"
34-
#define ANSI_COLOR_MAGENTA "\x1b[35m"
35-
#define ANSI_COLOR_CYAN "\x1b[36m"
36-
#define ANSI_COLOR_RESET "\x1b[0m"
37-
#define ANSI_BOLD "\x1b[1m"
38-
39-
/* Keep track of current color of output, and emit ANSI code if it changes. */
40-
enum console_state {
41-
CONSOLE_STATE_DEFAULT=0,
42-
CONSOLE_STATE_PROMPT,
43-
CONSOLE_STATE_USER_INPUT
44-
};
45-
46-
static console_state con_st = CONSOLE_STATE_DEFAULT;
47-
static bool con_use_color = false;
48-
49-
void set_console_state(console_state new_st) {
50-
if (!con_use_color) return;
51-
// only emit color code if state changed
52-
if (new_st != con_st) {
53-
con_st = new_st;
54-
switch(con_st) {
55-
case CONSOLE_STATE_DEFAULT:
56-
printf(ANSI_COLOR_RESET);
57-
return;
58-
case CONSOLE_STATE_PROMPT:
59-
printf(ANSI_COLOR_YELLOW);
60-
return;
61-
case CONSOLE_STATE_USER_INPUT:
62-
printf(ANSI_BOLD ANSI_COLOR_GREEN);
63-
return;
64-
}
65-
}
66-
}
21+
static console_state con_st;
6722

6823
static bool is_interacting = false;
6924

7025
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
7126
void sigint_handler(int signo) {
72-
set_console_state(CONSOLE_STATE_DEFAULT);
27+
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
7328
printf("\n"); // this also force flush stdout.
7429
if (signo == SIGINT) {
7530
if (!is_interacting) {
@@ -81,32 +36,6 @@ void sigint_handler(int signo) {
8136
}
8237
#endif
8338

84-
#if defined (_WIN32)
85-
void win32_console_init(void) {
86-
unsigned long dwMode = 0;
87-
void* hConOut = GetStdHandle((unsigned long)-11); // STD_OUTPUT_HANDLE (-11)
88-
if (!hConOut || hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode)) {
89-
hConOut = GetStdHandle((unsigned long)-12); // STD_ERROR_HANDLE (-12)
90-
if (hConOut && (hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode))) {
91-
hConOut = 0;
92-
}
93-
}
94-
if (hConOut) {
95-
// Enable ANSI colors on Windows 10+
96-
if (con_use_color && !(dwMode & 0x4)) {
97-
SetConsoleMode(hConOut, dwMode | 0x4); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4)
98-
}
99-
// Set console output codepage to UTF8
100-
SetConsoleOutputCP(65001); // CP_UTF8
101-
}
102-
void* hConIn = GetStdHandle((unsigned long)-10); // STD_INPUT_HANDLE (-10)
103-
if (hConIn && hConIn != (void*)-1 && GetConsoleMode(hConIn, &dwMode)) {
104-
// Set console input codepage to UTF8
105-
SetConsoleCP(65001); // CP_UTF8
106-
}
107-
}
108-
#endif
109-
11039
int main(int argc, char ** argv) {
11140
gpt_params params;
11241
params.model = "models/llama-7B/ggml-model.bin";
@@ -115,13 +44,12 @@ int main(int argc, char ** argv) {
11544
return 1;
11645
}
11746

118-
11947
// save choice to use color for later
12048
// (note for later: this is a slightly awkward choice)
121-
con_use_color = params.use_color;
49+
con_st.use_color = params.use_color;
12250

12351
#if defined (_WIN32)
124-
win32_console_init();
52+
win32_console_init(params.use_color);
12553
#endif
12654

12755
if (params.perplexity) {
@@ -218,24 +146,23 @@ int main(int argc, char ** argv) {
218146
return 1;
219147
}
220148

221-
params.n_keep = std::min(params.n_keep, (int) embd_inp.size());
149+
// number of tokens to keep when resetting context
150+
if (params.n_keep < 0 || params.n_keep > (int)embd_inp.size() || params.instruct) {
151+
params.n_keep = (int)embd_inp.size();
152+
}
222153

223154
// prefix & suffix for instruct mode
224155
const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", true);
225156
const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false);
226157

227158
// in instruct mode, we inject a prefix and a suffix to each input by the user
228159
if (params.instruct) {
229-
params.interactive = true;
160+
params.interactive_start = true;
230161
params.antiprompt.push_back("### Instruction:\n\n");
231162
}
232163

233-
// enable interactive mode if reverse prompt is specified
234-
if (params.antiprompt.size() != 0) {
235-
params.interactive = true;
236-
}
237-
238-
if (params.interactive_start) {
164+
// enable interactive mode if reverse prompt or interactive start is specified
165+
if (params.antiprompt.size() != 0 || params.interactive_start) {
239166
params.interactive = true;
240167
}
241168

@@ -297,17 +224,18 @@ int main(int argc, char ** argv) {
297224
#endif
298225
" - Press Return to return control to LLaMa.\n"
299226
" - If you want to submit another line, end your input in '\\'.\n\n");
300-
is_interacting = params.interactive_start || params.instruct;
227+
is_interacting = params.interactive_start;
301228
}
302229

303-
bool input_noecho = false;
230+
bool is_antiprompt = false;
231+
bool input_noecho = false;
304232

305233
int n_past = 0;
306234
int n_remain = params.n_predict;
307235
int n_consumed = 0;
308236

309237
// the first thing we will do is to output the prompt, so set color accordingly
310-
set_console_state(CONSOLE_STATE_PROMPT);
238+
set_console_color(con_st, CONSOLE_COLOR_PROMPT);
311239

312240
std::vector<llama_token> embd;
313241

@@ -408,36 +336,38 @@ int main(int argc, char ** argv) {
408336
}
409337
// reset color to default if we there is no pending user input
410338
if (!input_noecho && (int)embd_inp.size() == n_consumed) {
411-
set_console_state(CONSOLE_STATE_DEFAULT);
339+
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
412340
}
413341

414342
// in interactive mode, and not currently processing queued inputs;
415343
// check if we should prompt the user for more
416344
if (params.interactive && (int) embd_inp.size() <= n_consumed) {
345+
417346
// check for reverse prompt
418-
std::string last_output;
419-
for (auto id : last_n_tokens) {
420-
last_output += llama_token_to_str(ctx, id);
421-
}
347+
if (params.antiprompt.size()) {
348+
std::string last_output;
349+
for (auto id : last_n_tokens) {
350+
last_output += llama_token_to_str(ctx, id);
351+
}
422352

423-
// Check if each of the reverse prompts appears at the end of the output.
424-
for (std::string & antiprompt : params.antiprompt) {
425-
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
426-
is_interacting = true;
427-
set_console_state(CONSOLE_STATE_USER_INPUT);
428-
fflush(stdout);
429-
break;
353+
is_antiprompt = false;
354+
// Check if each of the reverse prompts appears at the end of the output.
355+
for (std::string & antiprompt : params.antiprompt) {
356+
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
357+
is_interacting = true;
358+
is_antiprompt = true;
359+
set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);
360+
fflush(stdout);
361+
break;
362+
}
430363
}
431364
}
432365

433366
if (n_past > 0 && is_interacting) {
434367
// potentially set color to indicate we are taking user input
435-
set_console_state(CONSOLE_STATE_USER_INPUT);
368+
set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);
436369

437370
if (params.instruct) {
438-
n_consumed = embd_inp.size();
439-
embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
440-
441371
printf("\n> ");
442372
}
443373

@@ -463,16 +393,28 @@ int main(int argc, char ** argv) {
463393
} while (another_line);
464394

465395
// done taking input, reset color
466-
set_console_state(CONSOLE_STATE_DEFAULT);
396+
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
467397

468-
auto line_inp = ::llama_tokenize(ctx, buffer, false);
469-
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
398+
// Add tokens to embd only if the input buffer is non-empty
399+
// Entering a empty line lets the user pass control back
400+
if (buffer.length() > 1) {
470401

471-
if (params.instruct) {
472-
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
473-
}
402+
// instruct mode: insert instruction prefix
403+
if (params.instruct && !is_antiprompt) {
404+
n_consumed = embd_inp.size();
405+
embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
406+
}
474407

475-
n_remain -= line_inp.size();
408+
auto line_inp = ::llama_tokenize(ctx, buffer, false);
409+
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
410+
411+
// instruct mode: insert response suffix
412+
if (params.instruct) {
413+
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
414+
}
415+
416+
n_remain -= line_inp.size();
417+
}
476418

477419
input_noecho = true; // do not echo this again
478420
}
@@ -506,7 +448,7 @@ int main(int argc, char ** argv) {
506448
llama_print_timings(ctx);
507449
llama_free(ctx);
508450

509-
set_console_state(CONSOLE_STATE_DEFAULT);
451+
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
510452

511453
return 0;
512454
}

0 commit comments

Comments
 (0)