Skip to content

Commit e761b0c

Browse files
metascroymalfet
authored andcommitted
Support llama3 in chat in run.cpp (#486)
* refactor chat runner in preparation for llama3 * add sketch for llama3 prompt template and move to returning tokens * fix tiktoken * fixes to chat * add default llama_ver
1 parent 4b8fe1a commit e761b0c

File tree

2 files changed

+219
-67
lines changed

2 files changed

+219
-67
lines changed

runner/run.cpp

Lines changed: 213 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
/* Inference for Llama-2 Transformer model in pure C++ */
2+
#include <cstdint>
3+
#include <cstdlib>
24
#include <ctype.h>
5+
#include <iterator>
36
#include <math.h>
47
#include <stdint.h>
58
#include <stdio.h>
69
#include <stdlib.h>
710
#include <string.h>
811
#include <time.h>
912
#include <tokenizer.h>
13+
#include <string>
14+
1015

1116
#ifdef DEBUG
1217
#include <cassert>
@@ -485,27 +490,184 @@ void read_stdin(const char* guide, char* buffer, size_t bufsize) {
485490
// python reference and that seemed ok, but this was not thoroughly tested and
486491
// is not safely implemented, it's more a proof of concept atm.
487492

493+
enum class ModelType {
494+
unknown,
495+
llama2,
496+
llama3,
497+
};
498+
499+
ModelType get_model_type(Tokenizer* tokenizer) {
500+
if (BPETokenizer* t = dynamic_cast<BPETokenizer*>(tokenizer)) {
501+
return ModelType::llama2;
502+
} else if (Tiktoken* t = dynamic_cast<Tiktoken*>(tokenizer)) {
503+
return ModelType::llama3;
504+
} else {
505+
return ModelType::unknown;
506+
}
507+
}
508+
509+
uint64_t get_eot_token(Tokenizer* tokenizer) {
510+
ModelType model_type = get_model_type(tokenizer);
511+
512+
if (model_type == ModelType::llama2) {
513+
// llama2 uses EOS as EOT token
514+
return tokenizer->eos_tok();
515+
}
516+
517+
if (model_type == ModelType::llama3) {
518+
auto tokens = tokenizer->encode("<|eot_id|>", 0, 0);
519+
return tokens[0];
520+
}
521+
522+
fprintf(stderr, "No chat template implemnation for model type %d", model_type);
523+
exit(EXIT_FAILURE);
524+
}
525+
526+
std::vector<uint64_t> get_initial_prompt_tokens(const char* cli_system_prompt, const char* cli_user_prompt, Tokenizer* tokenizer) {
527+
char system_prompt[512];
528+
char user_prompt[512];
529+
char rendered_prompt[512*2 + 200]; // the prompt template is ~170 characters. We use 200 to be safe.
530+
531+
if (cli_system_prompt != NULL) {
532+
strcpy(system_prompt, cli_system_prompt);
533+
} else {
534+
read_stdin("Enter system prompt (optional): ", system_prompt, sizeof(system_prompt));
535+
}
536+
537+
if (cli_user_prompt != NULL) {
538+
strcpy(user_prompt, cli_user_prompt);
539+
} else {
540+
read_stdin("User: ", user_prompt, sizeof(user_prompt));
541+
}
542+
543+
ModelType model_type = get_model_type(tokenizer);
544+
std::vector<uint64_t> tokens;
545+
546+
switch (model_type) {
547+
548+
case ModelType::llama2:
549+
if (system_prompt[0] != '\0') {
550+
snprintf(
551+
rendered_prompt,
552+
sizeof(rendered_prompt)-1,
553+
"[INST] <<SYS>>\n%s\n<</SYS>>\n\n%s [/INST]",
554+
system_prompt,
555+
user_prompt
556+
);
557+
} else {
558+
// const char prompt_template[] = ;
559+
snprintf(
560+
rendered_prompt,
561+
sizeof(rendered_prompt)-1,
562+
"[INST] %s [/INST]",
563+
user_prompt
564+
);
565+
}
566+
567+
// We need to add BOS token here and not in template because llama2 tokenizer
568+
// does not pattern match special tokens
569+
tokens = tokenizer->encode(rendered_prompt, 1, 0);
570+
break;
571+
572+
case ModelType::llama3:
573+
if (system_prompt[0] != '\0') {
574+
snprintf(
575+
rendered_prompt,
576+
sizeof(rendered_prompt)-1,
577+
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
578+
system_prompt,
579+
user_prompt
580+
);
581+
} else {
582+
snprintf(
583+
rendered_prompt,
584+
sizeof(rendered_prompt)-1,
585+
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
586+
user_prompt
587+
);
588+
}
589+
tokens = tokenizer->encode(rendered_prompt, 0, 0);
590+
break;
591+
592+
default:
593+
fprintf(stderr, "No chat template implemnation for model type %d", model_type);
594+
exit(EXIT_FAILURE);
595+
}
596+
597+
#ifdef DEBUG
598+
std::cerr << "Start of rendered prompt:" << std::endl;
599+
std::cerr << rendered_prompt;
600+
std::cerr << "End of rendered prompt:" << std::endl;
601+
std::cerr << "Encoded prompt: ";
602+
for (int i = 0; i < tokens.size(); i++) {
603+
std::cerr << tokens[i] << ", ";
604+
}
605+
std::cerr << std::endl << std::flush;
606+
#endif
607+
608+
return tokens;
609+
}
610+
611+
std::vector<uint64_t> get_next_user_prompt_tokens(Tokenizer* tokenizer) {
612+
char user_prompt[512];
613+
char rendered_prompt[512 + 150]; // the prompt template is ~100 characters. We use 150 to be safe.
614+
615+
read_stdin("User: ", user_prompt, sizeof(user_prompt));
616+
617+
ModelType model_type = get_model_type(tokenizer);
618+
std::vector<uint64_t> tokens;
619+
620+
switch (model_type) {
621+
622+
case ModelType::llama2:
623+
// const char prompt_template[] = ;
624+
snprintf(rendered_prompt, sizeof(rendered_prompt)-1, "[INST] %s [/INST]", user_prompt);
625+
626+
// We need to add BOS token here and not in template because llama2 tokenizer
627+
// does not pattern match special tokens
628+
tokens = tokenizer->encode(rendered_prompt, /*bos*/1, /*eos*/0);
629+
break;
630+
631+
case ModelType::llama3:
632+
snprintf(
633+
rendered_prompt,
634+
sizeof(rendered_prompt)-1,
635+
"<|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
636+
user_prompt
637+
);
638+
tokens = tokenizer->encode(rendered_prompt, 0, 0);
639+
break;
640+
641+
default:
642+
fprintf(stderr, "No chat template implemnation for model type %d", model_type);
643+
exit(EXIT_FAILURE);
644+
}
645+
646+
647+
#ifdef DEBUG
648+
std::cerr << "Start of rendered prompt:" << std::endl;
649+
std::cerr << rendered_prompt;
650+
std::cerr << "End of rendered prompt:" << std::endl;
651+
std::cerr << "Encoded prompt: ";
652+
for (int i = 0; i < tokens.size(); i++) {
653+
std::cerr << tokens[i] << ", ";
654+
}
655+
std::cerr << std::endl << std::flush;
656+
#endif
657+
658+
return tokens;
659+
}
660+
661+
488662
void chat(
489663
Transformer* transformer,
490664
Tokenizer* tokenizer,
491665
Sampler* sampler,
492666
const char* cli_user_prompt,
493667
const char* cli_system_prompt,
494668
int steps) {
495-
// special tokens
496-
const int SOS_TOKEN = tokenizer->bos_tok(); // token starts the assistant turn
497-
const int EOS_TOKEN = tokenizer->eos_tok(); // token ends the assistant turn
498-
const int SYSTEM_PROMPT_SIZE = 512;
499-
const int USER_PROMPT_SIZE = 512;
500-
const int RENDERED_PROMPT_SIZE = SYSTEM_PROMPT_SIZE + USER_PROMPT_SIZE + 128; // This is big enough to hold the expanded template
501-
502-
503669

504-
// buffers for reading the system prompt and user prompt from stdin
505-
// you'll notice they are soomewhat haphazardly and unsafely set atm
506-
char system_prompt[SYSTEM_PROMPT_SIZE];
507-
char user_prompt[USER_PROMPT_SIZE];
508-
char rendered_prompt[RENDERED_PROMPT_SIZE];
670+
const uint64_t EOT_TOKEN = get_eot_token(tokenizer);
509671
int num_prompt_tokens = 0;
510672
std::vector<uint64_t> prompt_tokens;
511673
int user_idx;
@@ -522,41 +684,10 @@ void chat(
522684
if (user_turn) {
523685
// get the (optional) system prompt at position 0
524686
if (pos == 0) {
525-
// at position 0, the user can also contribute a system prompt
526-
if (cli_system_prompt == NULL) {
527-
// system prompt was not passed in, attempt to get it from stdin
528-
read_stdin(
529-
"Enter system prompt (optional): ",
530-
system_prompt,
531-
sizeof(system_prompt));
532-
} else {
533-
// system prompt was passed in, use it
534-
strcpy(system_prompt, cli_system_prompt);
535-
}
536-
}
537-
// get the user prompt
538-
if (pos == 0 && cli_user_prompt != NULL) {
539-
// user prompt for position 0 was passed in, use it
540-
strcpy(user_prompt, cli_user_prompt);
687+
prompt_tokens = get_initial_prompt_tokens(cli_system_prompt, cli_user_prompt, tokenizer);
541688
} else {
542-
// otherwise get user prompt from stdin
543-
read_stdin("User: ", user_prompt, sizeof(user_prompt));
689+
prompt_tokens = get_next_user_prompt_tokens(tokenizer);
544690
}
545-
// render user/system prompts into the Llama 2 Chat schema
546-
if (pos == 0 && system_prompt[0] != '\0') {
547-
// We do not add <s> because that is added by tokenizer->encode(x, 1, 0)
548-
const char system_template[] = "[INST] <<SYS>>\n%s\n<</SYS>>\n\n%s [/INST]";
549-
snprintf(
550-
rendered_prompt, RENDERED_PROMPT_SIZE-1, system_template, system_prompt, user_prompt);
551-
} else {
552-
// Assistant should produce </s>, so we do not include it in template
553-
// We do not add <s> because that is added by tokenizer->encode(x, 1, 0)
554-
const char user_template[] = "[INST] %s [/INST]";
555-
snprintf(rendered_prompt, RENDERED_PROMPT_SIZE-1, user_template, user_prompt);
556-
}
557-
558-
// encode the rendered prompt into tokens
559-
prompt_tokens = tokenizer->encode(rendered_prompt, 1, 0);
560691
num_prompt_tokens = prompt_tokens.size();
561692

562693
user_idx = 0; // reset the user index
@@ -578,19 +709,21 @@ void chat(
578709
float* logits = forward(transformer, token, pos);
579710
next = sample(sampler, logits);
580711

712+
// std::cout << "TOKEN: " << token << " NEXT: " << next << std::endl;
581713

582-
if (token == EOS_TOKEN) {
714+
715+
if ((user_idx >= num_prompt_tokens) && (token == EOT_TOKEN)) {
583716
user_turn = 1;
584717
}
585718

586-
if (user_idx >= num_prompt_tokens && token != EOS_TOKEN && next != EOS_TOKEN) {
719+
if (user_idx >= num_prompt_tokens && token != EOT_TOKEN && next != EOT_TOKEN) {
587720
std::string piece = tokenizer->decode(token, next);
588721
safe_printf(piece.c_str()); // same as printf("%s", piece), but skips
589722
// "unsafe" bytes
590723
fflush(stdout);
591724
}
592725

593-
if (next == EOS_TOKEN) {
726+
if (next == EOT_TOKEN) {
594727
printf("\n");
595728
}
596729
pos++;
@@ -619,6 +752,7 @@ void error_usage() {
619752
fprintf(stderr, " -z <string> optional path to custom tokenizer\n");
620753
fprintf(stderr, " -m <string> mode: generate|chat, default: generate\n");
621754
fprintf(stderr, " -y <string> (optional) system prompt in chat mode\n");
755+
fprintf(stderr, " -l <int> (optional) llama version (2 or 3). Defaults to 2.\n");
622756
exit(EXIT_FAILURE);
623757
}
624758

@@ -630,14 +764,17 @@ int main(int argc, char* argv[]) {
630764
1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
631765
float topp =
632766
0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
633-
int vocab_size = 32000;
767+
634768
int steps = 256; // number of steps to run for
635769
const char* prompt = NULL; // prompt string
636770
unsigned long long rng_seed = 0; // seed rng with time by default
637771
const char* mode = "generate"; // generate|chat
638772
char* system_prompt =
639773
NULL; // the (optional) system prompt to use in chat mode
640774

775+
int vocab_size = -1;
776+
int llama_ver = 2;
777+
641778
#if defined(ET_USE_ADPATIVE_THREADS)
642779
uint32_t num_performant_cores = torch::executorch::cpuinfo::get_num_performant_cores();
643780
if (num_performant_cores > 0) {
@@ -682,7 +819,10 @@ int main(int argc, char* argv[]) {
682819
mode = argv[i + 1];
683820
} else if (argv[i][1] == 'y') {
684821
system_prompt = argv[i + 1];
685-
} else {
822+
} else if (argv[i][1] == 'l') {
823+
llama_ver = atoi(argv[i+1]);
824+
}
825+
else {
686826
error_usage();
687827
}
688828
}
@@ -697,27 +837,35 @@ int main(int argc, char* argv[]) {
697837
if (steps < 0)
698838
steps = 0;
699839

840+
841+
if (vocab_size == -1) {
842+
if (llama_ver == 2) {
843+
vocab_size = 32000;
844+
} else {
845+
vocab_size = 128256;
846+
}
847+
}
848+
700849
// build the Transformer via the model .bin file
701850
Transformer transformer;
702851
build_transformer(&transformer, checkpoint_path, vocab_size, steps);
703852

704853
// build the Tokenizer via the tokenizer .bin file
705854
Tokenizer* tokenizer = nullptr;
706855

707-
// Try to load using Tiktoken, if exception then switch to another tokenizer
708-
try {
709-
tokenizer =
710-
new Tiktoken(transformer.config.vocab_size, /*bos*/ 1, /*eos*/ 2);
711-
tokenizer->load(tokenizer_path);
712-
} catch (const std::invalid_argument&) {
713-
fprintf(
714-
stderr,
715-
"Failed to load %s into a Tiktoken tokenizer. Trying sentencepiece tokenizer..\n",
716-
tokenizer_path);
717-
delete tokenizer;
718-
tokenizer =
719-
new BPETokenizer(transformer.config.vocab_size, /*bos*/ 1, /*eos*/ 2);
720-
tokenizer->load(tokenizer_path);
856+
switch (llama_ver) {
857+
case 2:
858+
tokenizer = new BPETokenizer(transformer.config.vocab_size, /*bos*/ 1, /*eos*/ 2);
859+
tokenizer->load(tokenizer_path);
860+
break;
861+
case 3:
862+
tokenizer = new Tiktoken(transformer.config.vocab_size, /*bos*/ 1, /*eos*/ 2);
863+
tokenizer->load(tokenizer_path);
864+
break;
865+
866+
default:
867+
fprintf(stderr, "Cannot load tokenizer for unrecognized llama version %d", llama_ver);
868+
exit(EXIT_FAILURE);
721869
}
722870

723871
// build the Sampler

tokenizer/tiktoken.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,9 +331,9 @@ std::pair<std::vector<uint64_t>, uint64_t> Tiktoken::_encode_with_special_token(
331331

332332
Tiktoken::Tiktoken(int32_t vocab_size, uint64_t bos_tok, uint64_t eos_tok)
333333
: Tokenizer(vocab_size, bos_tok, eos_tok) {
334-
_regex = _create_regex(_pattern);
335334

336-
_special_token_regex = _build_special_token_regex(_special_token_encoder);
335+
// _regex = _create_regex(_pattern);
336+
// _special_token_regex = _build_special_token_regex(_special_token_encoder);
337337
}
338338

339339
void Tiktoken::load(const std::string& path) {
@@ -343,6 +343,10 @@ void Tiktoken::load(const std::string& path) {
343343
_decoder = _build_decoder(_encoder);
344344
_special_token_decoder = _build_decoder(_special_token_encoder);
345345

346+
347+
_regex = _create_regex(_pattern);
348+
_special_token_regex = _build_special_token_regex(_special_token_encoder);
349+
346350
initialized_ = true;
347351
}
348352

0 commit comments

Comments
 (0)