Skip to content

Fix color codes emitting mid-UTF8 code. #312

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 21, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 48 additions & 13 deletions main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,36 @@ extern "C" __declspec(dllimport) int __stdcall SetConsoleMode(void* hConsoleHand
#define ANSI_COLOR_RESET "\x1b[0m"
#define ANSI_BOLD "\x1b[1m"

/* Keep track of current color of output, and emit ANSI code if it changes. */
enum console_state {
CONSOLE_STATE_DEFAULT=0,
CONSOLE_STATE_PROMPT,
CONSOLE_STATE_USER_INPUT
};

static console_state con_st = CONSOLE_STATE_DEFAULT;
static bool con_use_color = false;

void set_console_state(console_state new_st)
{
if (!con_use_color) return;
// only emit color code if state changed
if (new_st != con_st) {
con_st = new_st;
switch(con_st) {
case CONSOLE_STATE_DEFAULT:
printf(ANSI_COLOR_RESET);
return;
case CONSOLE_STATE_PROMPT:
printf(ANSI_COLOR_YELLOW);
return;
case CONSOLE_STATE_USER_INPUT:
printf(ANSI_BOLD ANSI_COLOR_GREEN);
return;
}
}
}

static const int EOS_TOKEN_ID = 2;

// determine number of model parts based on the dimension
Expand Down Expand Up @@ -866,7 +896,7 @@ static bool is_interacting = false;

#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
void sigint_handler(int signo) {
printf(ANSI_COLOR_RESET);
set_console_state(CONSOLE_STATE_DEFAULT);
printf("\n"); // this also force flush stdout.
if (signo == SIGINT) {
if (!is_interacting) {
Expand Down Expand Up @@ -925,6 +955,10 @@ int main(int argc, char ** argv) {
params.prompt = gpt_random_prompt(rng);
}

// save choice to use color for later
// (note for later: this is a slightly awkward choice)
con_use_color = params.use_color;

// params.prompt = R"(// this function checks if the number n is prime
//bool is_prime(int n) {)";

Expand Down Expand Up @@ -1040,18 +1074,18 @@ int main(int argc, char ** argv) {

int remaining_tokens = params.n_predict;

// set the color for the prompt which will be output initially
if (params.use_color) {
#if defined (_WIN32)
if (params.use_color) {
// Enable ANSI colors on Windows 10+
unsigned long dwMode = 0;
void* hConOut = GetStdHandle((unsigned long)-11); // STD_OUTPUT_HANDLE (-11)
if (hConOut && hConOut != (void*)-1 && GetConsoleMode(hConOut, &dwMode) && !(dwMode & 0x4)) {
SetConsoleMode(hConOut, dwMode | 0x4); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4)
}
#endif
printf(ANSI_COLOR_YELLOW);
}
#endif
// the first thing we will do is to output the prompt, so set color accordingly
set_console_state(CONSOLE_STATE_PROMPT);

while (remaining_tokens > 0 || params.interactive) {
// predict
Expand Down Expand Up @@ -1125,8 +1159,8 @@ int main(int argc, char ** argv) {
fflush(stdout);
}
// reset color to default if we there is no pending user input
if (!input_noecho && params.use_color && (int)embd_inp.size() == input_consumed) {
printf(ANSI_COLOR_RESET);
if (!input_noecho && (int)embd_inp.size() == input_consumed) {
set_console_state(CONSOLE_STATE_DEFAULT);
}

// in interactive mode, and not currently processing queued inputs;
Expand All @@ -1146,15 +1180,16 @@ int main(int argc, char ** argv) {
}
}
if (is_interacting) {
// potentially set color to indicate we are taking user input
set_console_state(CONSOLE_STATE_USER_INPUT);

if (params.instruct) {
input_consumed = embd_inp.size();
embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());

printf("\n> ");
}

// currently being interactive
if (params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN);
std::string buffer;
std::string line;
bool another_line = true;
Expand All @@ -1167,7 +1202,9 @@ int main(int argc, char ** argv) {
}
buffer += line + '\n'; // Append the line to the result
} while (another_line);
if (params.use_color) printf(ANSI_COLOR_RESET);

// done taking input, reset color
set_console_state(CONSOLE_STATE_DEFAULT);

std::vector<llama_vocab::id> line_inp = ::llama_tokenize(vocab, buffer, false);
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
Expand Down Expand Up @@ -1218,9 +1255,7 @@ int main(int argc, char ** argv) {

ggml_free(model.ctx);

if (params.use_color) {
printf(ANSI_COLOR_RESET);
}
set_console_state(CONSOLE_STATE_DEFAULT);

return 0;
}