Skip to content

Commit 092363f

Browse files
mikekgfbmalfet
authored andcommitted
remove code for no KV Cache path (#527)
1 parent af88c63 commit 092363f

File tree

2 files changed

+98
-95
lines changed

2 files changed

+98
-95
lines changed

runner/run.cpp

Lines changed: 98 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
11
/* Inference for Llama-2 Transformer model in pure C++ */
2-
#include <cstdint>
3-
#include <cstdlib>
42
#include <ctype.h>
5-
#include <iterator>
63
#include <math.h>
74
#include <stdint.h>
85
#include <stdio.h>
96
#include <stdlib.h>
107
#include <string.h>
118
#include <time.h>
129
#include <tokenizer.h>
10+
#include <cstdint>
11+
#include <cstdlib>
12+
#include <iterator>
1313
#include <string>
1414

15-
1615
#ifdef DEBUG
1716
#include <cassert>
1817
#include <iostream>
@@ -167,22 +166,14 @@ float* forward(Transformer* transformer, int token, int pos) {
167166
torch::Tensor pos_tensor = torch::from_blob(pos_buffer, {1}, torch::kLong);
168167
std::vector<torch::Tensor> inputs{token_tensor, pos_tensor};
169168

170-
torch::Tensor result = transformer->runner->run(inputs)[0].to(torch::dtype(torch::kFloat32));
169+
torch::Tensor result =
170+
transformer->runner->run(inputs)[0].to(torch::dtype(torch::kFloat32));
171171
auto logits = result[0].data_ptr();
172172

173173
#else // __ET_MODEL__
174174
ManagedTensor pos_managed(pos_buffer, sizeof(int64_t), {1}, ScalarType::Long);
175-
#ifndef __KV_CACHE__
176-
// @lint-ignore CLANGTIDY facebook-hte-LocalUncheckedArrayBounds
177-
ManagedTensor tokens_managed(
178-
&(s->toks[pos]),
179-
/*ignored*/ sizeof(int64_t) * (pos + 1),
180-
{1, 1},
181-
ScalarType::Long);
182-
#else // __KV_CACHE__
183175
ManagedTensor tokens_managed(
184176
token_buffer, sizeof(int64_t), {1, 1}, ScalarType::Long);
185-
#endif
186177
std::vector<EValue> inputs;
187178
auto tmp1 = EValue(tokens_managed.get_aliasing_tensor());
188179
auto tmp2 = EValue(pos_managed.get_aliasing_tensor());
@@ -491,9 +482,9 @@ void read_stdin(const char* guide, char* buffer, size_t bufsize) {
491482
// is not safely implemented, it's more a proof of concept atm.
492483

493484
enum class ModelType {
494-
unknown,
495-
llama2,
496-
llama3,
485+
unknown,
486+
llama2,
487+
llama3,
497488
};
498489

499490
ModelType get_model_type(Tokenizer* tokenizer) {
@@ -519,19 +510,27 @@ uint64_t get_eot_token(Tokenizer* tokenizer) {
519510
return tokens[0];
520511
}
521512

522-
fprintf(stderr, "No chat template implemnation for model type %d", model_type);
513+
fprintf(
514+
stderr, "No chat template implemnation for model type %d", model_type);
523515
exit(EXIT_FAILURE);
524516
}
525517

526-
std::vector<uint64_t> get_initial_prompt_tokens(const char* cli_system_prompt, const char* cli_user_prompt, Tokenizer* tokenizer) {
518+
std::vector<uint64_t> get_initial_prompt_tokens(
519+
const char* cli_system_prompt,
520+
const char* cli_user_prompt,
521+
Tokenizer* tokenizer) {
527522
char system_prompt[512];
528523
char user_prompt[512];
529-
char rendered_prompt[512*2 + 200]; // the prompt template is ~170 characters. We use 200 to be safe.
524+
char rendered_prompt[512 * 2 + 200]; // the prompt template is ~170
525+
// characters. We use 200 to be safe.
530526

531527
if (cli_system_prompt != NULL) {
532528
strcpy(system_prompt, cli_system_prompt);
533529
} else {
534-
read_stdin("Enter system prompt (optional): ", system_prompt, sizeof(system_prompt));
530+
read_stdin(
531+
"Enter system prompt (optional): ",
532+
system_prompt,
533+
sizeof(system_prompt));
535534
}
536535

537536
if (cli_user_prompt != NULL) {
@@ -540,111 +539,114 @@ std::vector<uint64_t> get_initial_prompt_tokens(const char* cli_system_prompt, c
540539
read_stdin("User: ", user_prompt, sizeof(user_prompt));
541540
}
542541

543-
ModelType model_type = get_model_type(tokenizer);
544-
std::vector<uint64_t> tokens;
545-
546-
switch (model_type) {
542+
ModelType model_type = get_model_type(tokenizer);
543+
std::vector<uint64_t> tokens;
547544

545+
switch (model_type) {
548546
case ModelType::llama2:
549547
if (system_prompt[0] != '\0') {
550548
snprintf(
551-
rendered_prompt,
552-
sizeof(rendered_prompt)-1,
553-
"[INST] <<SYS>>\n%s\n<</SYS>>\n\n%s [/INST]",
554-
system_prompt,
555-
user_prompt
556-
);
549+
rendered_prompt,
550+
sizeof(rendered_prompt) - 1,
551+
"[INST] <<SYS>>\n%s\n<</SYS>>\n\n%s [/INST]",
552+
system_prompt,
553+
user_prompt);
557554
} else {
558555
// const char prompt_template[] = ;
559556
snprintf(
560-
rendered_prompt,
561-
sizeof(rendered_prompt)-1,
562-
"[INST] %s [/INST]",
563-
user_prompt
564-
);
557+
rendered_prompt,
558+
sizeof(rendered_prompt) - 1,
559+
"[INST] %s [/INST]",
560+
user_prompt);
565561
}
566562

567-
// We need to add BOS token here and not in template because llama2 tokenizer
568-
// does not pattern match special tokens
563+
// We need to add BOS token here and not in template because llama2
564+
// tokenizer does not pattern match special tokens
569565
tokens = tokenizer->encode(rendered_prompt, 1, 0);
570566
break;
571567

572568
case ModelType::llama3:
573569
if (system_prompt[0] != '\0') {
574570
snprintf(
575-
rendered_prompt,
576-
sizeof(rendered_prompt)-1,
577-
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
578-
system_prompt,
579-
user_prompt
580-
);
571+
rendered_prompt,
572+
sizeof(rendered_prompt) - 1,
573+
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
574+
system_prompt,
575+
user_prompt);
581576
} else {
582577
snprintf(
583-
rendered_prompt,
584-
sizeof(rendered_prompt)-1,
585-
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
586-
user_prompt
587-
);
578+
rendered_prompt,
579+
sizeof(rendered_prompt) - 1,
580+
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
581+
user_prompt);
588582
}
589583
tokens = tokenizer->encode(rendered_prompt, 0, 0);
590584
break;
591585

592586
default:
593-
fprintf(stderr, "No chat template implemnation for model type %d", model_type);
587+
fprintf(
588+
stderr,
589+
"No chat template implemnation for model type %d",
590+
model_type);
594591
exit(EXIT_FAILURE);
595-
}
592+
}
596593

597-
#ifdef DEBUG
598-
std::cerr << "Start of rendered prompt:" << std::endl;
599-
std::cerr << rendered_prompt;
600-
std::cerr << "End of rendered prompt:" << std::endl;
601-
std::cerr << "Encoded prompt: ";
602-
for (int i = 0; i < tokens.size(); i++) {
603-
std::cerr << tokens[i] << ", ";
604-
}
605-
std::cerr << std::endl << std::flush;
606-
#endif
594+
#ifdef DEBUG
595+
std::cerr << "Start of rendered prompt:" << std::endl;
596+
std::cerr << rendered_prompt;
597+
std::cerr << "End of rendered prompt:" << std::endl;
598+
std::cerr << "Encoded prompt: ";
599+
for (int i = 0; i < tokens.size(); i++) {
600+
std::cerr << tokens[i] << ", ";
601+
}
602+
std::cerr << std::endl << std::flush;
603+
#endif
607604

608-
return tokens;
605+
return tokens;
609606
}
610607

611608
std::vector<uint64_t> get_next_user_prompt_tokens(Tokenizer* tokenizer) {
612609
char user_prompt[512];
613-
char rendered_prompt[512 + 150]; // the prompt template is ~100 characters. We use 150 to be safe.
610+
char rendered_prompt[512 + 150]; // the prompt template is ~100 characters. We
611+
// use 150 to be safe.
614612

615613
read_stdin("User: ", user_prompt, sizeof(user_prompt));
616614

617615
ModelType model_type = get_model_type(tokenizer);
618616
std::vector<uint64_t> tokens;
619617

620618
switch (model_type) {
621-
622619
case ModelType::llama2:
623620
// const char prompt_template[] = ;
624-
snprintf(rendered_prompt, sizeof(rendered_prompt)-1, "[INST] %s [/INST]", user_prompt);
621+
snprintf(
622+
rendered_prompt,
623+
sizeof(rendered_prompt) - 1,
624+
"[INST] %s [/INST]",
625+
user_prompt);
625626

626-
// We need to add BOS token here and not in template because llama2 tokenizer
627-
// does not pattern match special tokens
628-
tokens = tokenizer->encode(rendered_prompt, /*bos*/1, /*eos*/0);
627+
// We need to add BOS token here and not in template because llama2
628+
// tokenizer does not pattern match special tokens
629+
tokens = tokenizer->encode(rendered_prompt, /*bos*/ 1, /*eos*/ 0);
629630
break;
630631

631632
case ModelType::llama3:
632633
snprintf(
633-
rendered_prompt,
634-
sizeof(rendered_prompt)-1,
635-
"<|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
636-
user_prompt
637-
);
634+
rendered_prompt,
635+
sizeof(rendered_prompt) - 1,
636+
"<|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
637+
user_prompt);
638638
tokens = tokenizer->encode(rendered_prompt, 0, 0);
639639
break;
640640

641641
default:
642-
fprintf(stderr, "No chat template implemnation for model type %d", model_type);
642+
fprintf(
643+
stderr,
644+
"No chat template implemnation for model type %d",
645+
model_type);
643646
exit(EXIT_FAILURE);
644647
}
645648

646-
647-
#ifdef DEBUG
649+
#ifdef DEBUG
648650
std::cerr << "Start of rendered prompt:" << std::endl;
649651
std::cerr << rendered_prompt;
650652
std::cerr << "End of rendered prompt:" << std::endl;
@@ -653,20 +655,18 @@ std::vector<uint64_t> get_next_user_prompt_tokens(Tokenizer* tokenizer) {
653655
std::cerr << tokens[i] << ", ";
654656
}
655657
std::cerr << std::endl << std::flush;
656-
#endif
658+
#endif
657659

658660
return tokens;
659661
}
660662

