@@ -20,7 +20,17 @@ extern "C" __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int
20
20
const wchar_t * lpWideCharStr, int cchWideChar,
21
21
char * lpMultiByteStr, int cbMultiByte,
22
22
const char * lpDefaultChar, bool * lpUsedDefaultChar);
23
+ #define ENABLE_LINE_INPUT 0x0002
24
+ #define ENABLE_ECHO_INPUT 0x0004
23
25
#define CP_UTF8 65001
26
+ #define CONSOLE_CHAR_TYPE wchar_t
27
+ #define CONSOLE_GET_CHAR () getwchar()
28
+ #define CONSOLE_EOF WEOF
29
+ #else
30
+ #include < unistd.h>
31
+ #define CONSOLE_CHAR_TYPE char
32
+ #define CONSOLE_GET_CHAR () getchar()
33
+ #define CONSOLE_EOF EOF
24
34
#endif
25
35
26
36
bool gpt_params_parse (int argc, char ** argv, gpt_params & params) {
@@ -158,6 +168,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
158
168
params.embedding = true ;
159
169
} else if (arg == " --interactive-first" ) {
160
170
params.interactive_first = true ;
171
+ } else if (arg == " --author-mode" ) {
172
+ params.author_mode = true ;
161
173
} else if (arg == " -ins" || arg == " --instruct" ) {
162
174
params.instruct = true ;
163
175
} else if (arg == " --color" ) {
@@ -220,6 +232,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
220
232
fprintf (stderr, " -i, --interactive run in interactive mode\n " );
221
233
fprintf (stderr, " --interactive-first run in interactive mode and wait for input right away\n " );
222
234
fprintf (stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n " );
235
+ fprintf (stderr, " --author-mode allows you to write or paste multiple lines without ending each in '\\ '\n " );
223
236
fprintf (stderr, " -r PROMPT, --reverse-prompt PROMPT\n " );
224
237
fprintf (stderr, " run in interactive mode and poll user input upon seeing PROMPT (can be\n " );
225
238
fprintf (stderr, " specified more than once for multiple prompts).\n " );
@@ -291,7 +304,7 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
291
304
}
292
305
293
306
/* Keep track of current color of output, and emit ANSI code if it changes. */
294
- void set_console_color (console_state & con_st, console_color_t color) {
307
+ void console_set_color (console_state & con_st, console_color_t color) {
295
308
if (con_st.use_color && con_st.color != color) {
296
309
switch (color) {
297
310
case CONSOLE_COLOR_DEFAULT:
@@ -308,8 +321,9 @@ void set_console_color(console_state & con_st, console_color_t color) {
308
321
}
309
322
}
310
323
324
+ void console_init (console_state & con_st) {
311
325
#if defined (_WIN32)
312
- void win32_console_init ( bool enable_color) {
326
+ // Windows-specific console initialization
313
327
unsigned long dwMode = 0 ;
314
328
void * hConOut = GetStdHandle ((unsigned long )-11 ); // STD_OUTPUT_HANDLE (-11)
315
329
if (!hConOut || hConOut == (void *)-1 || !GetConsoleMode (hConOut, &dwMode)) {
@@ -320,7 +334,7 @@ void win32_console_init(bool enable_color) {
320
334
}
321
335
if (hConOut) {
322
336
// Enable ANSI colors on Windows 10+
323
- if (enable_color && !(dwMode & 0x4 )) {
337
+ if (con_st. use_color && !(dwMode & 0x4 )) {
324
338
SetConsoleMode (hConOut, dwMode | 0x4 ); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4)
325
339
}
326
340
// Set console output codepage to UTF8
@@ -330,9 +344,46 @@ void win32_console_init(bool enable_color) {
330
344
if (hConIn && hConIn != (void *)-1 && GetConsoleMode (hConIn, &dwMode)) {
331
345
// Set console input codepage to UTF16
332
346
_setmode (_fileno (stdin), _O_WTEXT);
347
+
348
+ // Turn off ICANON (ENABLE_LINE_INPUT) and ECHO (ENABLE_ECHO_INPUT)
349
+ dwMode &= ~(ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT);
350
+ SetConsoleMode (hConIn, dwMode);
351
+ }
352
+ #else
353
+ // POSIX-specific console initialization
354
+ struct termios new_termios;
355
+ tcgetattr (STDIN_FILENO, &con_st.prev_state );
356
+ new_termios = con_st.prev_state ;
357
+ new_termios.c_lflag &= ~(ICANON | ECHO);
358
+ new_termios.c_cc [VMIN] = 1 ;
359
+ new_termios.c_cc [VTIME] = 0 ;
360
+ tcsetattr (STDIN_FILENO, TCSANOW, &new_termios);
361
+ #endif
362
+ }
363
+
364
+ void console_cleanup (console_state & con_st) {
365
+ #if !defined(_WIN32)
366
+ // Restore the terminal settings on POSIX systems
367
+ tcsetattr (STDIN_FILENO, TCSANOW, &con_st.prev_state );
368
+ #endif
369
+
370
+ // Reset console color
371
+ console_set_color (con_st, CONSOLE_COLOR_DEFAULT);
372
+ }
373
+
374
+ // Helper function to remove the last UTF-8 character from a string
375
+ void remove_last_utf8_char (std::string & line) {
376
+ if (line.empty ()) return ;
377
+ size_t pos = line.length () - 1 ;
378
+
379
+ // Find the start of the last UTF-8 character (checking up to 4 bytes back)
380
+ for (size_t i = 0 ; i < 3 && pos > 0 ; ++i, --pos) {
381
+ if ((line[pos] & 0xC0 ) != 0x80 ) break ; // Found the start of the character
333
382
}
383
+ line.erase (pos);
334
384
}
335
385
386
+ #if defined (_WIN32)
336
387
// Convert a wide Unicode string to an UTF8 string
337
388
void win32_utf8_encode (const std::wstring & wstr, std::string & str) {
338
389
int size_needed = WideCharToMultiByte (CP_UTF8, 0 , &wstr[0 ], (int )wstr.size (), NULL , 0 , NULL , NULL );
@@ -341,3 +392,99 @@ void win32_utf8_encode(const std::wstring & wstr, std::string & str) {
341
392
str = strTo;
342
393
}
343
394
#endif
395
+
396
+ bool console_readline (console_state & con_st, std::string & line) {
397
+ line.clear ();
398
+ bool is_special_char = false ;
399
+ bool end_of_stream = false ;
400
+
401
+ console_set_color (con_st, CONSOLE_COLOR_USER_INPUT);
402
+
403
+ CONSOLE_CHAR_TYPE input_char;
404
+ while (true ) {
405
+ fflush (stdout); // Ensure all output is displayed before waiting for input
406
+ input_char = CONSOLE_GET_CHAR ();
407
+
408
+ if (input_char == ' \r ' || input_char == ' \n ' ) {
409
+ break ;
410
+ }
411
+
412
+ if (input_char == CONSOLE_EOF || input_char == 0x04 /* Ctrl+D*/ ) {
413
+ end_of_stream = true ;
414
+ break ;
415
+ }
416
+
417
+ if (is_special_char) {
418
+ console_set_color (con_st, CONSOLE_COLOR_USER_INPUT);
419
+ putchar (' \b ' );
420
+ putchar (line.back ());
421
+ is_special_char = false ;
422
+ }
423
+
424
+ if (input_char == ' \033 ' ) { // Escape sequence
425
+ CONSOLE_CHAR_TYPE code = CONSOLE_GET_CHAR ();
426
+ if (code == ' [' ) {
427
+ // Discard the rest of the escape sequence
428
+ while ((code = CONSOLE_GET_CHAR ()) != CONSOLE_EOF) {
429
+ if ((code >= ' A' && code <= ' Z' ) || (code >= ' a' && code <= ' z' ) || code == ' ~' ) {
430
+ break ;
431
+ }
432
+ }
433
+ }
434
+ } else if (input_char == 0x08 || input_char == 0x7F ) { // Backspace
435
+ if (!line.empty ()) {
436
+ fputs (" \b \b " , stdout); // Move cursor back, print a space, and move cursor back again
437
+ remove_last_utf8_char (line);
438
+ }
439
+ } else if (input_char < 32 ) {
440
+ // Ignore control characters
441
+ } else {
442
+ #if defined(_WIN32)
443
+ std::string utf8_char;
444
+ win32_utf8_encode (std::wstring (1 , input_char), utf8_char);
445
+ line += utf8_char;
446
+ fputs (utf8_char.c_str (), stdout);
447
+ #else
448
+ line += input_char;
449
+ putchar (input_char);
450
+ #endif
451
+ }
452
+
453
+ if (!line.empty () && (line.back () == ' \\ ' || line.back () == ' /' )) {
454
+ console_set_color (con_st, CONSOLE_COLOR_PROMPT);
455
+ putchar (' \b ' );
456
+ putchar (line.back ());
457
+ is_special_char = true ;
458
+ }
459
+ }
460
+
461
+ bool has_more = con_st.author_mode ;
462
+ if (is_special_char) {
463
+ fputs (" \b \b " , stdout); // Move cursor back, print a space, and move cursor back again
464
+
465
+ char last = line.back ();
466
+ line.pop_back ();
467
+ if (last == ' \\ ' ) {
468
+ line += ' \n ' ;
469
+ putchar (' \n ' );
470
+ has_more = !has_more;
471
+ } else {
472
+ // llama doesn't seem to process a single space
473
+ if (line.length () == 1 && line.back () == ' ' ) {
474
+ line.clear ();
475
+ putchar (' \b ' );
476
+ }
477
+ has_more = false ;
478
+ }
479
+ } else {
480
+ if (end_of_stream) {
481
+ has_more = false ;
482
+ } else {
483
+ line += ' \n ' ;
484
+ putchar (' \n ' );
485
+ }
486
+ }
487
+
488
+ fflush (stdout);
489
+ return has_more;
490
+ }
0 commit comments