1
1
/* Inference for Llama-2 Transformer model in pure C++ */
2
+ #include < cstdint>
3
+ #include < cstdlib>
2
4
#include < ctype.h>
5
+ #include < iterator>
3
6
#include < math.h>
4
7
#include < stdint.h>
5
8
#include < stdio.h>
6
9
#include < stdlib.h>
7
10
#include < string.h>
8
11
#include < time.h>
9
12
#include < tokenizer.h>
13
+ #include < string>
14
+
10
15
11
16
#ifdef DEBUG
12
17
#include < cassert>
@@ -485,27 +490,184 @@ void read_stdin(const char* guide, char* buffer, size_t bufsize) {
485
490
// python reference and that seemed ok, but this was not thoroughly tested and
486
491
// is not safely implemented, it's more a proof of concept atm.
487
492
493
+ enum class ModelType {
494
+ unknown,
495
+ llama2,
496
+ llama3,
497
+ };
498
+
499
+ ModelType get_model_type (Tokenizer* tokenizer) {
500
+ if (BPETokenizer* t = dynamic_cast <BPETokenizer*>(tokenizer)) {
501
+ return ModelType::llama2;
502
+ } else if (Tiktoken* t = dynamic_cast <Tiktoken*>(tokenizer)) {
503
+ return ModelType::llama3;
504
+ } else {
505
+ return ModelType::unknown;
506
+ }
507
+ }
508
+
509
+ uint64_t get_eot_token (Tokenizer* tokenizer) {
510
+ ModelType model_type = get_model_type (tokenizer);
511
+
512
+ if (model_type == ModelType::llama2) {
513
+ // llama2 uses EOS as EOT token
514
+ return tokenizer->eos_tok ();
515
+ }
516
+
517
+ if (model_type == ModelType::llama3) {
518
+ auto tokens = tokenizer->encode (" <|eot_id|>" , 0 , 0 );
519
+ return tokens[0 ];
520
+ }
521
+
522
+ fprintf (stderr, " No chat template implemnation for model type %d" , model_type);
523
+ exit (EXIT_FAILURE);
524
+ }
525
+
526
+ std::vector<uint64_t > get_initial_prompt_tokens (const char * cli_system_prompt, const char * cli_user_prompt, Tokenizer* tokenizer) {
527
+ char system_prompt[512 ];
528
+ char user_prompt[512 ];
529
+ char rendered_prompt[512 *2 + 200 ]; // the prompt template is ~170 characters. We use 200 to be safe.
530
+
531
+ if (cli_system_prompt != NULL ) {
532
+ strcpy (system_prompt, cli_system_prompt);
533
+ } else {
534
+ read_stdin (" Enter system prompt (optional): " , system_prompt, sizeof (system_prompt));
535
+ }
536
+
537
+ if (cli_user_prompt != NULL ) {
538
+ strcpy (user_prompt, cli_user_prompt);
539
+ } else {
540
+ read_stdin (" User: " , user_prompt, sizeof (user_prompt));
541
+ }
542
+
543
+ ModelType model_type = get_model_type (tokenizer);
544
+ std::vector<uint64_t > tokens;
545
+
546
+ switch (model_type) {
547
+
548
+ case ModelType::llama2:
549
+ if (system_prompt[0 ] != ' \0 ' ) {
550
+ snprintf (
551
+ rendered_prompt,
552
+ sizeof (rendered_prompt)-1 ,
553
+ " [INST] <<SYS>>\n %s\n <</SYS>>\n\n %s [/INST]" ,
554
+ system_prompt,
555
+ user_prompt
556
+ );
557
+ } else {
558
+ // const char prompt_template[] = ;
559
+ snprintf (
560
+ rendered_prompt,
561
+ sizeof (rendered_prompt)-1 ,
562
+ " [INST] %s [/INST]" ,
563
+ user_prompt
564
+ );
565
+ }
566
+
567
+ // We need to add BOS token here and not in template because llama2 tokenizer
568
+ // does not pattern match special tokens
569
+ tokens = tokenizer->encode (rendered_prompt, 1 , 0 );
570
+ break ;
571
+
572
+ case ModelType::llama3:
573
+ if (system_prompt[0 ] != ' \0 ' ) {
574
+ snprintf (
575
+ rendered_prompt,
576
+ sizeof (rendered_prompt)-1 ,
577
+ " <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n %s<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n %s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n " ,
578
+ system_prompt,
579
+ user_prompt
580
+ );
581
+ } else {
582
+ snprintf (
583
+ rendered_prompt,
584
+ sizeof (rendered_prompt)-1 ,
585
+ " <|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n %s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n " ,
586
+ user_prompt
587
+ );
588
+ }
589
+ tokens = tokenizer->encode (rendered_prompt, 0 , 0 );
590
+ break ;
591
+
592
+ default :
593
+ fprintf (stderr, " No chat template implemnation for model type %d" , model_type);
594
+ exit (EXIT_FAILURE);
595
+ }
596
+
597
+ #ifdef DEBUG
598
+ std::cerr << " Start of rendered prompt:" << std::endl;
599
+ std::cerr << rendered_prompt;
600
+ std::cerr << " End of rendered prompt:" << std::endl;
601
+ std::cerr << " Encoded prompt: " ;
602
+ for (int i = 0 ; i < tokens.size (); i++) {
603
+ std::cerr << tokens[i] << " , " ;
604
+ }
605
+ std::cerr << std::endl << std::flush;
606
+ #endif
607
+
608
+ return tokens;
609
+ }
610
+
611
+ std::vector<uint64_t > get_next_user_prompt_tokens (Tokenizer* tokenizer) {
612
+ char user_prompt[512 ];
613
+ char rendered_prompt[512 + 150 ]; // the prompt template is ~100 characters. We use 150 to be safe.
614
+
615
+ read_stdin (" User: " , user_prompt, sizeof (user_prompt));
616
+
617
+ ModelType model_type = get_model_type (tokenizer);
618
+ std::vector<uint64_t > tokens;
619
+
620
+ switch (model_type) {
621
+
622
+ case ModelType::llama2:
623
+ // const char prompt_template[] = ;
624
+ snprintf (rendered_prompt, sizeof (rendered_prompt)-1 , " [INST] %s [/INST]" , user_prompt);
625
+
626
+ // We need to add BOS token here and not in template because llama2 tokenizer
627
+ // does not pattern match special tokens
628
+ tokens = tokenizer->encode (rendered_prompt, /* bos*/ 1 , /* eos*/ 0 );
629
+ break ;
630
+
631
+ case ModelType::llama3:
632
+ snprintf (
633
+ rendered_prompt,
634
+ sizeof (rendered_prompt)-1 ,
635
+ " <|start_header_id|>user<|end_header_id|>\n\n %s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n " ,
636
+ user_prompt
637
+ );
638
+ tokens = tokenizer->encode (rendered_prompt, 0 , 0 );
639
+ break ;
640
+
641
+ default :
642
+ fprintf (stderr, " No chat template implemnation for model type %d" , model_type);
643
+ exit (EXIT_FAILURE);
644
+ }
645
+
646
+
647
+ #ifdef DEBUG
648
+ std::cerr << " Start of rendered prompt:" << std::endl;
649
+ std::cerr << rendered_prompt;
650
+ std::cerr << " End of rendered prompt:" << std::endl;
651
+ std::cerr << " Encoded prompt: " ;
652
+ for (int i = 0 ; i < tokens.size (); i++) {
653
+ std::cerr << tokens[i] << " , " ;
654
+ }
655
+ std::cerr << std::endl << std::flush;
656
+ #endif
657
+
658
+ return tokens;
659
+ }
660
+
661
+
488
662
void chat (
489
663
Transformer* transformer,
490
664
Tokenizer* tokenizer,
491
665
Sampler* sampler,
492
666
const char * cli_user_prompt,
493
667
const char * cli_system_prompt,
494
668
int steps) {
495
- // special tokens
496
- const int SOS_TOKEN = tokenizer->bos_tok (); // token starts the assistant turn
497
- const int EOS_TOKEN = tokenizer->eos_tok (); // token ends the assistant turn
498
- const int SYSTEM_PROMPT_SIZE = 512 ;
499
- const int USER_PROMPT_SIZE = 512 ;
500
- const int RENDERED_PROMPT_SIZE = SYSTEM_PROMPT_SIZE + USER_PROMPT_SIZE + 128 ; // This is big enough to hold the expanded template
501
-
502
-
503
669
504
- // buffers for reading the system prompt and user prompt from stdin
505
- // you'll notice they are soomewhat haphazardly and unsafely set atm
506
- char system_prompt[SYSTEM_PROMPT_SIZE];
507
- char user_prompt[USER_PROMPT_SIZE];
508
- char rendered_prompt[RENDERED_PROMPT_SIZE];
670
+ const uint64_t EOT_TOKEN = get_eot_token (tokenizer);
509
671
int num_prompt_tokens = 0 ;
510
672
std::vector<uint64_t > prompt_tokens;
511
673
int user_idx;
@@ -522,41 +684,10 @@ void chat(
522
684
if (user_turn) {
523
685
// get the (optional) system prompt at position 0
524
686
if (pos == 0 ) {
525
- // at position 0, the user can also contribute a system prompt
526
- if (cli_system_prompt == NULL ) {
527
- // system prompt was not passed in, attempt to get it from stdin
528
- read_stdin (
529
- " Enter system prompt (optional): " ,
530
- system_prompt,
531
- sizeof (system_prompt));
532
- } else {
533
- // system prompt was passed in, use it
534
- strcpy (system_prompt, cli_system_prompt);
535
- }
536
- }
537
- // get the user prompt
538
- if (pos == 0 && cli_user_prompt != NULL ) {
539
- // user prompt for position 0 was passed in, use it
540
- strcpy (user_prompt, cli_user_prompt);
687
+ prompt_tokens = get_initial_prompt_tokens (cli_system_prompt, cli_user_prompt, tokenizer);
541
688
} else {
542
- // otherwise get user prompt from stdin
543
- read_stdin (" User: " , user_prompt, sizeof (user_prompt));
689
+ prompt_tokens = get_next_user_prompt_tokens (tokenizer);
544
690
}
545
- // render user/system prompts into the Llama 2 Chat schema
546
- if (pos == 0 && system_prompt[0 ] != ' \0 ' ) {
547
- // We do not add <s> because that is added by tokenizer->encode(x, 1, 0)
548
- const char system_template[] = " [INST] <<SYS>>\n %s\n <</SYS>>\n\n %s [/INST]" ;
549
- snprintf (
550
- rendered_prompt, RENDERED_PROMPT_SIZE-1 , system_template, system_prompt, user_prompt);
551
- } else {
552
- // Assistant should produce </s>, so we do not include it in template
553
- // We do not add <s> because that is added by tokenizer->encode(x, 1, 0)
554
- const char user_template[] = " [INST] %s [/INST]" ;
555
- snprintf (rendered_prompt, RENDERED_PROMPT_SIZE-1 , user_template, user_prompt);
556
- }
557
-
558
- // encode the rendered prompt into tokens
559
- prompt_tokens = tokenizer->encode (rendered_prompt, 1 , 0 );
560
691
num_prompt_tokens = prompt_tokens.size ();
561
692
562
693
user_idx = 0 ; // reset the user index
@@ -578,19 +709,21 @@ void chat(
578
709
float * logits = forward (transformer, token, pos);
579
710
next = sample (sampler, logits);
580
711
712
+ // std::cout << "TOKEN: " << token << " NEXT: " << next << std::endl;
581
713
582
- if (token == EOS_TOKEN) {
714
+
715
+ if ((user_idx >= num_prompt_tokens) && (token == EOT_TOKEN)) {
583
716
user_turn = 1 ;
584
717
}
585
718
586
- if (user_idx >= num_prompt_tokens && token != EOS_TOKEN && next != EOS_TOKEN ) {
719
+ if (user_idx >= num_prompt_tokens && token != EOT_TOKEN && next != EOT_TOKEN ) {
587
720
std::string piece = tokenizer->decode (token, next);
588
721
safe_printf (piece.c_str ()); // same as printf("%s", piece), but skips
589
722
// "unsafe" bytes
590
723
fflush (stdout);
591
724
}
592
725
593
- if (next == EOS_TOKEN ) {
726
+ if (next == EOT_TOKEN ) {
594
727
printf (" \n " );
595
728
}
596
729
pos++;
@@ -619,6 +752,7 @@ void error_usage() {
619
752
fprintf (stderr, " -z <string> optional path to custom tokenizer\n " );
620
753
fprintf (stderr, " -m <string> mode: generate|chat, default: generate\n " );
621
754
fprintf (stderr, " -y <string> (optional) system prompt in chat mode\n " );
755
+ fprintf (stderr, " -l <int> (optional) llama version (2 or 3). Defaults to 2.\n " );
622
756
exit (EXIT_FAILURE);
623
757
}
624
758
@@ -630,14 +764,17 @@ int main(int argc, char* argv[]) {
630
764
1 .0f ; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
631
765
float topp =
632
766
0 .9f ; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
633
- int vocab_size = 32000 ;
767
+
634
768
int steps = 256 ; // number of steps to run for
635
769
const char * prompt = NULL ; // prompt string
636
770
unsigned long long rng_seed = 0 ; // seed rng with time by default
637
771
const char * mode = " generate" ; // generate|chat
638
772
char * system_prompt =
639
773
NULL ; // the (optional) system prompt to use in chat mode
640
774
775
+ int vocab_size = -1 ;
776
+ int llama_ver = 2 ;
777
+
641
778
#if defined(ET_USE_ADPATIVE_THREADS)
642
779
uint32_t num_performant_cores = torch::executorch::cpuinfo::get_num_performant_cores ();
643
780
if (num_performant_cores > 0 ) {
@@ -682,7 +819,10 @@ int main(int argc, char* argv[]) {
682
819
mode = argv[i + 1 ];
683
820
} else if (argv[i][1 ] == ' y' ) {
684
821
system_prompt = argv[i + 1 ];
685
- } else {
822
+ } else if (argv[i][1 ] == ' l' ) {
823
+ llama_ver = atoi (argv[i+1 ]);
824
+ }
825
+ else {
686
826
error_usage ();
687
827
}
688
828
}
@@ -697,27 +837,35 @@ int main(int argc, char* argv[]) {
697
837
if (steps < 0 )
698
838
steps = 0 ;
699
839
840
+
841
+ if (vocab_size == -1 ) {
842
+ if (llama_ver == 2 ) {
843
+ vocab_size = 32000 ;
844
+ } else {
845
+ vocab_size = 128256 ;
846
+ }
847
+ }
848
+
700
849
// build the Transformer via the model .bin file
701
850
Transformer transformer;
702
851
build_transformer (&transformer, checkpoint_path, vocab_size, steps);
703
852
704
853
// build the Tokenizer via the tokenizer .bin file
705
854
Tokenizer* tokenizer = nullptr ;
706
855
707
- // Try to load using Tiktoken, if exception then switch to another tokenizer
708
- try {
709
- tokenizer =
710
- new Tiktoken (transformer.config .vocab_size , /* bos*/ 1 , /* eos*/ 2 );
711
- tokenizer->load (tokenizer_path);
712
- } catch (const std::invalid_argument&) {
713
- fprintf (
714
- stderr,
715
- " Failed to load %s into a Tiktoken tokenizer. Trying sentencepiece tokenizer..\n " ,
716
- tokenizer_path);
717
- delete tokenizer;
718
- tokenizer =
719
- new BPETokenizer (transformer.config .vocab_size , /* bos*/ 1 , /* eos*/ 2 );
720
- tokenizer->load (tokenizer_path);
856
+ switch (llama_ver) {
857
+ case 2 :
858
+ tokenizer = new BPETokenizer (transformer.config .vocab_size , /* bos*/ 1 , /* eos*/ 2 );
859
+ tokenizer->load (tokenizer_path);
860
+ break ;
861
+ case 3 :
862
+ tokenizer = new Tiktoken (transformer.config .vocab_size , /* bos*/ 1 , /* eos*/ 2 );
863
+ tokenizer->load (tokenizer_path);
864
+ break ;
865
+
866
+ default :
867
+ fprintf (stderr, " Cannot load tokenizer for unrecognized llama version %d" , llama_ver);
868
+ exit (EXIT_FAILURE);
721
869
}
722
870
723
871
// build the Sampler
0 commit comments