Skip to content

remove code for no KV Cache path #527

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 98 additions & 93 deletions runner/run.cpp
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
/* Inference for Llama-2 Transformer model in pure C++ */
#include <cstdint>
#include <cstdlib>
#include <ctype.h>
#include <iterator>
#include <math.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include <tokenizer.h>
#include <cstdint>
#include <cstdlib>
#include <iterator>
#include <string>


#ifdef DEBUG
#include <cassert>
#include <iostream>
Expand Down Expand Up @@ -167,22 +166,14 @@ float* forward(Transformer* transformer, int token, int pos) {
torch::Tensor pos_tensor = torch::from_blob(pos_buffer, {1}, torch::kLong);
std::vector<torch::Tensor> inputs{token_tensor, pos_tensor};

torch::Tensor result = transformer->runner->run(inputs)[0].to(torch::dtype(torch::kFloat32));
torch::Tensor result =
transformer->runner->run(inputs)[0].to(torch::dtype(torch::kFloat32));
auto logits = result[0].data_ptr();

#else // __ET_MODEL__
ManagedTensor pos_managed(pos_buffer, sizeof(int64_t), {1}, ScalarType::Long);
#ifndef __KV_CACHE__
// @lint-ignore CLANGTIDY facebook-hte-LocalUncheckedArrayBounds
ManagedTensor tokens_managed(
&(s->toks[pos]),
/*ignored*/ sizeof(int64_t) * (pos + 1),
{1, 1},
ScalarType::Long);
#else // __KV_CACHE__
ManagedTensor tokens_managed(
token_buffer, sizeof(int64_t), {1, 1}, ScalarType::Long);
#endif
std::vector<EValue> inputs;
auto tmp1 = EValue(tokens_managed.get_aliasing_tensor());
auto tmp2 = EValue(pos_managed.get_aliasing_tensor());
Expand Down Expand Up @@ -491,9 +482,9 @@ void read_stdin(const char* guide, char* buffer, size_t bufsize) {
// is not safely implemented, it's more a proof of concept atm.

enum class ModelType {
unknown,
llama2,
llama3,
unknown,
llama2,
llama3,
};

ModelType get_model_type(Tokenizer* tokenizer) {
Expand All @@ -519,19 +510,27 @@ uint64_t get_eot_token(Tokenizer* tokenizer) {
return tokens[0];
}

fprintf(stderr, "No chat template implemnation for model type %d", model_type);
fprintf(
stderr, "No chat template implemnation for model type %d", model_type);
exit(EXIT_FAILURE);
}

std::vector<uint64_t> get_initial_prompt_tokens(const char* cli_system_prompt, const char* cli_user_prompt, Tokenizer* tokenizer) {
std::vector<uint64_t> get_initial_prompt_tokens(
const char* cli_system_prompt,
const char* cli_user_prompt,
Tokenizer* tokenizer) {
char system_prompt[512];
char user_prompt[512];
char rendered_prompt[512*2 + 200]; // the prompt template is ~170 characters. We use 200 to be safe.
char rendered_prompt[512 * 2 + 200]; // the prompt template is ~170
// characters. We use 200 to be safe.

if (cli_system_prompt != NULL) {
strcpy(system_prompt, cli_system_prompt);
} else {
read_stdin("Enter system prompt (optional): ", system_prompt, sizeof(system_prompt));
read_stdin(
"Enter system prompt (optional): ",
system_prompt,
sizeof(system_prompt));
}

if (cli_user_prompt != NULL) {
Expand All @@ -540,111 +539,114 @@ std::vector<uint64_t> get_initial_prompt_tokens(const char* cli_system_prompt, c
read_stdin("User: ", user_prompt, sizeof(user_prompt));
}

ModelType model_type = get_model_type(tokenizer);
std::vector<uint64_t> tokens;

switch (model_type) {
ModelType model_type = get_model_type(tokenizer);
std::vector<uint64_t> tokens;

switch (model_type) {
case ModelType::llama2:
if (system_prompt[0] != '\0') {
snprintf(
rendered_prompt,
sizeof(rendered_prompt)-1,
"[INST] <<SYS>>\n%s\n<</SYS>>\n\n%s [/INST]",
system_prompt,
user_prompt
);
rendered_prompt,
sizeof(rendered_prompt) - 1,
"[INST] <<SYS>>\n%s\n<</SYS>>\n\n%s [/INST]",
system_prompt,
user_prompt);
} else {
// const char prompt_template[] = ;
snprintf(
rendered_prompt,
sizeof(rendered_prompt)-1,
"[INST] %s [/INST]",
user_prompt
);
rendered_prompt,
sizeof(rendered_prompt) - 1,
"[INST] %s [/INST]",
user_prompt);
}

// We need to add BOS token here and not in template because llama2 tokenizer
// does not pattern match special tokens
// We need to add BOS token here and not in template because llama2
// tokenizer does not pattern match special tokens
tokens = tokenizer->encode(rendered_prompt, 1, 0);
break;

case ModelType::llama3:
if (system_prompt[0] != '\0') {
snprintf(
rendered_prompt,
sizeof(rendered_prompt)-1,
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
system_prompt,
user_prompt
);
rendered_prompt,
sizeof(rendered_prompt) - 1,
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
system_prompt,
user_prompt);
} else {
snprintf(
rendered_prompt,
sizeof(rendered_prompt)-1,
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
user_prompt
);
rendered_prompt,
sizeof(rendered_prompt) - 1,
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
user_prompt);
}
tokens = tokenizer->encode(rendered_prompt, 0, 0);
break;

default:
fprintf(stderr, "No chat template implemnation for model type %d", model_type);
fprintf(
stderr,
"No chat template implemnation for model type %d",
model_type);
exit(EXIT_FAILURE);
}
}

#ifdef DEBUG
std::cerr << "Start of rendered prompt:" << std::endl;
std::cerr << rendered_prompt;
std::cerr << "End of rendered prompt:" << std::endl;
std::cerr << "Encoded prompt: ";
for (int i = 0; i < tokens.size(); i++) {
std::cerr << tokens[i] << ", ";
}
std::cerr << std::endl << std::flush;
#endif
#ifdef DEBUG
std::cerr << "Start of rendered prompt:" << std::endl;
std::cerr << rendered_prompt;
std::cerr << "End of rendered prompt:" << std::endl;
std::cerr << "Encoded prompt: ";
for (int i = 0; i < tokens.size(); i++) {
std::cerr << tokens[i] << ", ";
}
std::cerr << std::endl << std::flush;
#endif

return tokens;
return tokens;
}

