Skip to content

[llava][18/N] Move token generation loop to a class #4705

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 13, 2024
Merged
88 changes: 19 additions & 69 deletions examples/models/llama2/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ Runner::Runner(
}

bool Runner::is_loaded() const {
return module_->is_loaded() && tokenizer_ && text_decoder_runner_;
return module_->is_loaded() && tokenizer_ && text_decoder_runner_ &&
text_prefiller_ && text_token_generator_;
}

Error Runner::load() {
Expand Down Expand Up @@ -104,6 +105,13 @@ Error Runner::load() {
use_kv_cache_,
enable_parallel_prefill_);

text_token_generator_ = std::make_unique<TextTokenGenerator>(
tokenizer_.get(),
text_decoder_runner_.get(),
use_kv_cache_,
eos_id_,
&stats_);

return Error::Ok;
}

Expand Down Expand Up @@ -176,81 +184,19 @@ Error Runner::generate(
wrapped_callback(ET_UNWRAP(tokenizer_->decode(cur_token, cur_token)));

// start the main loop
int64_t pos = num_prompt_tokens; // position in the sequence

// Generate the rest of the sequence
std::vector<uint64_t> token_data; // allocate space for the tokens
std::vector<exec_aten::SizesType> token_shape;

if (use_kv_cache_) {
// hard code these to size 1 as kv cache is locked to static size right now.
token_data = {cur_token};
token_shape = {1, 1};
} else {
token_data = prompt_tokens;
token_data.push_back(cur_token);
token_shape = {1, num_prompt_tokens + 1};
}

// initialize tensor wrappers
ManagedTensor tokens_managed(
token_data.data(), token_shape, ScalarType::Long);

ManagedTensor start_pos_managed(&pos, {1}, ScalarType::Long);

uint64_t prev_token;

// Generate our tokens
while (pos < seq_len - 1) {
// Run the model
Result<exec_aten::Tensor> logits_res =
text_decoder_runner_->step(tokens_managed, start_pos_managed);
prompt_tokens.push_back(cur_token);
int64_t num_generated_tokens = ET_UNWRAP(text_token_generator_->generate(
prompt_tokens, num_prompt_tokens, seq_len, wrapped_callback));

ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error());
exec_aten::Tensor& logits_tensor = logits_res.get();

prev_token = cur_token;

long sample_start_time_ms = util::time_in_ms();
cur_token = text_decoder_runner_->logits_to_token(logits_tensor);
stats_.aggregate_sampling_time_ms +=
util::time_in_ms() - sample_start_time_ms;

pos++;

if (use_kv_cache_) {
// update the token tensor. token_data will not be empty.
// NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds)
token_data[0] = cur_token;
} else {
// push it to the back
token_data.push_back(cur_token);
tokens_managed.resize({1, static_cast<int>(token_data.size())});
}

// data-dependent terminating condition: we have n_eos_ number of EOS
if (pos >= num_prompt_tokens && cur_token == eos_id_) {
printf("\n");
ET_LOG(Info, "\nReached to the end of generation");
break;
}

// print the token as string, decode it with the Tokenizer object
wrapped_callback(ET_UNWRAP(tokenizer_->decode(prev_token, cur_token)));

if (shouldStop_) {
break;
}
}
stats_.inference_end_ms = util::time_in_ms();
printf("\n");

