Skip to content

Add seq_len to llama runner for early stopping #2051

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion examples/models/llama2/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ DEFINE_double(
0.8f,
"Temperature; Default is 0.8f. 0 = greedy argmax sampling (deterministic). Lower temperature = more deterministic");

DEFINE_int32(
seq_len,
128,
"Total number of tokens to generate (prompt + output). Defaults to max_seq_len. If the number of input tokens + seq_len > max_seq_len, the output will be truncated to max_seq_len tokens.");

int32_t main(int32_t argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);

Expand All @@ -38,11 +43,13 @@ int32_t main(int32_t argc, char** argv) {

double temperature = FLAGS_temperature;

int32_t seq_len = FLAGS_seq_len;

// create llama runner
::torch::executor::Runner runner(model_path, tokenizer_path, temperature);

// generate
runner.generate(prompt);
runner.generate(prompt, seq_len);

return 0;
}
23 changes: 17 additions & 6 deletions examples/models/llama2/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ T Runner::getMetadataHelper(std::string method_name, T default_val) {
return res;
}

std::vector<exec_aten::SizesType> Runner::getKVCacheShape() {
std::vector<exec_aten::SizesType> Runner::getKVCacheShape(int32_t seq_len) {
// shape: (n_layers, args.max_batch_size, args.max_seq_len, self.n_kv_heads,
// self.head_dim)
std::vector<std::string> methods = {
Expand All @@ -134,6 +134,9 @@ std::vector<exec_aten::SizesType> Runner::getKVCacheShape() {
// convert from int64_t to int32_t
result.push_back(getMetadataHelper<int64_t>(methods[i], default_values[i]));
}
// update seq_len if one is provided between 1 and max_seq_len
ET_CHECK_MSG(result.size() == 5, "KV cache shape must have 5 elements");
result[2] = (seq_len > 0 && seq_len <= result[2]) ? seq_len : result[2];
return result;
}

Expand All @@ -155,6 +158,7 @@ int32_t Runner::logitsToToken(

Error Runner::generate(
const std::string& prompt,
int32_t seq_len,
std::function<void(const std::string&)> callback) {
// Prepare the inputs.
// Use ones-initialized inputs.
Expand All @@ -168,6 +172,9 @@ Error Runner::generate(
// max # of prompt tokens: len(prompt) + '\0', ?BOS, ?EOS
int* prompt_tokens = new int[prompt.size() + 1 + n_bos_ + n_eos_];

// Set the sequence length to the max seq length if not provided
seq_len = (seq_len > 0 && seq_len <= max_seq_len_) ? seq_len : max_seq_len_;

tokenizer_->encode(
prompt.c_str(),
n_bos_,
Expand All @@ -182,6 +189,10 @@ Error Runner::generate(
num_prompt_tokens < max_seq_len_,
"Max seq length exceeded - please increase max seq len value in .../llama2/model.py");

ET_CHECK_MSG(
num_prompt_tokens < seq_len,
"Sequence length exceeded - please increase the seq_len value passed to generate()");

// start the main loop
long start =
0; // used to time our code, only initialized after first iteration
Expand All @@ -190,7 +201,7 @@ Error Runner::generate(
int token = prompt_tokens[pos]; // prefill starts from 0 to num_prompt_tokens
int eos_counter = 0; // counter to capture EOS
int logits_index = 0; // index of the logits tensor in the output
std::vector<exec_aten::SizesType> kv_cache_shape = getKVCacheShape();
std::vector<exec_aten::SizesType> kv_cache_shape = getKVCacheShape(seq_len);
std::vector<exec_aten::SizesType> input_shape = {1, 1};
std::vector<exec_aten::SizesType> pos_shape = {};
std::vector<uint8_t> k_data;
Expand All @@ -215,7 +226,7 @@ Error Runner::generate(
token_data.resize(1);
} else {
// reserve data for tokens, notice the size is still 0.
token_data.resize(max_seq_len_);
token_data.resize(seq_len);
}

// initialize tensor wrappers
Expand All @@ -235,7 +246,7 @@ Error Runner::generate(
}
}
// create a 1xN int tensor with next as value
while (pos < max_seq_len_) {
while (pos < seq_len) {
// ET_LOG(Info, "Generating step %d...", pos);
// set the current token in the tensor
std::vector<EValue> inputs;
Expand Down Expand Up @@ -348,8 +359,8 @@ Error Runner::generate(
}
printf("\n");

if (pos == max_seq_len_) {
ET_LOG(Info, "Maximum sequence length reached!");
if (pos == seq_len) {
ET_LOG(Info, "Sequence length (%i tokens) reached!", seq_len);
}
// report achieved tok/s (pos-1 because the timer starts after first
// iteration)
Expand Down
3 changes: 2 additions & 1 deletion examples/models/llama2/runner/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class Runner {
Error load();
Error generate(
const std::string& prompt,
int32_t seq_len = 128,
std::function<void(const std::string&)> callback = {});
void stop();

Expand All @@ -44,7 +45,7 @@ class Runner {
template <typename T>
int32_t
logitsToToken(const exec_aten::Tensor& logits_tensor, int64_t pos, T _);
std::vector<exec_aten::SizesType> getKVCacheShape();
std::vector<exec_aten::SizesType> getKVCacheShape(int32_t seq_len);
// metadata
int32_t vocab_size_;
int32_t bos_id_;
Expand Down