1
1
/* Inference for Llama-2 Transformer model in pure C++ */
2
- #include < cstdint>
3
- #include < cstdlib>
4
2
#include < ctype.h>
5
- #include < iterator>
6
3
#include < math.h>
7
4
#include < stdint.h>
8
5
#include < stdio.h>
9
6
#include < stdlib.h>
10
7
#include < string.h>
11
8
#include < time.h>
12
9
#include < tokenizer.h>
10
+ #include < cstdint>
11
+ #include < cstdlib>
12
+ #include < iterator>
13
13
#include < string>
14
14
15
-
16
15
#ifdef DEBUG
17
16
#include < cassert>
18
17
#include < iostream>
@@ -167,22 +166,14 @@ float* forward(Transformer* transformer, int token, int pos) {
167
166
torch::Tensor pos_tensor = torch::from_blob (pos_buffer, {1 }, torch::kLong );
168
167
std::vector<torch::Tensor> inputs{token_tensor, pos_tensor};
169
168
170
- torch::Tensor result = transformer->runner ->run (inputs)[0 ].to (torch::dtype (torch::kFloat32 ));
169
+ torch::Tensor result =
170
+ transformer->runner ->run (inputs)[0 ].to (torch::dtype (torch::kFloat32 ));
171
171
auto logits = result[0 ].data_ptr ();
172
172
173
173
#else // __ET_MODEL__
174
174
ManagedTensor pos_managed (pos_buffer, sizeof (int64_t ), {1 }, ScalarType::Long);
175
- #ifndef __KV_CACHE__
176
- // @lint-ignore CLANGTIDY facebook-hte-LocalUncheckedArrayBounds
177
- ManagedTensor tokens_managed (
178
- &(s->toks [pos]),
179
- /* ignored*/ sizeof (int64_t ) * (pos + 1 ),
180
- {1 , 1 },
181
- ScalarType::Long);
182
- #else // __KV_CACHE__
183
175
ManagedTensor tokens_managed (
184
176
token_buffer, sizeof (int64_t ), {1 , 1 }, ScalarType::Long);
185
- #endif
186
177
std::vector<EValue> inputs;
187
178
auto tmp1 = EValue (tokens_managed.get_aliasing_tensor ());
188
179
auto tmp2 = EValue (pos_managed.get_aliasing_tensor ());
@@ -491,9 +482,9 @@ void read_stdin(const char* guide, char* buffer, size_t bufsize) {
491
482
// is not safely implemented, it's more a proof of concept atm.
492
483
493
484
enum class ModelType {
494
- unknown,
495
- llama2,
496
- llama3,
485
+ unknown,
486
+ llama2,
487
+ llama3,
497
488
};
498
489
499
490
ModelType get_model_type (Tokenizer* tokenizer) {
@@ -519,19 +510,27 @@ uint64_t get_eot_token(Tokenizer* tokenizer) {
519
510
return tokens[0 ];
520
511
}
521
512
522
- fprintf (stderr, " No chat template implemnation for model type %d" , model_type);
513
+ fprintf (
514
+ stderr, " No chat template implemnation for model type %d" , model_type);
523
515
exit (EXIT_FAILURE);
524
516
}
525
517
526
- std::vector<uint64_t > get_initial_prompt_tokens (const char * cli_system_prompt, const char * cli_user_prompt, Tokenizer* tokenizer) {
518
+ std::vector<uint64_t > get_initial_prompt_tokens (
519
+ const char * cli_system_prompt,
520
+ const char * cli_user_prompt,
521
+ Tokenizer* tokenizer) {
527
522
char system_prompt[512 ];
528
523
char user_prompt[512 ];
529
- char rendered_prompt[512 *2 + 200 ]; // the prompt template is ~170 characters. We use 200 to be safe.
524
+ char rendered_prompt[512 * 2 + 200 ]; // the prompt template is ~170
525
+ // characters. We use 200 to be safe.
530
526
531
527
if (cli_system_prompt != NULL ) {
532
528
strcpy (system_prompt, cli_system_prompt);
533
529
} else {
534
- read_stdin (" Enter system prompt (optional): " , system_prompt, sizeof (system_prompt));
530
+ read_stdin (
531
+ " Enter system prompt (optional): " ,
532
+ system_prompt,
533
+ sizeof (system_prompt));
535
534
}
536
535
537
536
if (cli_user_prompt != NULL ) {
@@ -540,111 +539,114 @@ std::vector<uint64_t> get_initial_prompt_tokens(const char* cli_system_prompt, c
540
539
read_stdin (" User: " , user_prompt, sizeof (user_prompt));
541
540
}
542
541
543
- ModelType model_type = get_model_type (tokenizer);
544
- std::vector<uint64_t > tokens;
545
-
546
- switch (model_type) {
542
+ ModelType model_type = get_model_type (tokenizer);
543
+ std::vector<uint64_t > tokens;
547
544
545
+ switch (model_type) {
548
546
case ModelType::llama2:
549
547
if (system_prompt[0 ] != ' \0 ' ) {
550
548
snprintf (
551
- rendered_prompt,
552
- sizeof (rendered_prompt)-1 ,
553
- " [INST] <<SYS>>\n %s\n <</SYS>>\n\n %s [/INST]" ,
554
- system_prompt,
555
- user_prompt
556
- );
549
+ rendered_prompt,
550
+ sizeof (rendered_prompt) - 1 ,
551
+ " [INST] <<SYS>>\n %s\n <</SYS>>\n\n %s [/INST]" ,
552
+ system_prompt,
553
+ user_prompt);
557
554
} else {
558
555
// const char prompt_template[] = ;
559
556
snprintf (
560
- rendered_prompt,
561
- sizeof (rendered_prompt)-1 ,
562
- " [INST] %s [/INST]" ,
563
- user_prompt
564
- );
557
+ rendered_prompt,
558
+ sizeof (rendered_prompt) - 1 ,
559
+ " [INST] %s [/INST]" ,
560
+ user_prompt);
565
561
}
566
562
567
- // We need to add BOS token here and not in template because llama2 tokenizer
568
- // does not pattern match special tokens
563
+ // We need to add BOS token here and not in template because llama2
564
+ // tokenizer does not pattern match special tokens
569
565
tokens = tokenizer->encode (rendered_prompt, 1 , 0 );
570
566
break ;
571
567
572
568
case ModelType::llama3:
573
569
if (system_prompt[0 ] != ' \0 ' ) {
574
570
snprintf (
575
- rendered_prompt,
576
- sizeof (rendered_prompt)-1 ,
577
- " <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n %s<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n %s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n " ,
578
- system_prompt,
579
- user_prompt
580
- );
571
+ rendered_prompt,
572
+ sizeof (rendered_prompt) - 1 ,
573
+ " <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n %s<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n %s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n " ,
574
+ system_prompt,
575
+ user_prompt);
581
576
} else {
582
577
snprintf (
583
- rendered_prompt,
584
- sizeof (rendered_prompt)-1 ,
585
- " <|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n %s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n " ,
586
- user_prompt
587
- );
578
+ rendered_prompt,
579
+ sizeof (rendered_prompt) - 1 ,
580
+ " <|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n %s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n " ,
581
+ user_prompt);
588
582
}
589
583
tokens = tokenizer->encode (rendered_prompt, 0 , 0 );
590
584
break ;
591
585
592
586
default :
593
- fprintf (stderr, " No chat template implemnation for model type %d" , model_type);
587
+ fprintf (
588
+ stderr,
589
+ " No chat template implemnation for model type %d" ,
590
+ model_type);
594
591
exit (EXIT_FAILURE);
595
- }
592
+ }
596
593
597
- #ifdef DEBUG
598
- std::cerr << " Start of rendered prompt:" << std::endl;
599
- std::cerr << rendered_prompt;
600
- std::cerr << " End of rendered prompt:" << std::endl;
601
- std::cerr << " Encoded prompt: " ;
602
- for (int i = 0 ; i < tokens.size (); i++) {
603
- std::cerr << tokens[i] << " , " ;
604
- }
605
- std::cerr << std::endl << std::flush;
606
- #endif
594
+ #ifdef DEBUG
595
+ std::cerr << " Start of rendered prompt:" << std::endl;
596
+ std::cerr << rendered_prompt;
597
+ std::cerr << " End of rendered prompt:" << std::endl;
598
+ std::cerr << " Encoded prompt: " ;
599
+ for (int i = 0 ; i < tokens.size (); i++) {
600
+ std::cerr << tokens[i] << " , " ;
601
+ }
602
+ std::cerr << std::endl << std::flush;
603
+ #endif
607
604
608
- return tokens;
605
+ return tokens;
609
606
}
610
607
611
608
std::vector<uint64_t > get_next_user_prompt_tokens (Tokenizer* tokenizer) {
612
609
char user_prompt[512 ];
613
- char rendered_prompt[512 + 150 ]; // the prompt template is ~100 characters. We use 150 to be safe.
610
+ char rendered_prompt[512 + 150 ]; // the prompt template is ~100 characters. We
611
+ // use 150 to be safe.
614
612
615
613
read_stdin (" User: " , user_prompt, sizeof (user_prompt));
616
614
617
615
ModelType model_type = get_model_type (tokenizer);
618
616
std::vector<uint64_t > tokens;
619
617
620
618
switch (model_type) {
621
-
622
619
case ModelType::llama2:
623
620
// const char prompt_template[] = ;
624
- snprintf (rendered_prompt, sizeof (rendered_prompt)-1 , " [INST] %s [/INST]" , user_prompt);
621
+ snprintf (
622
+ rendered_prompt,
623
+ sizeof (rendered_prompt) - 1 ,
624
+ " [INST] %s [/INST]" ,
625
+ user_prompt);
625
626
626
- // We need to add BOS token here and not in template because llama2 tokenizer
627
- // does not pattern match special tokens
628
- tokens = tokenizer->encode (rendered_prompt, /* bos*/ 1 , /* eos*/ 0 );
627
+ // We need to add BOS token here and not in template because llama2
628
+ // tokenizer does not pattern match special tokens
629
+ tokens = tokenizer->encode (rendered_prompt, /* bos*/ 1 , /* eos*/ 0 );
629
630
break ;
630
631
631
632
case ModelType::llama3:
632
633
snprintf (
633
- rendered_prompt,
634
- sizeof (rendered_prompt)-1 ,
635
- " <|start_header_id|>user<|end_header_id|>\n\n %s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n " ,
636
- user_prompt
637
- );
634
+ rendered_prompt,
635
+ sizeof (rendered_prompt) - 1 ,
636
+ " <|start_header_id|>user<|end_header_id|>\n\n %s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n " ,
637
+ user_prompt);
638
638
tokens = tokenizer->encode (rendered_prompt, 0 , 0 );
639
639
break ;
640
640
641
641
default :
642
- fprintf (stderr, " No chat template implemnation for model type %d" , model_type);
642
+ fprintf (
643
+ stderr,
644
+ " No chat template implemnation for model type %d" ,
645
+ model_type);
643
646
exit (EXIT_FAILURE);
644
647
}
645
648
646
-
647
- #ifdef DEBUG
649
+ #ifdef DEBUG
648
650
std::cerr << " Start of rendered prompt:" << std::endl;
649
651
std::cerr << rendered_prompt;
650
652
std::cerr << " End of rendered prompt:" << std::endl;
@@ -653,20 +655,18 @@ std::vector<uint64_t> get_next_user_prompt_tokens(Tokenizer* tokenizer) {
653
655
std::cerr << tokens[i] << " , " ;
654
656
}
655
657
std::cerr << std::endl << std::flush;
656
- #endif
658
+ #endif
657
659
658
660
return tokens;
659
661
}
660
662
661
-
662
663
void chat (
663
664
Transformer* transformer,
664
665
Tokenizer* tokenizer,
665
666
Sampler* sampler,
666
667
const char * cli_user_prompt,
667
668
const char * cli_system_prompt,
668
669
int steps) {
669
-
670
670
const uint64_t EOT_TOKEN = get_eot_token (tokenizer);
671
671
int num_prompt_tokens = 0 ;
672
672
std::vector<uint64_t > prompt_tokens;
@@ -679,12 +679,12 @@ void chat(
679
679
int prev_token;
680
680
int pos = 0 ; // position in the sequence
681
681
while (pos < steps) {
682
-
683
682
// when it is the user's turn to contribute tokens to the dialog...
684
683
if (user_turn) {
685
684
// get the (optional) system prompt at position 0
686
685
if (pos == 0 ) {
687
- prompt_tokens = get_initial_prompt_tokens (cli_system_prompt, cli_user_prompt, tokenizer);
686
+ prompt_tokens = get_initial_prompt_tokens (
687
+ cli_system_prompt, cli_user_prompt, tokenizer);
688
688
} else {
689
689
prompt_tokens = get_next_user_prompt_tokens (tokenizer);
690
690
}
@@ -711,12 +711,12 @@ void chat(
711
711
712
712
// std::cout << "TOKEN: " << token << " NEXT: " << next << std::endl;
713
713
714
-
715
714
if ((user_idx >= num_prompt_tokens) && (token == EOT_TOKEN)) {
716
715
user_turn = 1 ;
717
716
}
718
717
719
- if (user_idx >= num_prompt_tokens && token != EOT_TOKEN && next != EOT_TOKEN) {
718
+ if (user_idx >= num_prompt_tokens && token != EOT_TOKEN &&
719
+ next != EOT_TOKEN) {
720
720
std::string piece = tokenizer->decode (token, next);
721
721
safe_printf (piece.c_str ()); // same as printf("%s", piece), but skips
722
722
// "unsafe" bytes
@@ -727,7 +727,6 @@ void chat(
727
727
printf (" \n " );
728
728
}
729
729
pos++;
730
-
731
730
}
732
731
printf (" \n " );
733
732
}
@@ -752,7 +751,9 @@ void error_usage() {
752
751
fprintf (stderr, " -z <string> optional path to custom tokenizer\n " );
753
752
fprintf (stderr, " -m <string> mode: generate|chat, default: generate\n " );
754
753
fprintf (stderr, " -y <string> (optional) system prompt in chat mode\n " );
755
- fprintf (stderr, " -l <int> (optional) llama version (2 or 3). Defaults to 2.\n " );
754
+ fprintf (
755
+ stderr,
756
+ " -l <int> (optional) llama version (2 or 3). Defaults to 2.\n " );
756
757
exit (EXIT_FAILURE);
757
758
}
758
759
@@ -776,7 +777,8 @@ int main(int argc, char* argv[]) {
776
777
int llama_ver = 2 ;
777
778
778
779
#if defined(ET_USE_ADPATIVE_THREADS)
779
- uint32_t num_performant_cores = torch::executorch::cpuinfo::get_num_performant_cores ();
780
+ uint32_t num_performant_cores =
781
+ torch::executorch::cpuinfo::get_num_performant_cores ();
780
782
if (num_performant_cores > 0 ) {
781
783
torch::executorch::threadpool::get_threadpool ()->_unsafe_reset_threadpool (
782
784
num_performant_cores);
@@ -820,9 +822,8 @@ int main(int argc, char* argv[]) {
820
822
} else if (argv[i][1 ] == ' y' ) {
821
823
system_prompt = argv[i + 1 ];
822
824
} else if (argv[i][1 ] == ' l' ) {
823
- llama_ver = atoi (argv[i+1 ]);
824
- }
825
- else {
825
+ llama_ver = atoi (argv[i + 1 ]);
826
+ } else {
826
827
error_usage ();
827
828
}
828
829
}
@@ -837,7 +838,6 @@ int main(int argc, char* argv[]) {
837
838
if (steps < 0 )
838
839
steps = 0 ;
839
840
840
-
841
841
if (vocab_size == -1 ) {
842
842
if (llama_ver == 2 ) {
843
843
vocab_size = 32000 ;
@@ -855,16 +855,21 @@ int main(int argc, char* argv[]) {
855
855
856
856
switch (llama_ver) {
857
857
case 2 :
858
- tokenizer = new BPETokenizer (transformer.config .vocab_size , /* bos*/ 1 , /* eos*/ 2 );
858
+ tokenizer =
859
+ new BPETokenizer (transformer.config .vocab_size , /* bos*/ 1 , /* eos*/ 2 );
859
860
tokenizer->load (tokenizer_path);
860
861
break ;
861
862
case 3 :
862
- tokenizer = new Tiktoken (transformer.config .vocab_size , /* bos*/ 1 , /* eos*/ 2 );
863
+ tokenizer =
864
+ new Tiktoken (transformer.config .vocab_size , /* bos*/ 1 , /* eos*/ 2 );
863
865
tokenizer->load (tokenizer_path);
864
866
break ;
865
867
866
868
default :
867
- fprintf (stderr, " Cannot load tokenizer for unrecognized llama version %d" , llama_ver);
869
+ fprintf (
870
+ stderr,
871
+ " Cannot load tokenizer for unrecognized llama version %d" ,
872
+ llama_ver);
868
873
exit (EXIT_FAILURE);
869
874
}
870
875
0 commit comments