Skip to content

Commit 17ba387

Browse files
metascroymalfet
authored andcommitted
fix run chat (#482)
1 parent 16fea12 commit 17ba387

File tree

1 file changed

+23
-20
lines changed

1 file changed

+23
-20
lines changed

runner/run.cpp

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,7 @@ void chat(
512512
int prev_token;
513513
int pos = 0; // position in the sequence
514514
while (pos < steps) {
515+
515516
// when it is the user's turn to contribute tokens to the dialog...
516517
if (user_turn) {
517518
// get the (optional) system prompt at position 0
@@ -538,19 +539,21 @@ void chat(
538539
}
539540
// render user/system prompts into the Llama 2 Chat schema
540541
if (pos == 0 && system_prompt[0] != '\0') {
541-
const char system_template[] = "<s>[INST] <<SYS>>\n%s\n<</SYS>>\n\n%s [/INST]";
542+
// We do not add <s> because that is added by tokenizer->encode(x, 1, 0)
543+
const char system_template[] = "[INST] <<SYS>>\n%s\n<</SYS>>\n\n%s [/INST]";
542544
snprintf(
543545
rendered_prompt, RENDERED_PROMPT_SIZE-1, system_template, system_prompt, user_prompt);
544546
} else {
545547
// Assistant should produce </s>, so we do not include it in template
546-
// "</s><s>[INST] %s [/INST]" for subsequent user inputs.
547-
const char user_template[] = "<s>[INST] %s [/INST]";
548+
// We do not add <s> because that is added by tokenizer->encode(x, 1, 0)
549+
const char user_template[] = "[INST] %s [/INST]";
548550
snprintf(rendered_prompt, RENDERED_PROMPT_SIZE-1, user_template, user_prompt);
549551
}
550552

551553
// encode the rendered prompt into tokens
552554
prompt_tokens = tokenizer->encode(rendered_prompt, 1, 0);
553555
num_prompt_tokens = prompt_tokens.size();
556+
554557
user_idx = 0; // reset the user index
555558
user_turn = 0;
556559
printf("Assistant: ");
@@ -566,27 +569,27 @@ void chat(
566569
token = next;
567570
}
568571

572+
// forward the transformer to get logits for the next token
573+
float* logits = forward(transformer, token, pos);
574+
next = sample(sampler, logits);
575+
576+
569577
if (token == EOS_TOKEN) {
570578
user_turn = 1;
571-
pos++;
572-
} else {
573-
// forward the transformer to get logits for the next token
574-
float* logits = forward(transformer, token, pos);
575-
next = sample(sampler, logits);
576-
pos++;
577-
578-
if (user_idx >= num_prompt_tokens && next != EOS_TOKEN && next != SOS_TOKEN) {
579-
// the Assistant is responding, so print its output
580-
std::string piece = tokenizer->decode(token, next);
581-
safe_printf(piece.c_str()); // same as printf("%s", piece), but skips
582-
// "unsafe" bytes
583-
fflush(stdout);
584-
}
585-
if (next == EOS_TOKEN) {
586-
printf("\n");
587-
}
588579
}
589580

581+
if (user_idx >= num_prompt_tokens && token != EOS_TOKEN && next != EOS_TOKEN) {
582+
std::string piece = tokenizer->decode(token, next);
583+
safe_printf(piece.c_str()); // same as printf("%s", piece), but skips
584+
// "unsafe" bytes
585+
fflush(stdout);
586+
}
587+
588+
if (next == EOS_TOKEN) {
589+
printf("\n");
590+
}
591+
pos++;
592+
590593
}
591594
printf("\n");
592595
}

0 commit comments

Comments
 (0)