@@ -22,7 +22,17 @@ extern "C" __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int
22
22
const wchar_t * lpWideCharStr, int cchWideChar,
23
23
char * lpMultiByteStr, int cbMultiByte,
24
24
const char * lpDefaultChar, bool * lpUsedDefaultChar);
25
+ #define ENABLE_LINE_INPUT 0x0002
26
+ #define ENABLE_ECHO_INPUT 0x0004
25
27
#define CP_UTF8 65001
28
+ #define CONSOLE_CHAR_TYPE wchar_t
29
+ #define CONSOLE_GET_CHAR () getwchar()
30
+ #define CONSOLE_EOF WEOF
31
+ #else
32
+ #include < unistd.h>
33
+ #define CONSOLE_CHAR_TYPE char
34
+ #define CONSOLE_GET_CHAR () getchar()
35
+ #define CONSOLE_EOF EOF
26
36
#endif
27
37
28
38
bool gpt_params_parse (int argc, char ** argv, gpt_params & params) {
@@ -208,6 +218,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
208
218
params.embedding = true ;
209
219
} else if (arg == " --interactive-first" ) {
210
220
params.interactive_first = true ;
221
+ } else if (arg == " --author-mode" ) {
222
+ params.author_mode = true ;
211
223
} else if (arg == " -ins" || arg == " --instruct" ) {
212
224
params.instruct = true ;
213
225
} else if (arg == " --color" ) {
@@ -291,6 +303,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
291
303
fprintf (stderr, " -i, --interactive run in interactive mode\n " );
292
304
fprintf (stderr, " --interactive-first run in interactive mode and wait for input right away\n " );
293
305
fprintf (stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n " );
306
+ fprintf (stderr, " --author-mode allows you to write or paste multiple lines without ending each in '\\ '\n " );
294
307
fprintf (stderr, " -r PROMPT, --reverse-prompt PROMPT\n " );
295
308
fprintf (stderr, " run in interactive mode and poll user input upon seeing PROMPT (can be\n " );
296
309
fprintf (stderr, " specified more than once for multiple prompts).\n " );
@@ -377,7 +390,7 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
377
390
}
378
391
379
392
/* Keep track of current color of output, and emit ANSI code if it changes. */
380
- void set_console_color (console_state & con_st, console_color_t color) {
393
+ void console_set_color (console_state & con_st, console_color_t color) {
381
394
if (con_st.use_color && con_st.color != color) {
382
395
switch (color) {
383
396
case CONSOLE_COLOR_DEFAULT:
@@ -394,8 +407,9 @@ void set_console_color(console_state & con_st, console_color_t color) {
394
407
}
395
408
}
396
409
410
+ void console_init (console_state & con_st) {
397
411
#if defined (_WIN32)
398
- void win32_console_init ( bool enable_color) {
412
+ // Windows-specific console initialization
399
413
unsigned long dwMode = 0 ;
400
414
void * hConOut = GetStdHandle ((unsigned long )-11 ); // STD_OUTPUT_HANDLE (-11)
401
415
if (!hConOut || hConOut == (void *)-1 || !GetConsoleMode (hConOut, &dwMode)) {
@@ -406,7 +420,7 @@ void win32_console_init(bool enable_color) {
406
420
}
407
421
if (hConOut) {
408
422
// Enable ANSI colors on Windows 10+
409
- if (enable_color && !(dwMode & 0x4 )) {
423
+ if (con_st. use_color && !(dwMode & 0x4 )) {
410
424
SetConsoleMode (hConOut, dwMode | 0x4 ); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4)
411
425
}
412
426
// Set console output codepage to UTF8
@@ -416,9 +430,46 @@ void win32_console_init(bool enable_color) {
416
430
if (hConIn && hConIn != (void *)-1 && GetConsoleMode (hConIn, &dwMode)) {
417
431
// Set console input codepage to UTF16
418
432
_setmode (_fileno (stdin), _O_WTEXT);
433
+
434
+ // Turn off ICANON (ENABLE_LINE_INPUT) and ECHO (ENABLE_ECHO_INPUT)
435
+ dwMode &= ~(ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT);
436
+ SetConsoleMode (hConIn, dwMode);
437
+ }
438
+ #else
439
+ // POSIX-specific console initialization
440
+ struct termios new_termios;
441
+ tcgetattr (STDIN_FILENO, &con_st.prev_state );
442
+ new_termios = con_st.prev_state ;
443
+ new_termios.c_lflag &= ~(ICANON | ECHO);
444
+ new_termios.c_cc [VMIN] = 1 ;
445
+ new_termios.c_cc [VTIME] = 0 ;
446
+ tcsetattr (STDIN_FILENO, TCSANOW, &new_termios);
447
+ #endif
448
+ }
449
+
450
+ void console_cleanup (console_state & con_st) {
451
+ #if !defined(_WIN32)
452
+ // Restore the terminal settings on POSIX systems
453
+ tcsetattr (STDIN_FILENO, TCSANOW, &con_st.prev_state );
454
+ #endif
455
+
456
+ // Reset console color
457
+ console_set_color (con_st, CONSOLE_COLOR_DEFAULT);
458
+ }
459
+
460
+ // Helper function to remove the last UTF-8 character from a string
461
+ void remove_last_utf8_char (std::string & line) {
462
+ if (line.empty ()) return ;
463
+ size_t pos = line.length () - 1 ;
464
+
465
+ // Find the start of the last UTF-8 character (checking up to 4 bytes back)
466
+ for (size_t i = 0 ; i < 3 && pos > 0 ; ++i, --pos) {
467
+ if ((line[pos] & 0xC0 ) != 0x80 ) break ; // Found the start of the character
419
468
}
469
+ line.erase (pos);
420
470
}
421
471
472
+ #if defined (_WIN32)
422
473
// Convert a wide Unicode string to an UTF8 string
423
474
void win32_utf8_encode (const std::wstring & wstr, std::string & str) {
424
475
int size_needed = WideCharToMultiByte (CP_UTF8, 0 , &wstr[0 ], (int )wstr.size (), NULL , 0 , NULL , NULL );
@@ -427,3 +478,99 @@ void win32_utf8_encode(const std::wstring & wstr, std::string & str) {
427
478
str = strTo;
428
479
}
429
480
#endif
481
+
482
+ bool console_readline (console_state & con_st, std::string & line) {
483
+ line.clear ();
484
+ bool is_special_char = false ;
485
+ bool end_of_stream = false ;
486
+
487
+ console_set_color (con_st, CONSOLE_COLOR_USER_INPUT);
488
+
489
+ CONSOLE_CHAR_TYPE input_char;
490
+ while (true ) {
491
+ fflush (stdout); // Ensure all output is displayed before waiting for input
492
+ input_char = CONSOLE_GET_CHAR ();
493
+
494
+ if (input_char == ' \r ' || input_char == ' \n ' ) {
495
+ break ;
496
+ }
497
+
498
+ if (input_char == CONSOLE_EOF || input_char == 0x04 /* Ctrl+D*/ ) {
499
+ end_of_stream = true ;
500
+ break ;
501
+ }
502
+
503
+ if (is_special_char) {
504
+ console_set_color (con_st, CONSOLE_COLOR_USER_INPUT);
505
+ putchar (' \b ' );
506
+ putchar (line.back ());
507
+ is_special_char = false ;
508
+ }
509
+
510
+ if (input_char == ' \033 ' ) { // Escape sequence
511
+ CONSOLE_CHAR_TYPE code = CONSOLE_GET_CHAR ();
512
+ if (code == ' [' ) {
513
+ // Discard the rest of the escape sequence
514
+ while ((code = CONSOLE_GET_CHAR ()) != CONSOLE_EOF) {
515
+ if ((code >= ' A' && code <= ' Z' ) || (code >= ' a' && code <= ' z' ) || code == ' ~' ) {
516
+ break ;
517
+ }
518
+ }
519
+ }
520
+ } else if (input_char == 0x08 || input_char == 0x7F ) { // Backspace
521
+ if (!line.empty ()) {
522
+ fputs (" \b \b " , stdout); // Move cursor back, print a space, and move cursor back again
523
+ remove_last_utf8_char (line);
524
+ }
525
+ } else if (input_char < 32 ) {
526
+ // Ignore control characters
527
+ } else {
528
+ #if defined(_WIN32)
529
+ std::string utf8_char;
530
+ win32_utf8_encode (std::wstring (1 , input_char), utf8_char);
531
+ line += utf8_char;
532
+ fputs (utf8_char.c_str (), stdout);
533
+ #else
534
+ line += input_char;
535
+ putchar (input_char);
536
+ #endif
537
+ }
538
+
539
+ if (!line.empty () && (line.back () == ' \\ ' || line.back () == ' /' )) {
540
+ console_set_color (con_st, CONSOLE_COLOR_PROMPT);
541
+ putchar (' \b ' );
542
+ putchar (line.back ());
543
+ is_special_char = true ;
544
+ }
545
+ }
546
+
547
+ bool has_more = con_st.author_mode ;
548
+ if (is_special_char) {
549
+ fputs (" \b \b " , stdout); // Move cursor back, print a space, and move cursor back again
550
+
551
+ char last = line.back ();
552
+ line.pop_back ();
553
+ if (last == ' \\ ' ) {
554
+ line += ' \n ' ;
555
+ putchar (' \n ' );
556
+ has_more = !has_more;
557
+ } else {
558
+ // llama doesn't seem to process a single space
559
+ if (line.length () == 1 && line.back () == ' ' ) {
560
+ line.clear ();
561
+ putchar (' \b ' );
562
+ }
563
+ has_more = false ;
564
+ }
565
+ } else {
566
+ if (end_of_stream) {
567
+ has_more = false ;
568
+ } else {
569
+ line += ' \n ' ;
570
+ putchar (' \n ' );
571
+ }
572
+ }
573
+
574
+ fflush (stdout);
575
+ return has_more;
576
+ }
0 commit comments