Skip to content

Commit 72eb661

Browse files
committed
switch more things to C
1 parent 56c66a7 commit 72eb661

File tree

1 file changed

+45
-50
lines changed

1 file changed

+45
-50
lines changed

runner/run.cpp

Lines changed: 45 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -49,22 +49,22 @@ using torch::executor::Result;
4949
// ----------------------------------------------------------------------------
5050
// Transformer model
5151

52-
enum class ModelType {
53-
unknown = 0,
54-
llama2 = 2,
55-
llama3 = 3,
52+
enum ModelType {
53+
UNKNOWN_MODEL = 0,
54+
LLAMA2_MODEL = 2,
55+
LLAMA3_MODEL = 3,
5656
};
5757

5858
ModelType get_model_type(int model_int) {
5959
switch (model_int) {
6060
case 2:
61-
return ModelType::llama2;
61+
return LLAMA2_MODEL;
6262
break;
6363
case 3:
64-
return ModelType::llama3;
64+
return LLAMA3_MODEL;
6565
break;
6666
default:
67-
return ModelType::unknown;
67+
return UNKNOWN_MODEL;
6868
}
6969
}
7070

@@ -381,20 +381,19 @@ Tokenizer* build_tokenizer(
381381
const char* tokenizer_path,
382382
ModelType model_type,
383383
int vocab_size) {
384-
Tokenizer* tokenizer = nullptr;
384+
Tokenizer* tokenizer = NULL;
385385
switch (model_type) {
386-
case ModelType::llama2:
386+
case LLAMA2_MODEL:
387387
tokenizer = new BPETokenizer(vocab_size, /*bos*/ 1, /*eos*/ 2);
388388
tokenizer->load(tokenizer_path);
389389
break;
390-
case ModelType::llama3:
390+
case LLAMA3_MODEL:
391391
tokenizer = new Tiktoken(vocab_size, /*bos*/ 1, /*eos*/ 2);
392392
tokenizer->load(tokenizer_path);
393393
break;
394394
default:
395-
throw std::runtime_error(
396-
"No tokenizer defined for model type " +
397-
std::to_string(static_cast<int>(model_type)));
395+
fprintf(stderr, "No tokenizer defined for model type %d.\n", model_type);
396+
exit(EXIT_FAILURE);
398397
}
399398
return tokenizer;
400399
}
@@ -410,7 +409,7 @@ void safe_printf(const char* piece) {
410409
// piece might be a raw byte token, and we only want to print printable chars
411410
// or whitespace because some of the other bytes can be various control codes,
412411
// backspace, etc.
413-
if (piece == nullptr) {
412+
if (piece == NULL) {
414413
return;
415414
}
416415
if (piece[0] == '\0') {
@@ -539,7 +538,7 @@ void generate(
539538
int steps,
540539
ModelType model_type) {
541540
const char* default_prompt = "Once upon a time";
542-
if (prompt == nullptr) {
541+
if (prompt == NULL) {
543542
prompt = default_prompt;
544543
}
545544

@@ -550,11 +549,11 @@ void generate(
550549
std::vector<uint64_t> prompt_tokens;
551550
std::vector<uint64_t> stop_tokens;
552551
switch (model_type) {
553-
case ModelType::llama2:
552+
case LLAMA2_MODEL:
554553
prompt_tokens = tokenizer->encode(prompt, 1, 0);
555554
stop_tokens.push_back(tokenizer->eos_tok());
556555
break;
557-
case ModelType::llama3:
556+
case LLAMA3_MODEL:
558557
prompt_tokens = tokenizer->encode(prompt, 0, 0);
559558
prompt_tokens.insert(
560559
prompt_tokens.begin(),
@@ -563,9 +562,8 @@ void generate(
563562
stop_tokens.push_back(tokenizer->encode("<|eot_id|>", 0, 0)[0]);
564563
break;
565564
default:
566-
throw std::runtime_error(
567-
"Generate does not support model type " +
568-
std::to_string(static_cast<int>(model_type)));
565+
fprintf(stderr, "Generate does not support model type %d.\n", model_type);
566+
exit(EXIT_FAILURE);
569567
}
570568

571569
generate_from_prompt_tokens(
@@ -583,7 +581,7 @@ void generate(
583581
void read_stdin(const char* guide, char* buffer, size_t bufsize) {
584582
// read a line from stdin, up to but not including \n
585583
printf("%s", guide);
586-
if (fgets(buffer, bufsize, stdin) != nullptr) {
584+
if (fgets(buffer, bufsize, stdin) != NULL) {
587585
size_t len = strlen(buffer);
588586
if (len > 0 && buffer[len - 1] == '\n') {
589587
buffer[len - 1] = '\0'; // strip newline
@@ -607,7 +605,7 @@ std::vector<uint64_t> get_initial_prompt_tokens(
607605
char rendered_prompt[512 * 2 + 200]; // the prompt template is ~170
608606
// characters. We use 200 to be safe.
609607

610-
if (cli_system_prompt != nullptr) {
608+
if (cli_system_prompt != NULL) {
611609
strcpy(system_prompt, cli_system_prompt);
612610
} else {
613611
read_stdin(
@@ -616,7 +614,7 @@ std::vector<uint64_t> get_initial_prompt_tokens(
616614
sizeof(system_prompt));
617615
}
618616

619-
if (cli_user_prompt != nullptr) {
617+
if (cli_user_prompt != NULL) {
620618
strcpy(user_prompt, cli_user_prompt);
621619
} else {
622620
read_stdin("User: ", user_prompt, sizeof(user_prompt));
@@ -625,7 +623,7 @@ std::vector<uint64_t> get_initial_prompt_tokens(
625623
std::vector<uint64_t> tokens;
626624

627625
switch (model_type) {
628-
case ModelType::llama2:
626+
case LLAMA2_MODEL:
629627
if (system_prompt[0] != '\0') {
630628
snprintf(
631629
rendered_prompt,
@@ -646,7 +644,7 @@ std::vector<uint64_t> get_initial_prompt_tokens(
646644
tokens = tokenizer->encode(rendered_prompt, 1, 0);
647645
break;
648646

649-
case ModelType::llama3:
647+
case LLAMA3_MODEL:
650648
if (system_prompt[0] != '\0') {
651649
snprintf(
652650
rendered_prompt,
@@ -665,9 +663,8 @@ std::vector<uint64_t> get_initial_prompt_tokens(
665663
break;
666664

667665
default:
668-
throw std::runtime_error(
669-
"Chat does not support model type " +
670-
std::to_string(static_cast<int>(model_type)));
666+
fprintf(stderr, "Chat does not support model type %d.\n", model_type);
667+
exit(EXIT_FAILURE);
671668
}
672669

673670
#ifdef DEBUG
@@ -695,7 +692,7 @@ std::vector<uint64_t> get_next_user_prompt_tokens(
695692
std::vector<uint64_t> tokens;
696693

697694
switch (model_type) {
698-
case ModelType::llama2:
695+
case LLAMA2_MODEL:
699696
snprintf(
700697
rendered_prompt,
701698
sizeof(rendered_prompt) - 1,
@@ -707,7 +704,7 @@ std::vector<uint64_t> get_next_user_prompt_tokens(
707704
tokens = tokenizer->encode(rendered_prompt, /*bos*/ 1, /*eos*/ 0);
708705
break;
709706

710-
case ModelType::llama3:
707+
case LLAMA3_MODEL:
711708
snprintf(
712709
rendered_prompt,
713710
sizeof(rendered_prompt) - 1,
@@ -717,9 +714,8 @@ std::vector<uint64_t> get_next_user_prompt_tokens(
717714
break;
718715

719716
default:
720-
throw std::runtime_error(
721-
"Chat does not support model type " +
722-
std::to_string(static_cast<int>(model_type)));
717+
fprintf(stderr, "Chat does not support model type %d.\n", model_type);
718+
exit(EXIT_FAILURE);
723719
}
724720

725721
#ifdef DEBUG
@@ -751,17 +747,16 @@ void chat(
751747
uint64_t eot_token;
752748
std::vector<uint64_t> prompt_tokens;
753749
switch (model_type) {
754-
case ModelType::llama2:
750+
case LLAMA2_MODEL:
755751
// llama2 uses EOS as EOT token
756752
eot_token = tokenizer->eos_tok();
757753
break;
758-
case ModelType::llama3:
754+
case LLAMA3_MODEL:
759755
eot_token = tokenizer->encode("<|eot_id|>", 0, 0)[0];
760756
break;
761757
default:
762-
throw std::runtime_error(
763-
"Chat does not support model type " +
764-
std::to_string(static_cast<int>(model_type)));
758+
fprintf(stderr, "Chat does not support model type %d.\n", model_type);
759+
exit(EXIT_FAILURE);
765760
}
766761

767762
std::vector<uint64_t> stop_tokens{eot_token};
@@ -801,7 +796,7 @@ void error_usage() {
801796
fprintf(
802797
stderr,
803798
" -p <float> p value in top-p (nucleus) sampling in [0,1], default 0.9\n");
804-
fprintf(stderr, " -s <int> random seed, default time(nullptr)\n");
799+
fprintf(stderr, " -s <int> random seed, default time(NULL)\n");
805800
fprintf(
806801
stderr,
807802
" -n <int> number of steps to run for, default 256. 0 = max_seq_len\n");
@@ -819,19 +814,19 @@ void error_usage() {
819814

820815
int main(int argc, char* argv[]) {
821816
// default parameters
822-
char* model_path = nullptr;
823-
char* tokenizer_path = nullptr;
817+
char* model_path = NULL;
818+
char* tokenizer_path = NULL;
824819
float temperature =
825820
1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
826821
float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well,
827822
// but slower
828823

829824
int steps = 256; // number of steps to run for
830-
const char* prompt = nullptr; // prompt string
825+
const char* prompt = NULL; // prompt string
831826
unsigned long long rng_seed = 0; // seed rng with time by default
832827
const char* mode = "generate"; // generate|chat
833828
char* system_prompt =
834-
nullptr; // the (optional) system prompt to use in chat mode
829+
NULL; // the (optional) system prompt to use in chat mode
835830

836831
int vocab_size = -1;
837832
int llama_ver = 2;
@@ -889,27 +884,27 @@ int main(int argc, char* argv[]) {
889884
}
890885

891886
ModelType model_type = get_model_type(llama_ver);
892-
if (model_type == ModelType::unknown) {
887+
if (model_type == UNKNOWN_MODEL) {
893888
fprintf(
894889
stderr,
895890
"Unknown model type passed by -l argument. Received l=%d.",
896891
llama_ver);
897892
error_usage();
898893
}
899894

900-
if (model_path == nullptr) {
895+
if (model_path == NULL) {
901896
fprintf(stderr, "No model_path provided.");
902897
error_usage();
903898
}
904899

905-
if (tokenizer_path == nullptr) {
900+
if (tokenizer_path == NULL) {
906901
fprintf(stderr, "No tokenizer_path provided.");
907902
error_usage();
908903
}
909904

910905
// parameter validation/overrides
911906
if (rng_seed <= 0)
912-
rng_seed = (unsigned int)time(nullptr);
907+
rng_seed = (unsigned int)time(NULL);
913908
if (temperature < 0.0)
914909
temperature = 0.0;
915910
if (topp < 0.0 || 1.0 < topp)
@@ -920,16 +915,16 @@ int main(int argc, char* argv[]) {
920915
// If no tokenizer path provided, get default for model_type
921916
if (vocab_size == -1) {
922917
switch (model_type) {
923-
case ModelType::llama2:
918+
case LLAMA2_MODEL:
924919
vocab_size = 32000;
925920
break;
926-
case ModelType::llama3:
921+
case LLAMA3_MODEL:
927922
vocab_size = 128256;
928923
break;
929924
default:
930925
fprintf(
931926
stderr,
932-
"No vocab_size was provided with -v argument, and there is no default vocab_size for model_type ModelType::%d.",
927+
"No vocab_size was provided with -v argument, and there is no default vocab_size for model_type %d.\n",
933928
model_type);
934929
error_usage();
935930
}

0 commit comments

Comments
 (0)