@@ -49,22 +49,22 @@ using torch::executor::Result;
49
49
// ----------------------------------------------------------------------------
50
50
// Transformer model
51
51
52
- enum class ModelType {
53
- unknown = 0 ,
54
- llama2 = 2 ,
55
- llama3 = 3 ,
52
+ enum ModelType {
53
+ UNKNOWN_MODEL = 0 ,
54
+ LLAMA2_MODEL = 2 ,
55
+ LLAMA3_MODEL = 3 ,
56
56
};
57
57
58
58
ModelType get_model_type (int model_int) {
59
59
switch (model_int) {
60
60
case 2 :
61
- return ModelType::llama2 ;
61
+ return LLAMA2_MODEL ;
62
62
break ;
63
63
case 3 :
64
- return ModelType::llama3 ;
64
+ return LLAMA3_MODEL ;
65
65
break ;
66
66
default :
67
- return ModelType::unknown ;
67
+ return UNKNOWN_MODEL ;
68
68
}
69
69
}
70
70
@@ -381,20 +381,19 @@ Tokenizer* build_tokenizer(
381
381
const char * tokenizer_path,
382
382
ModelType model_type,
383
383
int vocab_size) {
384
- Tokenizer* tokenizer = nullptr ;
384
+ Tokenizer* tokenizer = NULL ;
385
385
switch (model_type) {
386
- case ModelType::llama2 :
386
+ case LLAMA2_MODEL :
387
387
tokenizer = new BPETokenizer (vocab_size, /* bos*/ 1 , /* eos*/ 2 );
388
388
tokenizer->load (tokenizer_path);
389
389
break ;
390
- case ModelType::llama3 :
390
+ case LLAMA3_MODEL :
391
391
tokenizer = new Tiktoken (vocab_size, /* bos*/ 1 , /* eos*/ 2 );
392
392
tokenizer->load (tokenizer_path);
393
393
break ;
394
394
default :
395
- throw std::runtime_error (
396
- " No tokenizer defined for model type " +
397
- std::to_string (static_cast <int >(model_type)));
395
+ fprintf (stderr, " No tokenizer defined for model type %d.\n " , model_type);
396
+ exit (EXIT_FAILURE);
398
397
}
399
398
return tokenizer;
400
399
}
@@ -410,7 +409,7 @@ void safe_printf(const char* piece) {
410
409
// piece might be a raw byte token, and we only want to print printable chars
411
410
// or whitespace because some of the other bytes can be various control codes,
412
411
// backspace, etc.
413
- if (piece == nullptr ) {
412
+ if (piece == NULL ) {
414
413
return ;
415
414
}
416
415
if (piece[0 ] == ' \0 ' ) {
@@ -539,7 +538,7 @@ void generate(
539
538
int steps,
540
539
ModelType model_type) {
541
540
const char * default_prompt = " Once upon a time" ;
542
- if (prompt == nullptr ) {
541
+ if (prompt == NULL ) {
543
542
prompt = default_prompt;
544
543
}
545
544
@@ -550,11 +549,11 @@ void generate(
550
549
std::vector<uint64_t > prompt_tokens;
551
550
std::vector<uint64_t > stop_tokens;
552
551
switch (model_type) {
553
- case ModelType::llama2 :
552
+ case LLAMA2_MODEL :
554
553
prompt_tokens = tokenizer->encode (prompt, 1 , 0 );
555
554
stop_tokens.push_back (tokenizer->eos_tok ());
556
555
break ;
557
- case ModelType::llama3 :
556
+ case LLAMA3_MODEL :
558
557
prompt_tokens = tokenizer->encode (prompt, 0 , 0 );
559
558
prompt_tokens.insert (
560
559
prompt_tokens.begin (),
@@ -563,9 +562,8 @@ void generate(
563
562
stop_tokens.push_back (tokenizer->encode (" <|eot_id|>" , 0 , 0 )[0 ]);
564
563
break ;
565
564
default :
566
- throw std::runtime_error (
567
- " Generate does not support model type " +
568
- std::to_string (static_cast <int >(model_type)));
565
+ fprintf (stderr, " Generate does not support model type %d.\n " , model_type);
566
+ exit (EXIT_FAILURE);
569
567
}
570
568
571
569
generate_from_prompt_tokens (
@@ -583,7 +581,7 @@ void generate(
583
581
void read_stdin (const char * guide, char * buffer, size_t bufsize) {
584
582
// read a line from stdin, up to but not including \n
585
583
printf (" %s" , guide);
586
- if (fgets (buffer, bufsize, stdin) != nullptr ) {
584
+ if (fgets (buffer, bufsize, stdin) != NULL ) {
587
585
size_t len = strlen (buffer);
588
586
if (len > 0 && buffer[len - 1 ] == ' \n ' ) {
589
587
buffer[len - 1 ] = ' \0 ' ; // strip newline
@@ -607,7 +605,7 @@ std::vector<uint64_t> get_initial_prompt_tokens(
607
605
char rendered_prompt[512 * 2 + 200 ]; // the prompt template is ~170
608
606
// characters. We use 200 to be safe.
609
607
610
- if (cli_system_prompt != nullptr ) {
608
+ if (cli_system_prompt != NULL ) {
611
609
strcpy (system_prompt, cli_system_prompt);
612
610
} else {
613
611
read_stdin (
@@ -616,7 +614,7 @@ std::vector<uint64_t> get_initial_prompt_tokens(
616
614
sizeof (system_prompt));
617
615
}
618
616
619
- if (cli_user_prompt != nullptr ) {
617
+ if (cli_user_prompt != NULL ) {
620
618
strcpy (user_prompt, cli_user_prompt);
621
619
} else {
622
620
read_stdin (" User: " , user_prompt, sizeof (user_prompt));
@@ -625,7 +623,7 @@ std::vector<uint64_t> get_initial_prompt_tokens(
625
623
std::vector<uint64_t > tokens;
626
624
627
625
switch (model_type) {
628
- case ModelType::llama2 :
626
+ case LLAMA2_MODEL :
629
627
if (system_prompt[0 ] != ' \0 ' ) {
630
628
snprintf (
631
629
rendered_prompt,
@@ -646,7 +644,7 @@ std::vector<uint64_t> get_initial_prompt_tokens(
646
644
tokens = tokenizer->encode (rendered_prompt, 1 , 0 );
647
645
break ;
648
646
649
- case ModelType::llama3 :
647
+ case LLAMA3_MODEL :
650
648
if (system_prompt[0 ] != ' \0 ' ) {
651
649
snprintf (
652
650
rendered_prompt,
@@ -665,9 +663,8 @@ std::vector<uint64_t> get_initial_prompt_tokens(
665
663
break ;
666
664
667
665
default :
668
- throw std::runtime_error (
669
- " Chat does not support model type " +
670
- std::to_string (static_cast <int >(model_type)));
666
+ fprintf (stderr, " Chat does not support model type %d.\n " , model_type);
667
+ exit (EXIT_FAILURE);
671
668
}
672
669
673
670
#ifdef DEBUG
@@ -695,7 +692,7 @@ std::vector<uint64_t> get_next_user_prompt_tokens(
695
692
std::vector<uint64_t > tokens;
696
693
697
694
switch (model_type) {
698
- case ModelType::llama2 :
695
+ case LLAMA2_MODEL :
699
696
snprintf (
700
697
rendered_prompt,
701
698
sizeof (rendered_prompt) - 1 ,
@@ -707,7 +704,7 @@ std::vector<uint64_t> get_next_user_prompt_tokens(
707
704
tokens = tokenizer->encode (rendered_prompt, /* bos*/ 1 , /* eos*/ 0 );
708
705
break ;
709
706
710
- case ModelType::llama3 :
707
+ case LLAMA3_MODEL :
711
708
snprintf (
712
709
rendered_prompt,
713
710
sizeof (rendered_prompt) - 1 ,
@@ -717,9 +714,8 @@ std::vector<uint64_t> get_next_user_prompt_tokens(
717
714
break ;
718
715
719
716
default :
720
- throw std::runtime_error (
721
- " Chat does not support model type " +
722
- std::to_string (static_cast <int >(model_type)));
717
+ fprintf (stderr, " Chat does not support model type %d.\n " , model_type);
718
+ exit (EXIT_FAILURE);
723
719
}
724
720
725
721
#ifdef DEBUG
@@ -751,17 +747,16 @@ void chat(
751
747
uint64_t eot_token;
752
748
std::vector<uint64_t > prompt_tokens;
753
749
switch (model_type) {
754
- case ModelType::llama2 :
750
+ case LLAMA2_MODEL :
755
751
// llama2 uses EOS as EOT token
756
752
eot_token = tokenizer->eos_tok ();
757
753
break ;
758
- case ModelType::llama3 :
754
+ case LLAMA3_MODEL :
759
755
eot_token = tokenizer->encode (" <|eot_id|>" , 0 , 0 )[0 ];
760
756
break ;
761
757
default :
762
- throw std::runtime_error (
763
- " Chat does not support model type " +
764
- std::to_string (static_cast <int >(model_type)));
758
+ fprintf (stderr, " Chat does not support model type %d.\n " , model_type);
759
+ exit (EXIT_FAILURE);
765
760
}
766
761
767
762
std::vector<uint64_t > stop_tokens{eot_token};
@@ -801,7 +796,7 @@ void error_usage() {
801
796
fprintf (
802
797
stderr,
803
798
" -p <float> p value in top-p (nucleus) sampling in [0,1], default 0.9\n " );
804
- fprintf (stderr, " -s <int> random seed, default time(nullptr )\n " );
799
+ fprintf (stderr, " -s <int> random seed, default time(NULL )\n " );
805
800
fprintf (
806
801
stderr,
807
802
" -n <int> number of steps to run for, default 256. 0 = max_seq_len\n " );
@@ -819,19 +814,19 @@ void error_usage() {
819
814
820
815
int main (int argc, char * argv[]) {
821
816
// default parameters
822
- char * model_path = nullptr ;
823
- char * tokenizer_path = nullptr ;
817
+ char * model_path = NULL ;
818
+ char * tokenizer_path = NULL ;
824
819
float temperature =
825
820
1 .0f ; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
826
821
float topp = 0 .9f ; // top-p in nucleus sampling. 1.0 = off. 0.9 works well,
827
822
// but slower
828
823
829
824
int steps = 256 ; // number of steps to run for
830
- const char * prompt = nullptr ; // prompt string
825
+ const char * prompt = NULL ; // prompt string
831
826
unsigned long long rng_seed = 0 ; // seed rng with time by default
832
827
const char * mode = " generate" ; // generate|chat
833
828
char * system_prompt =
834
- nullptr ; // the (optional) system prompt to use in chat mode
829
+ NULL ; // the (optional) system prompt to use in chat mode
835
830
836
831
int vocab_size = -1 ;
837
832
int llama_ver = 2 ;
@@ -889,27 +884,27 @@ int main(int argc, char* argv[]) {
889
884
}
890
885
891
886
ModelType model_type = get_model_type (llama_ver);
892
- if (model_type == ModelType::unknown ) {
887
+ if (model_type == UNKNOWN_MODEL ) {
893
888
fprintf (
894
889
stderr,
895
890
" Unknown model type passed by -l argument. Received l=%d." ,
896
891
llama_ver);
897
892
error_usage ();
898
893
}
899
894
900
- if (model_path == nullptr ) {
895
+ if (model_path == NULL ) {
901
896
fprintf (stderr, " No model_path provided." );
902
897
error_usage ();
903
898
}
904
899
905
- if (tokenizer_path == nullptr ) {
900
+ if (tokenizer_path == NULL ) {
906
901
fprintf (stderr, " No tokenizer_path provided." );
907
902
error_usage ();
908
903
}
909
904
910
905
// parameter validation/overrides
911
906
if (rng_seed <= 0 )
912
- rng_seed = (unsigned int )time (nullptr );
907
+ rng_seed = (unsigned int )time (NULL );
913
908
if (temperature < 0.0 )
914
909
temperature = 0.0 ;
915
910
if (topp < 0.0 || 1.0 < topp)
@@ -920,16 +915,16 @@ int main(int argc, char* argv[]) {
920
915
// If no tokenizer path provided, get default for model_type
921
916
if (vocab_size == -1 ) {
922
917
switch (model_type) {
923
- case ModelType::llama2 :
918
+ case LLAMA2_MODEL :
924
919
vocab_size = 32000 ;
925
920
break ;
926
- case ModelType::llama3 :
921
+ case LLAMA3_MODEL :
927
922
vocab_size = 128256 ;
928
923
break ;
929
924
default :
930
925
fprintf (
931
926
stderr,
932
- " No vocab_size was provided with -v argument, and there is no default vocab_size for model_type ModelType:: %d." ,
927
+ " No vocab_size was provided with -v argument, and there is no default vocab_size for model_type %d.\n " ,
933
928
model_type);
934
929
error_usage ();
935
930
}
0 commit comments