if (pos == seq_len) {
if (num_prompt_tokens + num_generated_tokens == seq_len) {
ET_LOG(Info, "Sequence length (%i tokens) reached!", seq_len);
}

stats_.num_prompt_tokens = num_prompt_tokens;
stats_.num_generated_tokens = pos - num_prompt_tokens;
stats_.num_generated_tokens = num_generated_tokens;
::executorch::llm::print_report(stats_);
if (stats_callback) {
stats_callback(stats_);
Expand All @@ -260,6 +206,10 @@ Error Runner::generate(
}

void Runner::stop() {
shouldStop_ = true;
if (is_loaded()) {
text_token_generator_->stop();
} else {
ET_LOG(Error, "Token generator is not loaded, cannot stop");
}
}
} // namespace torch::executor
2 changes: 2 additions & 0 deletions examples/models/llama2/runner/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <executorch/extension/llm/runner/stats.h>
#include <executorch/extension/llm/runner/text_decoder_runner.h>
#include <executorch/extension/llm/runner/text_prefiller.h>
#include <executorch/extension/llm/runner/text_token_generator.h>
#include <executorch/extension/llm/sampler/sampler.h>
#include <executorch/extension/llm/tokenizer/tokenizer.h>
#include <executorch/extension/module/module.h>
Expand Down Expand Up @@ -66,6 +67,7 @@ class Runner {
std::unique_ptr<Module> module_;
std::unique_ptr<TextDecoderRunner> text_decoder_runner_;
std::unique_ptr<TextPrefiller> text_prefiller_;
std::unique_ptr<TextTokenGenerator> text_token_generator_;
std::string tokenizer_path_;
std::unique_ptr<Tokenizer> tokenizer_;

Expand Down
1 change: 1 addition & 0 deletions examples/models/llama2/runner/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def define_common_targets():
"//executorch/extension/llm/runner:stats",
"//executorch/extension/llm/runner:text_decoder_runner" + aten_suffix,
"//executorch/extension/llm/runner:text_prefiller" + aten_suffix,
"//executorch/extension/llm/runner:text_token_generator" + aten_suffix,
"//executorch/extension/evalue_util:print_evalue" + aten_suffix,
"//executorch/extension/runner_util:managed_tensor" + aten_suffix,
"//executorch/extension/module:module" + aten_suffix,
Expand Down
16 changes: 14 additions & 2 deletions extension/llm/runner/stats.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

// Runner stats for LLM
#pragma once
#include <executorch/extension/llm/runner/util.h>
#include <executorch/runtime/platform/log.h>
#include <cinttypes>
#include <sstream>
// patternlint-disable-next-line executorch-cpp-nostdinc
#include <string>

#include <executorch/runtime/platform/log.h>
namespace executorch::llm {

struct Stats {
Expand All @@ -40,6 +40,18 @@ struct Stats {
int64_t num_prompt_tokens;
// Token count from generated (total - prompt)
int64_t num_generated_tokens;
inline void on_sampling_begin() {
aggregate_sampling_timer_start_timestamp =
::torch::executor::util::time_in_ms();
}
inline void on_sampling_end() {
aggregate_sampling_time_ms += ::torch::executor::util::time_in_ms() -
aggregate_sampling_timer_start_timestamp;
aggregate_sampling_timer_start_timestamp = 0;
}

private:
long aggregate_sampling_timer_start_timestamp = 0;
};

static constexpr auto kTopp = 0.9f;
Expand Down
14 changes: 14 additions & 0 deletions extension/llm/runner/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,20 @@ def define_common_targets():
],
)

runtime.cxx_library(
name = "text_token_generator" + aten_suffix,
exported_headers = ["text_token_generator.h"],
visibility = [
"@EXECUTORCH_CLIENTS",
],
exported_deps = [
":text_decoder_runner" + aten_suffix,
"//executorch/extension/llm/tokenizer:tokenizer_header",
"//executorch/extension/module:module" + aten_suffix,
"//executorch/extension/runner_util:managed_tensor" + aten_suffix,
],
)

runtime.cxx_library(
name = "metadata_util" + aten_suffix,
exported_headers = ["metadata_util.h"],
Expand Down
5 changes: 5 additions & 0 deletions extension/llm/runner/text_decoder_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ class TextDecoderRunner {
return module_->is_method_loaded(method_name);
}

inline void stop() {
should_stop_ = true;
}

/**
* Sample the next token from the logits tensor.
* @param logits_tensor The logits tensor.
Expand Down Expand Up @@ -90,6 +94,7 @@ class TextDecoderRunner {
Module* module_;
std::unique_ptr<Sampler> sampler_;
bool use_kv_cache_;
bool should_stop_{false};
};

} // namespace torch::executor
11 changes: 9 additions & 2 deletions extension/llm/runner/text_prefiller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,11 @@ Result<uint64_t> TextPrefiller::prefill(
// NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
uint64_t prev = prompt_tokens[0];
uint64_t cur;
for (int i = 1; i < prompt_tokens.size(); i++) {
for (int i = 0; i < prompt_tokens.size(); i++) {
cur = prompt_tokens[i];
token_callback(ET_UNWRAP(tokenizer_->decode(prev, cur)));
if (cur != tokenizer_->bos_tok()) {
token_callback(ET_UNWRAP(tokenizer_->decode(prev, cur)));
}
prev = cur;
}
cur_token = text_decoder_runner_->logits_to_token(outputs_res.get());
Expand All @@ -82,6 +84,11 @@ Result<uint64_t> TextPrefiller::prefill(
// is bos so don't callback.
exec_aten::Tensor logits_tensor = ET_UNWRAP(
text_decoder_runner_->step(managed_tokens, managed_start_pos));

// if first token is not bos, we need to callback
if (cur_token != tokenizer_->bos_tok()) {
token_callback(ET_UNWRAP(tokenizer_->decode(cur_token, cur_token)));
}
pos = 1; // start from index 1

while (pos < num_prompt_tokens) {
Expand Down
139 changes: 139 additions & 0 deletions extension/llm/runner/text_token_generator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

// Generate tokens in a loop.
#pragma once

#include <executorch/extension/llm/runner/stats.h>
#include <executorch/extension/llm/runner/text_decoder_runner.h>
#include <executorch/extension/llm/tokenizer/tokenizer.h>

namespace torch::executor {
using Stats = ::executorch::llm::Stats;

class TextTokenGenerator {
public:
TextTokenGenerator(
Tokenizer* tokenizer,
TextDecoderRunner* text_decoder_runner,
bool use_kv_cache,
uint64_t eos_id,
Stats* stats)
: tokenizer_(tokenizer),
text_decoder_runner_(text_decoder_runner),
eos_id_(eos_id),
use_kv_cache_(use_kv_cache),
stats_(stats) {}

/**
* Token generation loop.
* @param tokens prompt tokens as well as the first token generated by
* prefill.
* @param start_pos the start position of the new tokens, based on how many
* prompt tokens is prefilled.
* @param seq_len the total sequence length, including the prompt tokens, next
* token from prefill and new tokens.
* @param token_callback what to do after a token is generated.
* @return how many tokens are generated.
*/
inline Result<int64_t> generate(
std::vector<uint64_t> tokens,
int64_t start_pos,
int32_t seq_len,
std::function<void(const std::string&)> token_callback) {
ET_CHECK_MSG(
!tokens.empty(), "Token generation loop shouldn't take empty tokens");
int64_t pos = start_pos; // position in the sequence

std::vector<uint64_t> token_data; // allocate space for the tokens
std::vector<exec_aten::SizesType> token_shape;

// Token after prefill
uint64_t cur_token = tokens.back();
uint64_t prev_token;

if (use_kv_cache_) {
// hard code these to size 1 as kv cache is locked to static size right
// now.
token_data = {cur_token};
token_shape = {1, 1};
} else {
token_data = tokens;
token_shape = {1, static_cast<int>(tokens.size())};
}

// initialize tensor wrappers
ManagedTensor tokens_managed(
token_data.data(), token_shape, ScalarType::Long);

ManagedTensor start_pos_managed(&pos, {1}, ScalarType::Long);

// Generate our tokens
while (pos < seq_len - 1) {
// Run the model
Result<exec_aten::Tensor> logits_res =
text_decoder_runner_->step(tokens_managed, start_pos_managed);

ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error());
exec_aten::Tensor& logits_tensor = logits_res.get();

prev_token = cur_token;

stats_->on_sampling_begin();
cur_token = text_decoder_runner_->logits_to_token(logits_tensor);
stats_->on_sampling_end();

pos++;

if (use_kv_cache_) {
// update the token tensor. token_data will not be empty.
// NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds)
token_data[0] = cur_token;
} else {
// push it to the back
token_data.push_back(cur_token);
tokens_managed.resize({1, static_cast<int>(token_data.size())});
}

// print the token as string, decode it with the Tokenizer object
token_callback(ET_UNWRAP(tokenizer_->decode(prev_token, cur_token)));

if (should_stop_) {
break;
}

// data-dependent terminating condition: we have n_eos_ number of EOS
if (cur_token == eos_id_) {
printf("\n");
ET_LOG(Info, "\nReached to the end of generation");
break;
}
}
return pos - start_pos;
}

/**
* Stop the generation loop.
*/
inline void stop() {
should_stop_ = true;
}

private:
Tokenizer* tokenizer_;
TextDecoderRunner* text_decoder_runner_;
uint64_t eos_id_;
bool use_kv_cache_;

// state machine
bool should_stop_ = false;

// stats
Stats* stats_;
};
} // namespace torch::executor
Loading