@@ -27,7 +27,17 @@ extern "C" __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int
27
27
const wchar_t * lpWideCharStr, int cchWideChar,
28
28
char * lpMultiByteStr, int cbMultiByte,
29
29
const char * lpDefaultChar, bool * lpUsedDefaultChar);
30
+ #define ENABLE_LINE_INPUT 0x0002
31
+ #define ENABLE_ECHO_INPUT 0x0004
30
32
#define CP_UTF8 65001
33
+ #define CONSOLE_CHAR_TYPE wchar_t
34
+ #define CONSOLE_GET_CHAR () getwchar()
35
+ #define CONSOLE_EOF WEOF
36
+ #else
37
+ #include < unistd.h>
38
+ #define CONSOLE_CHAR_TYPE char
39
+ #define CONSOLE_GET_CHAR () getchar()
40
+ #define CONSOLE_EOF EOF
31
41
#endif
32
42
33
43
int32_t get_num_physical_cores () {
@@ -264,6 +274,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
264
274
params.embedding = true ;
265
275
} else if (arg == " --interactive-first" ) {
266
276
params.interactive_first = true ;
277
+ } else if (arg == " --author-mode" ) {
278
+ params.author_mode = true ;
267
279
} else if (arg == " -ins" || arg == " --instruct" ) {
268
280
params.instruct = true ;
269
281
} else if (arg == " --color" ) {
@@ -356,6 +368,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
356
368
fprintf (stderr, " -i, --interactive run in interactive mode\n " );
357
369
fprintf (stderr, " --interactive-first run in interactive mode and wait for input right away\n " );
358
370
fprintf (stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n " );
371
+ fprintf (stderr, " --author-mode allows you to write or paste multiple lines without ending each in '\\ '\n " );
359
372
fprintf (stderr, " -r PROMPT, --reverse-prompt PROMPT\n " );
360
373
fprintf (stderr, " run in interactive mode and poll user input upon seeing PROMPT (can be\n " );
361
374
fprintf (stderr, " specified more than once for multiple prompts).\n " );
@@ -477,7 +490,7 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
477
490
}
478
491
479
492
/* Keep track of current color of output, and emit ANSI code if it changes. */
480
- void set_console_color (console_state & con_st, console_color_t color) {
493
+ void console_set_color (console_state & con_st, console_color_t color) {
481
494
if (con_st.use_color && con_st.color != color) {
482
495
switch (color) {
483
496
case CONSOLE_COLOR_DEFAULT:
@@ -494,8 +507,9 @@ void set_console_color(console_state & con_st, console_color_t color) {
494
507
}
495
508
}
496
509
510
+ void console_init (console_state & con_st) {
497
511
#if defined (_WIN32)
498
- void win32_console_init ( bool enable_color) {
512
+ // Windows-specific console initialization
499
513
unsigned long dwMode = 0 ;
500
514
void * hConOut = GetStdHandle ((unsigned long )-11 ); // STD_OUTPUT_HANDLE (-11)
501
515
if (!hConOut || hConOut == (void *)-1 || !GetConsoleMode (hConOut, &dwMode)) {
@@ -506,7 +520,7 @@ void win32_console_init(bool enable_color) {
506
520
}
507
521
if (hConOut) {
508
522
// Enable ANSI colors on Windows 10+
509
- if (enable_color && !(dwMode & 0x4 )) {
523
+ if (con_st. use_color && !(dwMode & 0x4 )) {
510
524
SetConsoleMode (hConOut, dwMode | 0x4 ); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4)
511
525
}
512
526
// Set console output codepage to UTF8
@@ -516,9 +530,46 @@ void win32_console_init(bool enable_color) {
516
530
if (hConIn && hConIn != (void *)-1 && GetConsoleMode (hConIn, &dwMode)) {
517
531
// Set console input codepage to UTF16
518
532
_setmode (_fileno (stdin), _O_WTEXT);
533
+
534
+ // Turn off ICANON (ENABLE_LINE_INPUT) and ECHO (ENABLE_ECHO_INPUT)
535
+ dwMode &= ~(ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT);
536
+ SetConsoleMode (hConIn, dwMode);
537
+ }
538
+ #else
539
+ // POSIX-specific console initialization
540
+ struct termios new_termios;
541
+ tcgetattr (STDIN_FILENO, &con_st.prev_state );
542
+ new_termios = con_st.prev_state ;
543
+ new_termios.c_lflag &= ~(ICANON | ECHO);
544
+ new_termios.c_cc [VMIN] = 1 ;
545
+ new_termios.c_cc [VTIME] = 0 ;
546
+ tcsetattr (STDIN_FILENO, TCSANOW, &new_termios);
547
+ #endif
548
+ }
549
+
550
+ void console_cleanup (console_state & con_st) {
551
+ #if !defined(_WIN32)
552
+ // Restore the terminal settings on POSIX systems
553
+ tcsetattr (STDIN_FILENO, TCSANOW, &con_st.prev_state );
554
+ #endif
555
+
556
+ // Reset console color
557
+ console_set_color (con_st, CONSOLE_COLOR_DEFAULT);
558
+ }
559
+
560
+ // Helper function to remove the last UTF-8 character from a string
561
+ void remove_last_utf8_char (std::string & line) {
562
+ if (line.empty ()) return ;
563
+ size_t pos = line.length () - 1 ;
564
+
565
+ // Find the start of the last UTF-8 character (checking up to 4 bytes back)
566
+ for (size_t i = 0 ; i < 3 && pos > 0 ; ++i, --pos) {
567
+ if ((line[pos] & 0xC0 ) != 0x80 ) break ; // Found the start of the character
519
568
}
569
+ line.erase (pos);
520
570
}
521
571
572
+ #if defined (_WIN32)
522
573
// Convert a wide Unicode string to an UTF8 string
523
574
void win32_utf8_encode (const std::wstring & wstr, std::string & str) {
524
575
int size_needed = WideCharToMultiByte (CP_UTF8, 0 , &wstr[0 ], (int )wstr.size (), NULL , 0 , NULL , NULL );
@@ -527,3 +578,99 @@ void win32_utf8_encode(const std::wstring & wstr, std::string & str) {
527
578
str = strTo;
528
579
}
529
580
#endif
581
+
582
+ bool console_readline (console_state & con_st, std::string & line) {
583
+ line.clear ();
584
+ bool is_special_char = false ;
585
+ bool end_of_stream = false ;
586
+
587
+ console_set_color (con_st, CONSOLE_COLOR_USER_INPUT);
588
+
589
+ CONSOLE_CHAR_TYPE input_char;
590
+ while (true ) {
591
+ fflush (stdout); // Ensure all output is displayed before waiting for input
592
+ input_char = CONSOLE_GET_CHAR ();
593
+
594
+ if (input_char == ' \r ' || input_char == ' \n ' ) {
595
+ break ;
596
+ }
597
+
598
+ if (input_char == CONSOLE_EOF || input_char == 0x04 /* Ctrl+D*/ ) {
599
+ end_of_stream = true ;
600
+ break ;
601
+ }
602
+
603
+ if (is_special_char) {
604
+ console_set_color (con_st, CONSOLE_COLOR_USER_INPUT);
605
+ putchar (' \b ' );
606
+ putchar (line.back ());
607
+ is_special_char = false ;
608
+ }
609
+
610
+ if (input_char == ' \033 ' ) { // Escape sequence
611
+ CONSOLE_CHAR_TYPE code = CONSOLE_GET_CHAR ();
612
+ if (code == ' [' ) {
613
+ // Discard the rest of the escape sequence
614
+ while ((code = CONSOLE_GET_CHAR ()) != CONSOLE_EOF) {
615
+ if ((code >= ' A' && code <= ' Z' ) || (code >= ' a' && code <= ' z' ) || code == ' ~' ) {
616
+ break ;
617
+ }
618
+ }
619
+ }
620
+ } else if (input_char == 0x08 || input_char == 0x7F ) { // Backspace
621
+ if (!line.empty ()) {
622
+ fputs (" \b \b " , stdout); // Move cursor back, print a space, and move cursor back again
623
+ remove_last_utf8_char (line);
624
+ }
625
+ } else if (input_char < 32 ) {
626
+ // Ignore control characters
627
+ } else {
628
+ #if defined(_WIN32)
629
+ std::string utf8_char;
630
+ win32_utf8_encode (std::wstring (1 , input_char), utf8_char);
631
+ line += utf8_char;
632
+ fputs (utf8_char.c_str (), stdout);
633
+ #else
634
+ line += input_char;
635
+ putchar (input_char);
636
+ #endif
637
+ }
638
+
639
+ if (!line.empty () && (line.back () == ' \\ ' || line.back () == ' /' )) {
640
+ console_set_color (con_st, CONSOLE_COLOR_PROMPT);
641
+ putchar (' \b ' );
642
+ putchar (line.back ());
643
+ is_special_char = true ;
644
+ }
645
+ }
646
+
647
+ bool has_more = con_st.author_mode ;
648
+ if (is_special_char) {
649
+ fputs (" \b \b " , stdout); // Move cursor back, print a space, and move cursor back again
650
+
651
+ char last = line.back ();
652
+ line.pop_back ();
653
+ if (last == ' \\ ' ) {
654
+ line += ' \n ' ;
655
+ putchar (' \n ' );
656
+ has_more = !has_more;
657
+ } else {
658
+ // llama doesn't seem to process a single space
659
+ if (line.length () == 1 && line.back () == ' ' ) {
660
+ line.clear ();
661
+ putchar (' \b ' );
662
+ }
663
+ has_more = false ;
664
+ }
665
+ } else {
666
+ if (end_of_stream) {
667
+ has_more = false ;
668
+ } else {
669
+ line += ' \n ' ;
670
+ putchar (' \n ' );
671
+ }
672
+ }
673
+
674
+ fflush (stdout);
675
+ return has_more;
676
+ }
0 commit comments