std::vector<uint64_t> get_next_user_prompt_tokens(Tokenizer* tokenizer) {
char user_prompt[512];
char rendered_prompt[512 + 150]; // the prompt template is ~100 characters. We use 150 to be safe.
char rendered_prompt[512 + 150]; // the prompt template is ~100 characters. We
// use 150 to be safe.

read_stdin("User: ", user_prompt, sizeof(user_prompt));

ModelType model_type = get_model_type(tokenizer);
std::vector<uint64_t> tokens;

switch (model_type) {

case ModelType::llama2:
// const char prompt_template[] = ;
snprintf(rendered_prompt, sizeof(rendered_prompt)-1, "[INST] %s [/INST]", user_prompt);
snprintf(
rendered_prompt,
sizeof(rendered_prompt) - 1,
"[INST] %s [/INST]",
user_prompt);

// We need to add BOS token here and not in template because llama2 tokenizer
// does not pattern match special tokens
tokens = tokenizer->encode(rendered_prompt, /*bos*/1, /*eos*/0);
// We need to add BOS token here and not in template because llama2
// tokenizer does not pattern match special tokens
tokens = tokenizer->encode(rendered_prompt, /*bos*/ 1, /*eos*/ 0);
break;

case ModelType::llama3:
snprintf(
rendered_prompt,
sizeof(rendered_prompt)-1,
"<|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
user_prompt
);
rendered_prompt,
sizeof(rendered_prompt) - 1,
"<|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
user_prompt);
tokens = tokenizer->encode(rendered_prompt, 0, 0);
break;

default:
fprintf(stderr, "No chat template implemnation for model type %d", model_type);
fprintf(
stderr,
"No chat template implemnation for model type %d",
model_type);
exit(EXIT_FAILURE);
}


#ifdef DEBUG
#ifdef DEBUG
std::cerr << "Start of rendered prompt:" << std::endl;
std::cerr << rendered_prompt;
std::cerr << "End of rendered prompt:" << std::endl;
Expand All @@ -653,20 +655,18 @@ std::vector<uint64_t> get_next_user_prompt_tokens(Tokenizer* tokenizer) {
std::cerr << tokens[i] << ", ";
}
std::cerr << std::endl << std::flush;
#endif
#endif

