Skip to content

refactor runner to fix bugs #2752

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 137 additions & 94 deletions examples/models/llama2/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
// The module takes in a string as input and emits a string as output.

#include <executorch/examples/models/llama2/runner/runner.h>
#include <executorch/extension/evalue_util/print_evalue.h>
#include <executorch/extension/runner_util/managed_tensor.h>

#include <ctime>
Expand Down Expand Up @@ -121,24 +122,6 @@ T Runner::getMetadataHelper(std::string method_name, T default_val) {
return res;
}

std::vector<exec_aten::SizesType> Runner::getKVCacheShape() {
// shape: (n_layers, args.max_batch_size, args.max_seq_len, self.n_kv_heads,
// self.head_dim)
std::vector<std::string> methods = {
"get_n_layers",
"get_max_batch_size",
"get_max_seq_len",
"get_n_kv_heads",
"get_head_dim"};
std::vector<int64_t> default_values = {12, 1, 128, 32, 128};
std::vector<exec_aten::SizesType> result;
for (int i = 0; i < methods.size(); ++i) {
// convert from int64_t to int32_t
result.push_back(getMetadataHelper<int64_t>(methods[i], default_values[i]));
}
return result;
}

template <typename T>
int32_t Runner::logitsToToken(
const exec_aten::Tensor& logits_tensor,
Expand All @@ -155,6 +138,73 @@ int32_t Runner::logitsToToken(
return sampler_->sample(logits_last);
}

// Given an input token. Set up the inputs for the model and execute a single
// step. Returning the logits tensor.
Result<torch::executor::Tensor> Runner::run_model_step(
int64_t input_token,
ManagedTensor& managed_tokens,
ManagedTensor& managed_start_pos,
size_t max_seq_len) {
// ET_LOG(Info, "Input token %" PRIu64, input_token);
if (use_kv_cache_) {
std::vector<EValue> inputs;
auto tokens = managed_tokens.get_aliasing_tensor();
auto start_pos = managed_start_pos.get_aliasing_tensor();

// When using kv-cache our input is always 1 token, so just update to the
// latest.
tokens.mutable_data_ptr<int64_t>()[0] = input_token;

// inputs:[tokens, start_pos]
inputs.push_back(tokens);
inputs.push_back(start_pos);

Result<std::vector<EValue>> outputs_res = module_->forward(inputs);
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
ET_CHECK_MSG(
outputs_res.get().size() == 1,
"More then one output returned from executing LLM.");
ET_CHECK_MSG(
outputs_res.get()[0].isTensor(),
"Non Tensor Output returned from executing LLM");

// Bump start_pos by 1
start_pos.mutable_data_ptr<int64_t>()[0]++;

// Return the logits tensor
return outputs_res.get()[0].toTensor();
} else { // no kv cache
std::vector<EValue> inputs;
auto tokens = managed_tokens.get_aliasing_tensor();
(void)managed_start_pos; // unused

// When not using kv-cache our input is the entire history of tokens we have
// seen, so resize input to be 1 larger and append the new token to the end.
// TODO does this work in ATen mode?
tokens.mutable_data_ptr<int64_t>()[tokens.size(1) - 1] = input_token;

// inputs:[tokens]
inputs.push_back(tokens);

Result<std::vector<EValue>> outputs_res = module_->forward(inputs);
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
ET_CHECK_MSG(
outputs_res.get().size() == 1,
"More then one output returned from executing LLM.");
ET_CHECK_MSG(
outputs_res.get()[0].isTensor(),
"Non Tensor Output returned from executing LLM");

if (tokens.size(1) < max_seq_len) {
// Resize the tokens tensor to be 1 larger for next step.
managed_tokens.resize({1, static_cast<int>(tokens.size(1) + 1)});
}

// Return the logits tensor
return outputs_res.get()[0].toTensor();
}
}

Error Runner::generate(
const std::string& prompt,
int32_t seq_len,
Expand Down Expand Up @@ -189,9 +239,6 @@ Error Runner::generate(
prompt_tokens,
&num_prompt_tokens);

for (int i = 0; i < num_prompt_tokens; i++) {
ET_LOG(Info, "prompt_tokens[%d]: %d", i, prompt_tokens[i]);
}
ET_CHECK_MSG(num_prompt_tokens >= 1, "Expected at least 1 prompt token");
ET_CHECK_MSG(
num_prompt_tokens < max_seq_len_,
Expand All @@ -202,89 +249,94 @@ Error Runner::generate(
"Sequence length exceeded - please increase the seq_len value passed to generate()");

// start the main loop
int next; // will store the next token in the sequence
int64_t pos = num_prompt_tokens - 1; // position in the sequence
int token = prompt_tokens[pos]; // prefill starts from 0 to num_prompt_tokens
int logits_index = 0; // index of the logits tensor in the output
std::vector<exec_aten::SizesType> input_shape = {1, 1};
std::vector<exec_aten::SizesType> pos_shape = {1};
int64_t pos = 0; // position in the sequence

std::vector<int64_t> token_data; // allocate space for the tokens
std::vector<int64_t> pos_data; // allocate space for the tokens
std::vector<exec_aten::SizesType> token_shape = {1, seq_len};

std::vector<int64_t> start_pos_data; // allocate space for the tokens
std::vector<exec_aten::SizesType> start_pos_shape = {1};

if (use_kv_cache_) {
// set pos to 0, refill token by token
pos = 0;
// hard code these to size 1 as kv cache is locked to static size right now.
token_data.resize(1);
pos_data.resize(seq_len);
token_shape[1] = 1;
start_pos_data.resize(1);
start_pos_data.push_back(0);
} else {
// reserve data for tokens, notice the size is still 0.
// reserve data for tokens, notice the size is still 0 but the capacity is
// seq_len.
token_data.resize(seq_len);
}

// initialize tensor wrappers
ManagedTensor pos_managed(
pos_data.data(), pos_data.size(), pos_shape, ScalarType::Long);

// copy prompt tokens into data
for (int i = 0; i <= pos; ++i) {
// @lint-ignore CLANGTIDY facebook-hte-LocalUncheckedArrayBounds
token_data[i] = prompt_tokens[i];
if (i > 0) {
printf(
"%s",
ET_UNWRAP(
tokenizer_->decode(prompt_tokens[i - 1], prompt_tokens[i])));
ManagedTensor tokens_managed(
token_data.data(),
128, // TODO clean up unused 128 here as ManagedTensor ignores this arg in
// ctor
token_shape,
ScalarType::Long);
// Create with the max shape to approapriately set the capacity of this
// tensor, then resize back to 1 for first input.
tokens_managed.resize({1, 1});

ManagedTensor start_pos_managed(
start_pos_data.data(), 128, start_pos_shape, ScalarType::Long);

int64_t prev_token;
int64_t cur_token = prompt_tokens[0];

// If we arent using the kv cache then we can batch prefill the prompt
if (!use_kv_cache_) {
tokens_managed.resize({1, num_prompt_tokens});
for (int i = 0; i < num_prompt_tokens - 1; i++) {
tokens_managed.get_aliasing_tensor().mutable_data_ptr<int64_t>()[i] =
prompt_tokens[i];
}
// prefill tokens up to the last prompt token and then enter the loop with
// the last promp token as the current token.
cur_token = prompt_tokens[num_prompt_tokens - 1];
pos = num_prompt_tokens - 1;

// Print the prompt for consistent output between single token prefill and
// batch prefill.
int prev = prompt_tokens[0];
int cur;
for (int i = 1; i < num_prompt_tokens; i++) {
cur = prompt_tokens[i];
auto piece_res = tokenizer_->decode(prev, cur);
ET_CHECK_OK_OR_RETURN_ERROR(piece_res.error());
util::safe_printf(piece_res.get());
fflush(stdout);
prev = cur;
}
}

// create a 1xN int tensor with next as value
while (pos + 1 < seq_len) {
// ET_LOG(Info, "Generating step %d...", pos);
// set the current token in the tensor
std::vector<EValue> inputs;
if (use_kv_cache_) {
token_data[0] = token;
input_shape[1] = 1;
// inputs: [tokens, start_pos, k_cache, v_cache]
inputs.emplace_back(pos_managed.get_aliasing_tensor());
} else {
// @lint-ignore CLANGTIDY facebook-hte-LocalUncheckedArrayBounds
token_data[pos] = token;
input_shape[1] = pos + 1;
}
ManagedTensor token_managed(
token_data.data(), token_data.size(), input_shape, ScalarType::Long);
inputs.insert(inputs.begin(), token_managed.get_aliasing_tensor());
// For kv cache, inputs: [tokens, start_pos, k_cache, v_cache]
// Otherwise inputs: [tokens]
Result<std::vector<EValue>> outputs_res = module_->forward(inputs);
ET_CHECK_MSG(
outputs_res.ok(),
"Execution of method forward failed with status 0x%" PRIx32,
static_cast<int32_t>(outputs_res.error()));
// ET_LOG(Info, "Model executed successfully.");
// Generate our tokens
while (pos < seq_len - 1) {
// Run the model
Result<torch::executor::Tensor> logits_res =
run_model_step(cur_token, tokens_managed, start_pos_managed, seq_len);

std::vector<EValue> outputs = outputs_res.get();
// Check the outputs.
ET_CHECK_MSG(
outputs.size() == 1 && outputs.at(0).isTensor(),
"Expecting output to have exactly 1 tensor output. Got %zu outputs.",
outputs.size());
if (pos == num_prompt_tokens) {
timers_.first_token_ms = util::time_in_ms();
} else if (pos == num_prompt_tokens - 1) {
timers_.prompt_eval_end_ms = util::time_in_ms();
}
int32_t next_tok;
exec_aten::Tensor logits_tensor = outputs.at(logits_index).toTensor();

ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error());
exec_aten::Tensor& logits_tensor = logits_res.get();

prev_token = cur_token;

long sample_start_time_ms = util::time_in_ms();
switch (logits_tensor.scalar_type()) {
case ScalarType::Float: {
next_tok = logitsToToken<float>(logits_tensor, pos, 0);
cur_token = logitsToToken<float>(logits_tensor, pos, 0);
break;
}
case ScalarType::Half: {
next_tok = logitsToToken<exec_aten::Half>(logits_tensor, pos, 0);
cur_token = logitsToToken<exec_aten::Half>(logits_tensor, pos, 0);
break;
}
default:
Expand All @@ -299,19 +351,12 @@ Error Runner::generate(
// advance the state machine
if (pos < num_prompt_tokens - 1) {
// prefill, force the next token to be the next prompt token
next = prompt_tokens[pos + 1];
} else {
// otherwise sample the next token from the logits
next = next_tok;
cur_token = prompt_tokens[pos + 1];
}
// ET_LOG(Info, "Output saved, next = %d", next);
pos++;
if (use_kv_cache_) {
pos_data.at(0) = pos;
}

// print the token as string, decode it with the Tokenizer object
auto piece_res = tokenizer_->decode(token, next);
auto piece_res = tokenizer_->decode(prev_token, cur_token);
ET_CHECK(piece_res.ok());
const char* piece = piece_res.get();

Expand All @@ -328,22 +373,20 @@ Error Runner::generate(
}

// data-dependent terminating condition: we have n_eos_ number of EOS
if (pos >= num_prompt_tokens && next == eos_id_) {
if (pos >= num_prompt_tokens && cur_token == eos_id_) {
printf("\n");
ET_LOG(Info, "\nReached to the end of generation");
break;
}

token = next;
}
timers_.inference_end_ms = util::time_in_ms();
printf("\n");

if (pos + 1 == seq_len) {
if (pos == seq_len) {
ET_LOG(Info, "Sequence length (%i tokens) reached!", seq_len);
}

timers_.printReport(num_prompt_tokens, (pos + 1) - num_prompt_tokens);
timers_.printReport(num_prompt_tokens, pos - num_prompt_tokens);

delete[] prompt_tokens;
return Error::Ok;
Expand Down
7 changes: 6 additions & 1 deletion examples/models/llama2/runner/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <executorch/examples/models/llama2/sampler/sampler.h>
#include <executorch/examples/models/llama2/tokenizer/tokenizer.h>
#include <executorch/extension/module/module.h>
#include <executorch/extension/runner_util/managed_tensor.h>

namespace torch::executor {

Expand All @@ -45,7 +46,11 @@ class Runner {
template <typename T>
int32_t
logitsToToken(const exec_aten::Tensor& logits_tensor, int64_t pos, T _);
std::vector<exec_aten::SizesType> getKVCacheShape();
Result<torch::executor::Tensor> run_model_step(
int64_t input_token,
ManagedTensor& tokens,
ManagedTensor& start_pos,
size_t max_seq_len);
// metadata
int32_t vocab_size_;
int32_t bos_id_;
Expand Down
2 changes: 2 additions & 0 deletions examples/models/llama2/runner/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def define_common_targets():
"//executorch/extension/runner_util:managed_tensor" + aten_suffix,
"//executorch/extension/module:module" + aten_suffix,
"//executorch/kernels/quantized:generated_lib" + aten_suffix,
"//executorch/runtime/core/exec_aten:lib" + aten_suffix,
"//executorch/runtime/core/exec_aten/util:tensor_util" + aten_suffix,
] + (_get_operator_lib(aten)) + ([
# Vulkan API currently cannot build on some platforms (e.g. Apple, FBCODE)
# Therefore enable it explicitly for now to avoid failing tests
Expand Down
15 changes: 14 additions & 1 deletion extension/runner_util/managed_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
#include <executorch/runtime/platform/assert.h>

#ifdef USE_ATEN_LIB
#include <torch/torch.h>
#else
Expand Down Expand Up @@ -56,10 +59,20 @@ class ManagedTensor {
data_ptr_,
dim_order_.data(),
strides_.data(),
TensorShapeDynamism::STATIC);
TensorShapeDynamism::DYNAMIC_BOUND);
#endif
}

void resize(const std::vector<SizesType>& new_sizes) {
ET_CHECK_MSG(
new_sizes.size() == sizes_.size(),
"Cannot change rank of a managed tensor");
auto err = resize_tensor(
this->get_aliasing_tensor(),
exec_aten::ArrayRef<SizesType>(new_sizes.data(), new_sizes.size()));
ET_CHECK(err == Error::Ok);
}

/**
* Get the underlying Tensor object. This is assuming the copying is cheap.
*/
Expand Down
1 change: 1 addition & 0 deletions extension/runner_util/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,6 @@ def define_common_targets():
],
deps = [
"//executorch/runtime/core/exec_aten:lib" + aten_suffix,
"//executorch/runtime/core/exec_aten/util:tensor_util" + aten_suffix,
],
)