661-
662663
void chat(
663664
Transformer* transformer,
664665
Tokenizer* tokenizer,
665666
Sampler* sampler,
666667
const char* cli_user_prompt,
667668
const char* cli_system_prompt,
668669
int steps) {
669-
670670
const uint64_t EOT_TOKEN = get_eot_token(tokenizer);
671671
int num_prompt_tokens = 0;
672672
std::vector<uint64_t> prompt_tokens;
@@ -679,12 +679,12 @@ void chat(
679679
int prev_token;
680680
int pos = 0; // position in the sequence
681681
while (pos < steps) {
682-
683682
// when it is the user's turn to contribute tokens to the dialog...
684683
if (user_turn) {
685684
// get the (optional) system prompt at position 0
686685
if (pos == 0) {
687-
prompt_tokens = get_initial_prompt_tokens(cli_system_prompt, cli_user_prompt, tokenizer);
686+
prompt_tokens = get_initial_prompt_tokens(
687+
cli_system_prompt, cli_user_prompt, tokenizer);
688688
} else {
689689
prompt_tokens = get_next_user_prompt_tokens(tokenizer);
690690
}
@@ -711,12 +711,12 @@ void chat(
711711

712712
// std::cout << "TOKEN: " << token << " NEXT: " << next << std::endl;
713713

714-
715714
if ((user_idx >= num_prompt_tokens) && (token == EOT_TOKEN)) {
716715
user_turn = 1;
717716
}
718717

719-
if (user_idx >= num_prompt_tokens && token != EOT_TOKEN && next != EOT_TOKEN) {
718+
if (user_idx >= num_prompt_tokens && token != EOT_TOKEN &&
719+
next != EOT_TOKEN) {
720720
std::string piece = tokenizer->decode(token, next);
721721
safe_printf(piece.c_str()); // same as printf("%s", piece), but skips
722722
// "unsafe" bytes
@@ -727,7 +727,6 @@ void chat(
727727
printf("\n");
728728
}
729729
pos++;
730-
731730
}
732731
printf("\n");
733732
}
@@ -752,7 +751,9 @@ void error_usage() {
752751
fprintf(stderr, " -z <string> optional path to custom tokenizer\n");
753752
fprintf(stderr, " -m <string> mode: generate|chat, default: generate\n");
754753
fprintf(stderr, " -y <string> (optional) system prompt in chat mode\n");
755-
fprintf(stderr, " -l <int> (optional) llama version (2 or 3). Defaults to 2.\n");
754+
fprintf(
755+
stderr,
756+
" -l <int> (optional) llama version (2 or 3). Defaults to 2.\n");
756757
exit(EXIT_FAILURE);
757758
}
758759

@@ -776,7 +777,8 @@ int main(int argc, char* argv[]) {
776777
int llama_ver = 2;
777778

778779
#if defined(ET_USE_ADPATIVE_THREADS)
779-
uint32_t num_performant_cores = torch::executorch::cpuinfo::get_num_performant_cores();
780+
uint32_t num_performant_cores =
781+
torch::executorch::cpuinfo::get_num_performant_cores();
780782
if (num_performant_cores > 0) {
781783
torch::executorch::threadpool::get_threadpool()->_unsafe_reset_threadpool(
782784
num_performant_cores);
@@ -820,9 +822,8 @@ int main(int argc, char* argv[]) {
820822
} else if (argv[i][1] == 'y') {
821823
system_prompt = argv[i + 1];
822824
} else if (argv[i][1] == 'l') {
823-
llama_ver = atoi(argv[i+1]);
824-
}
825-
else {
825+
llama_ver = atoi(argv[i + 1]);
826+
} else {
826827
error_usage();
827828
}
828829
}
@@ -837,7 +838,6 @@ int main(int argc, char* argv[]) {
837838
if (steps < 0)
838839
steps = 0;
839840

840-
841841
if (vocab_size == -1) {
842842
if (llama_ver == 2) {
843843
vocab_size = 32000;
@@ -855,16 +855,21 @@ int main(int argc, char* argv[]) {
855855

856856
switch (llama_ver) {
857857
case 2:
858-
tokenizer = new BPETokenizer(transformer.config.vocab_size, /*bos*/ 1, /*eos*/ 2);
858+
tokenizer =
859+
new BPETokenizer(transformer.config.vocab_size, /*bos*/ 1, /*eos*/ 2);
859860
tokenizer->load(tokenizer_path);
860861
break;
861862
case 3:
862-
tokenizer = new Tiktoken(transformer.config.vocab_size, /*bos*/ 1, /*eos*/ 2);
863+
tokenizer =
864+
new Tiktoken(transformer.config.vocab_size, /*bos*/ 1, /*eos*/ 2);
863865
tokenizer->load(tokenizer_path);
864866
break;
865867

866868
default:
867-
fprintf(stderr, "Cannot load tokenizer for unrecognized llama version %d", llama_ver);
869+
fprintf(
870+
stderr,
871+
"Cannot load tokenizer for unrecognized llama version %d",
872+
llama_ver);
868873
exit(EXIT_FAILURE);
869874
}
870875

0 commit comments

Comments
 (0)