return tokens;
}


void chat(
Transformer* transformer,
Tokenizer* tokenizer,
Sampler* sampler,
const char* cli_user_prompt,
const char* cli_system_prompt,
int steps) {

const uint64_t EOT_TOKEN = get_eot_token(tokenizer);
int num_prompt_tokens = 0;
std::vector<uint64_t> prompt_tokens;
Expand All @@ -679,12 +679,12 @@ void chat(
int prev_token;
int pos = 0; // position in the sequence
while (pos < steps) {

// when it is the user's turn to contribute tokens to the dialog...
if (user_turn) {
// get the (optional) system prompt at position 0
if (pos == 0) {
prompt_tokens = get_initial_prompt_tokens(cli_system_prompt, cli_user_prompt, tokenizer);
prompt_tokens = get_initial_prompt_tokens(
cli_system_prompt, cli_user_prompt, tokenizer);
} else {
prompt_tokens = get_next_user_prompt_tokens(tokenizer);
}
Expand All @@ -711,12 +711,12 @@ void chat(

// std::cout << "TOKEN: " << token << " NEXT: " << next << std::endl;


if ((user_idx >= num_prompt_tokens) && (token == EOT_TOKEN)) {
user_turn = 1;
}

if (user_idx >= num_prompt_tokens && token != EOT_TOKEN && next != EOT_TOKEN) {
if (user_idx >= num_prompt_tokens && token != EOT_TOKEN &&
next != EOT_TOKEN) {
std::string piece = tokenizer->decode(token, next);
safe_printf(piece.c_str()); // same as printf("%s", piece), but skips
// "unsafe" bytes
Expand All @@ -727,7 +727,6 @@ void chat(
printf("\n");
}
pos++;

}
printf("\n");
}
Expand All @@ -752,7 +751,9 @@ void error_usage() {
fprintf(stderr, " -z <string> optional path to custom tokenizer\n");
fprintf(stderr, " -m <string> mode: generate|chat, default: generate\n");
fprintf(stderr, " -y <string> (optional) system prompt in chat mode\n");
fprintf(stderr, " -l <int> (optional) llama version (2 or 3). Defaults to 2.\n");
fprintf(
stderr,
" -l <int> (optional) llama version (2 or 3). Defaults to 2.\n");
exit(EXIT_FAILURE);
}

Expand All @@ -776,7 +777,8 @@ int main(int argc, char* argv[]) {
int llama_ver = 2;

#if defined(ET_USE_ADPATIVE_THREADS)
uint32_t num_performant_cores = torch::executorch::cpuinfo::get_num_performant_cores();
uint32_t num_performant_cores =
torch::executorch::cpuinfo::get_num_performant_cores();
if (num_performant_cores > 0) {
torch::executorch::threadpool::get_threadpool()->_unsafe_reset_threadpool(
num_performant_cores);
Expand Down Expand Up @@ -820,9 +822,8 @@ int main(int argc, char* argv[]) {
} else if (argv[i][1] == 'y') {
system_prompt = argv[i + 1];
} else if (argv[i][1] == 'l') {
llama_ver = atoi(argv[i+1]);
}
else {
llama_ver = atoi(argv[i + 1]);
} else {
error_usage();
}
}
Expand All @@ -837,7 +838,6 @@ int main(int argc, char* argv[]) {
if (steps < 0)
steps = 0;


if (vocab_size == -1) {
if (llama_ver == 2) {
vocab_size = 32000;
Expand All @@ -855,16 +855,21 @@ int main(int argc, char* argv[]) {

switch (llama_ver) {
case 2:
tokenizer = new BPETokenizer(transformer.config.vocab_size, /*bos*/ 1, /*eos*/ 2);
tokenizer =
new BPETokenizer(transformer.config.vocab_size, /*bos*/ 1, /*eos*/ 2);
tokenizer->load(tokenizer_path);
break;
case 3:
tokenizer = new Tiktoken(transformer.config.vocab_size, /*bos*/ 1, /*eos*/ 2);
tokenizer =
new Tiktoken(transformer.config.vocab_size, /*bos*/ 1, /*eos*/ 2);
tokenizer->load(tokenizer_path);
break;

default:
fprintf(stderr, "Cannot load tokenizer for unrecognized llama version %d", llama_ver);
fprintf(
stderr,
"Cannot load tokenizer for unrecognized llama version %d",
llama_ver);
exit(EXIT_FAILURE);
}

Expand Down
Loading