Skip to content

Commit 52e3190

Browse files
committed
Add author mode and other related QOL improvements
1 parent a90e96b commit 52e3190

File tree

3 files changed

+184
-48
lines changed

3 files changed

+184
-48
lines changed

examples/common.cpp

Lines changed: 150 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,17 @@ extern "C" __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int
2727
const wchar_t * lpWideCharStr, int cchWideChar,
2828
char * lpMultiByteStr, int cbMultiByte,
2929
const char * lpDefaultChar, bool * lpUsedDefaultChar);
30+
#define ENABLE_LINE_INPUT 0x0002
31+
#define ENABLE_ECHO_INPUT 0x0004
3032
#define CP_UTF8 65001
33+
#define CONSOLE_CHAR_TYPE wchar_t
34+
#define CONSOLE_GET_CHAR() getwchar()
35+
#define CONSOLE_EOF WEOF
36+
#else
37+
#include <unistd.h>
38+
#define CONSOLE_CHAR_TYPE char
39+
#define CONSOLE_GET_CHAR() getchar()
40+
#define CONSOLE_EOF EOF
3141
#endif
3242

3343
int32_t get_num_physical_cores() {
@@ -264,6 +274,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
264274
params.embedding = true;
265275
} else if (arg == "--interactive-first") {
266276
params.interactive_first = true;
277+
} else if (arg == "--author-mode") {
278+
params.author_mode = true;
267279
} else if (arg == "-ins" || arg == "--instruct") {
268280
params.instruct = true;
269281
} else if (arg == "--color") {
@@ -356,6 +368,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
356368
fprintf(stderr, " -i, --interactive run in interactive mode\n");
357369
fprintf(stderr, " --interactive-first run in interactive mode and wait for input right away\n");
358370
fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n");
371+
fprintf(stderr, " --author-mode allows you to write or paste multiple lines without ending each in '\\'\n");
359372
fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n");
360373
fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT (can be\n");
361374
fprintf(stderr, " specified more than once for multiple prompts).\n");
@@ -477,7 +490,7 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
477490
}
478491

479492
/* Keep track of current color of output, and emit ANSI code if it changes. */
480-
void set_console_color(console_state & con_st, console_color_t color) {
493+
void console_set_color(console_state & con_st, console_color_t color) {
481494
if (con_st.use_color && con_st.color != color) {
482495
switch(color) {
483496
case CONSOLE_COLOR_DEFAULT:
@@ -494,8 +507,9 @@ void set_console_color(console_state & con_st, console_color_t color) {
494507
}
495508
}
496509

510+
void console_init(console_state & con_st) {
497511
#if defined (_WIN32)
498-
void win32_console_init(bool enable_color) {
512+
// Windows-specific console initialization
499513
unsigned long dwMode = 0;
500514
void* hConOut = GetStdHandle((unsigned long)-11); // STD_OUTPUT_HANDLE (-11)
501515
if (!hConOut || hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode)) {
@@ -506,7 +520,7 @@ void win32_console_init(bool enable_color) {
506520
}
507521
if (hConOut) {
508522
// Enable ANSI colors on Windows 10+
509-
if (enable_color && !(dwMode & 0x4)) {
523+
if (con_st.use_color && !(dwMode & 0x4)) {
510524
SetConsoleMode(hConOut, dwMode | 0x4); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4)
511525
}
512526
// Set console output codepage to UTF8
@@ -516,9 +530,46 @@ void win32_console_init(bool enable_color) {
516530
if (hConIn && hConIn != (void*)-1 && GetConsoleMode(hConIn, &dwMode)) {
517531
// Set console input codepage to UTF16
518532
_setmode(_fileno(stdin), _O_WTEXT);
533+
534+
// Turn off ICANON (ENABLE_LINE_INPUT) and ECHO (ENABLE_ECHO_INPUT)
535+
dwMode &= ~(ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT);
536+
SetConsoleMode(hConIn, dwMode);
537+
}
538+
#else
539+
// POSIX-specific console initialization
540+
struct termios new_termios;
541+
tcgetattr(STDIN_FILENO, &con_st.prev_state);
542+
new_termios = con_st.prev_state;
543+
new_termios.c_lflag &= ~(ICANON | ECHO);
544+
new_termios.c_cc[VMIN] = 1;
545+
new_termios.c_cc[VTIME] = 0;
546+
tcsetattr(STDIN_FILENO, TCSANOW, &new_termios);
547+
#endif
548+
}
549+
550+
void console_cleanup(console_state & con_st) {
551+
#if !defined(_WIN32)
552+
// Restore the terminal settings on POSIX systems
553+
tcsetattr(STDIN_FILENO, TCSANOW, &con_st.prev_state);
554+
#endif
555+
556+
// Reset console color
557+
console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
558+
}
559+
560+
// Helper function to remove the last UTF-8 character from a string
561+
void remove_last_utf8_char(std::string & line) {
562+
if (line.empty()) return;
563+
size_t pos = line.length() - 1;
564+
565+
// Find the start of the last UTF-8 character (checking up to 4 bytes back)
566+
for (size_t i = 0; i < 3 && pos > 0; ++i, --pos) {
567+
if ((line[pos] & 0xC0) != 0x80) break; // Found the start of the character
519568
}
569+
line.erase(pos);
520570
}
521571

572+
#if defined (_WIN32)
522573
// Convert a wide Unicode string to an UTF8 string
523574
void win32_utf8_encode(const std::wstring & wstr, std::string & str) {
524575
int size_needed = WideCharToMultiByte(CP_UTF8, 0, &wstr[0], (int)wstr.size(), NULL, 0, NULL, NULL);
@@ -527,3 +578,99 @@ void win32_utf8_encode(const std::wstring & wstr, std::string & str) {
527578
str = strTo;
528579
}
529580
#endif
581+
582+
bool console_readline(console_state & con_st, std::string & line) {
583+
line.clear();
584+
bool is_special_char = false;
585+
bool end_of_stream = false;
586+
587+
console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
588+
589+
CONSOLE_CHAR_TYPE input_char;
590+
while (true) {
591+
fflush(stdout); // Ensure all output is displayed before waiting for input
592+
input_char = CONSOLE_GET_CHAR();
593+
594+
if (input_char == '\r' || input_char == '\n') {
595+
break;
596+
}
597+
598+
if (input_char == CONSOLE_EOF || input_char == 0x04 /* Ctrl+D*/) {
599+
end_of_stream = true;
600+
break;
601+
}
602+
603+
if (is_special_char) {
604+
console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
605+
putchar('\b');
606+
putchar(line.back());
607+
is_special_char = false;
608+
}
609+
610+
if (input_char == '\033') { // Escape sequence
611+
CONSOLE_CHAR_TYPE code = CONSOLE_GET_CHAR();
612+
if (code == '[') {
613+
// Discard the rest of the escape sequence
614+
while ((code = CONSOLE_GET_CHAR()) != CONSOLE_EOF) {
615+
if ((code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z') || code == '~') {
616+
break;
617+
}
618+
}
619+
}
620+
} else if (input_char == 0x08 || input_char == 0x7F) { // Backspace
621+
if (!line.empty()) {
622+
fputs("\b \b", stdout); // Move cursor back, print a space, and move cursor back again
623+
remove_last_utf8_char(line);
624+
}
625+
} else if (input_char < 32) {
626+
// Ignore control characters
627+
} else {
628+
#if defined(_WIN32)
629+
std::string utf8_char;
630+
win32_utf8_encode(std::wstring(1, input_char), utf8_char);
631+
line += utf8_char;
632+
fputs(utf8_char.c_str(), stdout);
633+
#else
634+
line += input_char;
635+
putchar(input_char);
636+
#endif
637+
}
638+
639+
if (!line.empty() && (line.back() == '\\' || line.back() == '/')) {
640+
console_set_color(con_st, CONSOLE_COLOR_PROMPT);
641+
putchar('\b');
642+
putchar(line.back());
643+
is_special_char = true;
644+
}
645+
}
646+
647+
bool has_more = con_st.author_mode;
648+
if (is_special_char) {
649+
fputs("\b \b", stdout); // Move cursor back, print a space, and move cursor back again
650+
651+
char last = line.back();
652+
line.pop_back();
653+
if (last == '\\') {
654+
line += '\n';
655+
putchar('\n');
656+
has_more = !has_more;
657+
} else {
658+
// llama doesn't seem to process a single space
659+
if (line.length() == 1 && line.back() == ' ') {
660+
line.clear();
661+
putchar('\b');
662+
}
663+
has_more = false;
664+
}
665+
} else {
666+
if (end_of_stream) {
667+
has_more = false;
668+
} else {
669+
line += '\n';
670+
putchar('\n');
671+
}
672+
}
673+
674+
fflush(stdout);
675+
return has_more;
676+
}

examples/common.h

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
#include <thread>
1111
#include <unordered_map>
1212

13+
#if !defined (_WIN32)
14+
#include <termios.h>
15+
#endif
16+
1317
//
1418
// CLI argument parsing
1519
//
@@ -56,6 +60,7 @@ struct gpt_params {
5660

5761
bool embedding = false; // get only sentence embedding
5862
bool interactive_first = false; // wait for user input immediately
63+
bool author_mode = false; // reverse the usage of `\`
5964

6065
bool instruct = false; // instruction mode (used for Alpaca models)
6166
bool penalize_nl = true; // consider newlines as a repeatable token
@@ -104,13 +109,15 @@ enum console_color_t {
104109
};
105110

106111
struct console_state {
112+
bool author_mode = false;
107113
bool use_color = false;
108114
console_color_t color = CONSOLE_COLOR_DEFAULT;
115+
#if !defined (_WIN32)
116+
termios prev_state;
117+
#endif
109118
};
110119

111-
void set_console_color(console_state & con_st, console_color_t color);
112-
113-
#if defined (_WIN32)
114-
void win32_console_init(bool enable_color);
115-
void win32_utf8_encode(const std::wstring & wstr, std::string & str);
116-
#endif
120+
void console_init(console_state & con_st);
121+
void console_cleanup(console_state & con_st);
122+
void console_set_color(console_state & con_st, console_color_t color);
123+
bool console_readline(console_state & con_st, std::string & line);

examples/main/main.cpp

Lines changed: 21 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,12 @@ static bool is_interacting = false;
3535

3636
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
3737
void sigint_handler(int signo) {
38-
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
39-
printf("\n"); // this also force flush stdout.
4038
if (signo == SIGINT) {
4139
if (!is_interacting) {
4240
is_interacting=true;
4341
} else {
42+
console_cleanup(con_st);
43+
printf("\n");
4444
llama_print_timings(*g_ctx);
4545
_exit(130);
4646
}
@@ -59,10 +59,9 @@ int main(int argc, char ** argv) {
5959
// save choice to use color for later
6060
// (note for later: this is a slightly awkward choice)
6161
con_st.use_color = params.use_color;
62-
63-
#if defined (_WIN32)
64-
win32_console_init(params.use_color);
65-
#endif
62+
con_st.author_mode = params.author_mode;
63+
console_init(con_st);
64+
atexit([]() { console_cleanup(con_st); });
6665

6766
if (params.perplexity) {
6867
printf("\n************\n");
@@ -275,12 +274,21 @@ int main(int argc, char ** argv) {
275274
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
276275

277276
if (params.interactive) {
277+
const char *control_message;
278+
if (con_st.author_mode) {
279+
control_message = " - To return control to LLaMa, end your input with '\\'.\n"
280+
" - To return control without starting a new line, end your input with '/'.\n";
281+
} else {
282+
control_message = " - Press Return to return control to LLaMa.\n"
283+
" - To return control without starting a new line, end your input with '/'.\n"
284+
" - If you want to submit another line, end your input with '\\'.\n";
285+
}
278286
fprintf(stderr, "== Running in interactive mode. ==\n"
279287
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
280288
" - Press Ctrl+C to interject at any time.\n"
281289
#endif
282-
" - Press Return to return control to LLaMa.\n"
283-
" - If you want to submit another line, end your input in '\\'.\n\n");
290+
"%s\n", control_message);
291+
284292
is_interacting = params.interactive_first;
285293
}
286294

@@ -299,7 +307,7 @@ int main(int argc, char ** argv) {
299307
int n_session_consumed = 0;
300308

301309
// the first thing we will do is to output the prompt, so set color accordingly
302-
set_console_color(con_st, CONSOLE_COLOR_PROMPT);
310+
console_set_color(con_st, CONSOLE_COLOR_PROMPT);
303311

304312
std::vector<llama_token> embd;
305313

@@ -498,7 +506,7 @@ int main(int argc, char ** argv) {
498506
}
499507
// reset color to default if we there is no pending user input
500508
if (input_echo && (int)embd_inp.size() == n_consumed) {
501-
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
509+
console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
502510
}
503511

504512
// in interactive mode, and not currently processing queued inputs;
@@ -518,17 +526,12 @@ int main(int argc, char ** argv) {
518526
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
519527
is_interacting = true;
520528
is_antiprompt = true;
521-
set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);
522-
fflush(stdout);
523529
break;
524530
}
525531
}
526532
}
527533

528534
if (n_past > 0 && is_interacting) {
529-
// potentially set color to indicate we are taking user input
530-
set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);
531-
532535
if (params.instruct) {
533536
printf("\n> ");
534537
}
@@ -542,31 +545,12 @@ int main(int argc, char ** argv) {
542545
std::string line;
543546
bool another_line = true;
544547
do {
545-
#if defined(_WIN32)
546-
std::wstring wline;
547-
if (!std::getline(std::wcin, wline)) {
548-
// input stream is bad or EOF received
549-
return 0;
550-
}
551-
win32_utf8_encode(wline, line);
552-
#else
553-
if (!std::getline(std::cin, line)) {
554-
// input stream is bad or EOF received
555-
return 0;
556-
}
557-
#endif
558-
if (!line.empty()) {
559-
if (line.back() == '\\') {
560-
line.pop_back(); // Remove the continue character
561-
} else {
562-
another_line = false;
563-
}
564-
buffer += line + '\n'; // Append the line to the result
565-
}
548+
another_line = console_readline(con_st, line);
549+
buffer += line;
566550
} while (another_line);
567551

568552
// done taking input, reset color
569-
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
553+
console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
570554

571555
// Add tokens to embd only if the input buffer is non-empty
572556
// Entering a empty line lets the user pass control back
@@ -622,7 +606,5 @@ int main(int argc, char ** argv) {
622606
llama_print_timings(ctx);
623607
llama_free(ctx);
624608

625-
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
626-
627609
return 0;
628610
}

0 commit comments

Comments
 (0)