@@ -512,6 +512,7 @@ void chat(
512
512
int prev_token;
513
513
int pos = 0 ; // position in the sequence
514
514
while (pos < steps) {
515
+
515
516
// when it is the user's turn to contribute tokens to the dialog...
516
517
if (user_turn) {
517
518
// get the (optional) system prompt at position 0
@@ -538,19 +539,21 @@ void chat(
538
539
}
539
540
// render user/system prompts into the Llama 2 Chat schema
540
541
if (pos == 0 && system_prompt[0 ] != ' \0 ' ) {
541
- const char system_template[] = " <s>[INST] <<SYS>>\n %s\n <</SYS>>\n\n %s [/INST]" ;
542
+ // We do not add <s> because that is added by tokenizer->encode(x, 1, 0)
543
+ const char system_template[] = " [INST] <<SYS>>\n %s\n <</SYS>>\n\n %s [/INST]" ;
542
544
snprintf (
543
545
rendered_prompt, RENDERED_PROMPT_SIZE-1 , system_template, system_prompt, user_prompt);
544
546
} else {
545
547
// Assistant should produce </s>, so we do not include it in template
546
- // "</s><s>[INST] %s [/INST]" for subsequent user inputs.
547
- const char user_template[] = " <s> [INST] %s [/INST]" ;
548
+ // We do not add <s> because that is added by tokenizer->encode(x, 1, 0)
549
+ const char user_template[] = " [INST] %s [/INST]" ;
548
550
snprintf (rendered_prompt, RENDERED_PROMPT_SIZE-1 , user_template, user_prompt);
549
551
}
550
552
551
553
// encode the rendered prompt into tokens
552
554
prompt_tokens = tokenizer->encode (rendered_prompt, 1 , 0 );
553
555
num_prompt_tokens = prompt_tokens.size ();
556
+
554
557
user_idx = 0 ; // reset the user index
555
558
user_turn = 0 ;
556
559
printf (" Assistant: " );
@@ -566,27 +569,27 @@ void chat(
566
569
token = next;
567
570
}
568
571
572
+ // forward the transformer to get logits for the next token
573
+ float * logits = forward (transformer, token, pos);
574
+ next = sample (sampler, logits);
575
+
576
+
569
577
if (token == EOS_TOKEN) {
570
578
user_turn = 1 ;
571
- pos++;
572
- } else {
573
- // forward the transformer to get logits for the next token
574
- float * logits = forward (transformer, token, pos);
575
- next = sample (sampler, logits);
576
- pos++;
577
-
578
- if (user_idx >= num_prompt_tokens && next != EOS_TOKEN && next != SOS_TOKEN) {
579
- // the Assistant is responding, so print its output
580
- std::string piece = tokenizer->decode (token, next);
581
- safe_printf (piece.c_str ()); // same as printf("%s", piece), but skips
582
- // "unsafe" bytes
583
- fflush (stdout);
584
- }
585
- if (next == EOS_TOKEN) {
586
- printf (" \n " );
587
- }
588
579
}
589
580
581
+ if (user_idx >= num_prompt_tokens && token != EOS_TOKEN && next != EOS_TOKEN) {
582
+ std::string piece = tokenizer->decode (token, next);
583
+ safe_printf (piece.c_str ()); // same as printf("%s", piece), but skips
584
+ // "unsafe" bytes
585
+ fflush (stdout);
586
+ }
587
+
588
+ if (next == EOS_TOKEN) {
589
+ printf (" \n " );
590
+ }
591
+ pos++;
592
+
590
593
}
591
594
printf (" \n " );
592
595
}
0 commit comments