Skip to content

Commit 623055e

Browse files
committed
Add author mode and other related QOL improvements
1 parent 305eb5a commit 623055e

File tree

3 files changed

+184
-50
lines changed

3 files changed

+184
-50
lines changed

examples/common.cpp

Lines changed: 150 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,17 @@ extern "C" __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int
2222
const wchar_t * lpWideCharStr, int cchWideChar,
2323
char * lpMultiByteStr, int cbMultiByte,
2424
const char * lpDefaultChar, bool * lpUsedDefaultChar);
25+
#define ENABLE_LINE_INPUT 0x0002
26+
#define ENABLE_ECHO_INPUT 0x0004
2527
#define CP_UTF8 65001
28+
#define CONSOLE_CHAR_TYPE wchar_t
29+
#define CONSOLE_GET_CHAR() getwchar()
30+
#define CONSOLE_EOF WEOF
31+
#else
32+
#include <unistd.h>
33+
#define CONSOLE_CHAR_TYPE char
34+
#define CONSOLE_GET_CHAR() getchar()
35+
#define CONSOLE_EOF EOF
2636
#endif
2737

2838
bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
@@ -208,6 +218,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
208218
params.embedding = true;
209219
} else if (arg == "--interactive-first") {
210220
params.interactive_first = true;
221+
} else if (arg == "--author-mode") {
222+
params.author_mode = true;
211223
} else if (arg == "-ins" || arg == "--instruct") {
212224
params.instruct = true;
213225
} else if (arg == "--color") {
@@ -291,6 +303,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
291303
fprintf(stderr, " -i, --interactive run in interactive mode\n");
292304
fprintf(stderr, " --interactive-first run in interactive mode and wait for input right away\n");
293305
fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n");
306+
fprintf(stderr, " --author-mode allows you to write or paste multiple lines without ending each in '\\'\n");
294307
fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n");
295308
fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT (can be\n");
296309
fprintf(stderr, " specified more than once for multiple prompts).\n");
@@ -377,7 +390,7 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
377390
}
378391

379392
/* Keep track of current color of output, and emit ANSI code if it changes. */
380-
void set_console_color(console_state & con_st, console_color_t color) {
393+
void console_set_color(console_state & con_st, console_color_t color) {
381394
if (con_st.use_color && con_st.color != color) {
382395
switch(color) {
383396
case CONSOLE_COLOR_DEFAULT:
@@ -394,8 +407,9 @@ void set_console_color(console_state & con_st, console_color_t color) {
394407
}
395408
}
396409

410+
void console_init(console_state & con_st) {
397411
#if defined (_WIN32)
398-
void win32_console_init(bool enable_color) {
412+
// Windows-specific console initialization
399413
unsigned long dwMode = 0;
400414
void* hConOut = GetStdHandle((unsigned long)-11); // STD_OUTPUT_HANDLE (-11)
401415
if (!hConOut || hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode)) {
@@ -406,7 +420,7 @@ void win32_console_init(bool enable_color) {
406420
}
407421
if (hConOut) {
408422
// Enable ANSI colors on Windows 10+
409-
if (enable_color && !(dwMode & 0x4)) {
423+
if (con_st.use_color && !(dwMode & 0x4)) {
410424
SetConsoleMode(hConOut, dwMode | 0x4); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4)
411425
}
412426
// Set console output codepage to UTF8
@@ -416,9 +430,46 @@ void win32_console_init(bool enable_color) {
416430
if (hConIn && hConIn != (void*)-1 && GetConsoleMode(hConIn, &dwMode)) {
417431
// Set console input codepage to UTF16
418432
_setmode(_fileno(stdin), _O_WTEXT);
433+
434+
// Turn off ICANON (ENABLE_LINE_INPUT) and ECHO (ENABLE_ECHO_INPUT)
435+
dwMode &= ~(ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT);
436+
SetConsoleMode(hConIn, dwMode);
437+
}
438+
#else
439+
// POSIX-specific console initialization
440+
struct termios new_termios;
441+
tcgetattr(STDIN_FILENO, &con_st.prev_state);
442+
new_termios = con_st.prev_state;
443+
new_termios.c_lflag &= ~(ICANON | ECHO);
444+
new_termios.c_cc[VMIN] = 1;
445+
new_termios.c_cc[VTIME] = 0;
446+
tcsetattr(STDIN_FILENO, TCSANOW, &new_termios);
447+
#endif
448+
}
449+
450+
void console_cleanup(console_state & con_st) {
451+
#if !defined(_WIN32)
452+
// Restore the terminal settings on POSIX systems
453+
tcsetattr(STDIN_FILENO, TCSANOW, &con_st.prev_state);
454+
#endif
455+
456+
// Reset console color
457+
console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
458+
}
459+
460+
// Helper function to remove the last UTF-8 character from a string
461+
void remove_last_utf8_char(std::string & line) {
462+
if (line.empty()) return;
463+
size_t pos = line.length() - 1;
464+
465+
// Find the start of the last UTF-8 character (checking up to 4 bytes back)
466+
for (size_t i = 0; i < 3 && pos > 0; ++i, --pos) {
467+
if ((line[pos] & 0xC0) != 0x80) break; // Found the start of the character
419468
}
469+
line.erase(pos);
420470
}
421471

472+
#if defined (_WIN32)
422473
// Convert a wide Unicode string to an UTF8 string
423474
void win32_utf8_encode(const std::wstring & wstr, std::string & str) {
424475
int size_needed = WideCharToMultiByte(CP_UTF8, 0, &wstr[0], (int)wstr.size(), NULL, 0, NULL, NULL);
@@ -427,3 +478,99 @@ void win32_utf8_encode(const std::wstring & wstr, std::string & str) {
427478
str = strTo;
428479
}
429480
#endif
481+
482+
bool console_readline(console_state & con_st, std::string & line) {
483+
line.clear();
484+
bool is_special_char = false;
485+
bool end_of_stream = false;
486+
487+
console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
488+
489+
CONSOLE_CHAR_TYPE input_char;
490+
while (true) {
491+
fflush(stdout); // Ensure all output is displayed before waiting for input
492+
input_char = CONSOLE_GET_CHAR();
493+
494+
if (input_char == '\r' || input_char == '\n') {
495+
break;
496+
}
497+
498+
if (input_char == CONSOLE_EOF || input_char == 0x04 /* Ctrl+D*/) {
499+
end_of_stream = true;
500+
break;
501+
}
502+
503+
if (is_special_char) {
504+
console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
505+
putchar('\b');
506+
putchar(line.back());
507+
is_special_char = false;
508+
}
509+
510+
if (input_char == '\033') { // Escape sequence
511+
CONSOLE_CHAR_TYPE code = CONSOLE_GET_CHAR();
512+
if (code == '[') {
513+
// Discard the rest of the escape sequence
514+
while ((code = CONSOLE_GET_CHAR()) != CONSOLE_EOF) {
515+
if ((code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z') || code == '~') {
516+
break;
517+
}
518+
}
519+
}
520+
} else if (input_char == 0x08 || input_char == 0x7F) { // Backspace
521+
if (!line.empty()) {
522+
fputs("\b \b", stdout); // Move cursor back, print a space, and move cursor back again
523+
remove_last_utf8_char(line);
524+
}
525+
} else if (input_char < 32) {
526+
// Ignore control characters
527+
} else {
528+
#if defined(_WIN32)
529+
std::string utf8_char;
530+
win32_utf8_encode(std::wstring(1, input_char), utf8_char);
531+
line += utf8_char;
532+
fputs(utf8_char.c_str(), stdout);
533+
#else
534+
line += input_char;
535+
putchar(input_char);
536+
#endif
537+
}
538+
539+
if (!line.empty() && (line.back() == '\\' || line.back() == '/')) {
540+
console_set_color(con_st, CONSOLE_COLOR_PROMPT);
541+
putchar('\b');
542+
putchar(line.back());
543+
is_special_char = true;
544+
}
545+
}
546+
547+
bool has_more = con_st.author_mode;
548+
if (is_special_char) {
549+
fputs("\b \b", stdout); // Move cursor back, print a space, and move cursor back again
550+
551+
char last = line.back();
552+
line.pop_back();
553+
if (last == '\\') {
554+
line += '\n';
555+
putchar('\n');
556+
has_more = !has_more;
557+
} else {
558+
// llama doesn't seem to process a single space
559+
if (line.length() == 1 && line.back() == ' ') {
560+
line.clear();
561+
putchar('\b');
562+
}
563+
has_more = false;
564+
}
565+
} else {
566+
if (end_of_stream) {
567+
has_more = false;
568+
} else {
569+
line += '\n';
570+
putchar('\n');
571+
}
572+
}
573+
574+
fflush(stdout);
575+
return has_more;
576+
}

examples/common.h

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
#include <thread>
1111
#include <unordered_map>
1212

13+
#if !defined (_WIN32)
14+
#include <termios.h>
15+
#endif
16+
1317
//
1418
// CLI argument parsing
1519
//
@@ -54,6 +58,7 @@ struct gpt_params {
5458

5559
bool embedding = false; // get only sentence embedding
5660
bool interactive_first = false; // wait for user input immediately
61+
bool author_mode = false; // reverse the usage of `\`
5762

5863
bool instruct = false; // instruction mode (used for Alpaca models)
5964
bool penalize_nl = true; // consider newlines as a repeatable token
@@ -96,13 +101,15 @@ enum console_color_t {
96101
};
97102

98103
struct console_state {
104+
bool author_mode = false;
99105
bool use_color = false;
100106
console_color_t color = CONSOLE_COLOR_DEFAULT;
107+
#if !defined (_WIN32)
108+
termios prev_state;
109+
#endif
101110
};
102111

103-
void set_console_color(console_state & con_st, console_color_t color);
104-
105-
#if defined (_WIN32)
106-
void win32_console_init(bool enable_color);
107-
void win32_utf8_encode(const std::wstring & wstr, std::string & str);
108-
#endif
112+
void console_init(console_state & con_st);
113+
void console_cleanup(console_state & con_st);
114+
void console_set_color(console_state & con_st, console_color_t color);
115+
bool console_readline(console_state & con_st, std::string & line);

examples/main/main.cpp

Lines changed: 21 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,12 @@ static bool is_interacting = false;
3131

3232
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
3333
void sigint_handler(int signo) {
34-
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
35-
printf("\n"); // this also force flush stdout.
3634
if (signo == SIGINT) {
3735
if (!is_interacting) {
3836
is_interacting=true;
3937
} else {
38+
console_cleanup(con_st);
39+
printf("\n");
4040
llama_print_timings(*g_ctx);
4141
_exit(130);
4242
}
@@ -55,10 +55,9 @@ int main(int argc, char ** argv) {
5555
// save choice to use color for later
5656
// (note for later: this is a slightly awkward choice)
5757
con_st.use_color = params.use_color;
58-
59-
#if defined (_WIN32)
60-
win32_console_init(params.use_color);
61-
#endif
58+
con_st.author_mode = params.author_mode;
59+
console_init(con_st);
60+
atexit([]() { console_cleanup(con_st); });
6261

6362
if (params.perplexity) {
6463
printf("\n************\n");
@@ -286,12 +285,21 @@ int main(int argc, char ** argv) {
286285
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
287286

288287
if (params.interactive) {
288+
const char *control_message;
289+
if (con_st.author_mode) {
290+
control_message = " - To return control to LLaMa, end your input with '\\'.\n"
291+
" - To return control without starting a new line, end your input with '/'.\n";
292+
} else {
293+
control_message = " - Press Return to return control to LLaMa.\n"
294+
" - To return control without starting a new line, end your input with '/'.\n"
295+
" - If you want to submit another line, end your input with '\\'.\n";
296+
}
289297
fprintf(stderr, "== Running in interactive mode. ==\n"
290298
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
291299
" - Press Ctrl+C to interject at any time.\n"
292300
#endif
293-
" - Press Return to return control to LLaMa.\n"
294-
" - If you want to submit another line, end your input in '\\'.\n\n");
301+
"%s\n", control_message);
302+
295303
is_interacting = params.interactive_first;
296304
}
297305

@@ -310,7 +318,7 @@ int main(int argc, char ** argv) {
310318
int n_session_consumed = 0;
311319

312320
// the first thing we will do is to output the prompt, so set color accordingly
313-
set_console_color(con_st, CONSOLE_COLOR_PROMPT);
321+
console_set_color(con_st, CONSOLE_COLOR_PROMPT);
314322

315323
std::vector<llama_token> embd;
316324

@@ -508,7 +516,7 @@ int main(int argc, char ** argv) {
508516
}
509517
// reset color to default if we there is no pending user input
510518
if (!input_noecho && (int)embd_inp.size() == n_consumed) {
511-
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
519+
console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
512520
}
513521

514522
// in interactive mode, and not currently processing queued inputs;
@@ -528,17 +536,12 @@ int main(int argc, char ** argv) {
528536
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
529537
is_interacting = true;
530538
is_antiprompt = true;
531-
set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);
532-
fflush(stdout);
533539
break;
534540
}
535541
}
536542
}
537543

538544
if (n_past > 0 && is_interacting) {
539-
// potentially set color to indicate we are taking user input
540-
set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);
541-
542545
#if defined (_WIN32)
543546
// Windows: must reactivate sigint handler after each signal
544547
signal(SIGINT, sigint_handler);
@@ -557,29 +560,12 @@ int main(int argc, char ** argv) {
557560
std::string line;
558561
bool another_line = true;
559562
do {
560-
#if defined(_WIN32)
561-
std::wstring wline;
562-
if (!std::getline(std::wcin, wline)) {
563-
// input stream is bad or EOF received
564-
return 0;
565-
}
566-
win32_utf8_encode(wline, line);
567-
#else
568-
if (!std::getline(std::cin, line)) {
569-
// input stream is bad or EOF received
570-
return 0;
571-
}
572-
#endif
573-
if (line.empty() || line.back() != '\\') {
574-
another_line = false;
575-
} else {
576-
line.pop_back(); // Remove the continue character
577-
}
578-
buffer += line + '\n'; // Append the line to the result
563+
another_line = console_readline(con_st, line);
564+
buffer += line;
579565
} while (another_line);
580566

581567
// done taking input, reset color
582-
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
568+
console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
583569

584570
// Add tokens to embd only if the input buffer is non-empty
585571
// Entering a empty line lets the user pass control back
@@ -627,14 +613,8 @@ int main(int argc, char ** argv) {
627613
}
628614
}
629615

630-
#if defined (_WIN32)
631-
signal(SIGINT, SIG_DFL);
632-
#endif
633-
634616
llama_print_timings(ctx);
635617
llama_free(ctx);
636618

637-
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
638-
639619
return 0;
640620
}

0 commit comments

Comments
 (0)