Skip to content

Use a single map for model metadata. #4706

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 77 additions & 60 deletions examples/models/llama2/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,31 @@
// The module takes in a string as input and emits a string as output.

#include <executorch/examples/models/llama2/runner/runner.h>

#include <ctime>

#include <executorch/extension/llm/runner/util.h>
#include <executorch/extension/runner_util/managed_tensor.h>

#if ET_USE_TIKTOKEN
#include <executorch/examples/models/llama2/tokenizer/llama_tiktoken.h>
#else /* BPE */
#include <executorch/extension/llm/tokenizer/bpe_tokenizer.h>
#endif /* ET_USE_TIKTOKEN*/
#include <executorch/extension/evalue_util/print_evalue.h>
#include <executorch/extension/llm/runner/metadata_util.h>
#include <executorch/extension/runner_util/managed_tensor.h>

#include <ctime>
#include <memory>
#include <sstream>

#ifdef USE_ATEN_LIB
#include <torch/torch.h>
#endif

#include <executorch/extension/llm/runner/util.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
#include <executorch/runtime/platform/log.h>

namespace torch::executor {
namespace {
static constexpr auto kAppendEosToPrompt = "append_eos_to_prompt";
static constexpr auto kEnableDynamicShape = "enable_dynamic_shape";
static constexpr auto kBosId = "get_bos_id";
static constexpr auto kEosId = "get_eos_id";
static constexpr auto kMaxSeqLen = "get_max_seq_len";
static constexpr auto kNBos = "get_n_bos";
static constexpr auto kNEos = "get_n_eos";
static constexpr auto kVocabSize = "get_vocab_size";
static constexpr auto kUseKVCache = "use_kv_cache";
static constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";
} // namespace

Runner::Runner(
const std::string& model_path,
Expand All @@ -43,7 +45,23 @@ Runner::Runner(
// FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
: temperature_(temperature),
module_(std::make_unique<Module>(model_path, Module::LoadMode::File)),
tokenizer_path_(tokenizer_path) {
tokenizer_path_(tokenizer_path),
tokenizer_(
#if ET_USE_TIKTOKEN
get_tiktoken_for_llama()
#else
std::make_unique<BPETokenizer>()
#endif
),
metadata_({
{kAppendEosToPrompt, false},
{kEnableDynamicShape, false},
{kMaxSeqLen, 128},
{kNBos, 1},
{kNEos, 1},
{kUseKVCache, true},
{kUseSDPAWithKVCache, false},
}) {
ET_LOG(
Info,
"Creating LLaMa runner: model_path=%s, tokenizer_path=%s",
Expand All @@ -62,54 +80,49 @@ Error Runner::load() {
}
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward"));

// Read out metadata: vocab_size (expected by the model), BOS, EOS, n_BOS,
// n_EOS max_seq_len from the model
ET_LOG(Info, "Reading metadata from model");
const auto method_names = module_->method_names();
ET_CHECK_MSG(method_names.ok(), "Failed to read method names from model");
model_methods_ = method_names.get();
n_bos_ = get_module_metadata<int64_t>(module_.get(), "get_n_bos", 1);
n_eos_ = get_module_metadata<int64_t>(module_.get(), "get_n_eos", 1);
max_seq_len_ =
get_module_metadata<int64_t>(module_.get(), "get_max_seq_len", 128);
use_kv_cache_ = get_module_metadata(module_.get(), "use_kv_cache", true);
use_sdpa_with_kv_cache_ =
get_module_metadata(module_.get(), "use_sdpa_with_kv_cache", false);
append_eos_ =
get_module_metadata(module_.get(), "append_eos_to_prompt", false);
enable_parallel_prefill_ =
get_module_metadata(module_.get(), "enable_dynamic_shape", false);

// Load tokenizer
#if ET_USE_TIKTOKEN
tokenizer_ = get_tiktoken_for_llama();
#else
tokenizer_ = std::make_unique<BPETokenizer>();
#endif
tokenizer_->load(tokenizer_path_);

vocab_size_ = get_module_metadata<int64_t>(
module_.get(), "get_vocab_size", tokenizer_->vocab_size());
bos_id_ = get_module_metadata<int64_t>(
module_.get(), "get_bos_id", tokenizer_->bos_tok());
eos_id_ = get_module_metadata<int64_t>(
module_.get(), "get_eos_id", tokenizer_->eos_tok());
ET_LOG(Info, "Reading metadata from model");

// Create text decoder runner and prefiller
metadata_[kBosId] = tokenizer_->bos_tok();
metadata_[kEosId] = tokenizer_->eos_tok();
metadata_[kVocabSize] = tokenizer_->vocab_size();

const auto method_names =
ET_UNWRAP(module_->method_names(), "Failed reading method names");

for (auto& pair : metadata_) {
const auto& method_name = pair.first;
auto& value = pair.second;

if (method_names.count(method_name)) {
value = ET_UNWRAP(module_->get(method_name))
.toScalar()
.to<decltype(metadata_)::mapped_type>();
} else {
ET_LOG(
Info,
"Methond %s not found, using the default value %" PRId64,
method_name.c_str(),
value);
}
}
text_decoder_runner_ = std::make_unique<TextDecoderRunner>(
module_.get(), use_kv_cache_, vocab_size_, temperature_);

module_.get(),
metadata_.at(kUseKVCache),
metadata_.at(kVocabSize),
temperature_);
text_prefiller_ = std::make_unique<TextPrefiller>(
tokenizer_.get(),
text_decoder_runner_.get(),
use_kv_cache_,
metadata_.at(kUseKVCache),
enable_parallel_prefill_);

text_token_generator_ = std::make_unique<TextTokenGenerator>(
tokenizer_.get(),
text_decoder_runner_.get(),
use_kv_cache_,
eos_id_,
metadata_.at(kUseKVCache),
metadata_.at(kEosId),
&stats_);

return Error::Ok;
Expand Down Expand Up @@ -145,10 +158,14 @@ Error Runner::generate(
shouldStop_ = false;

// Set the sequence length to the max seq length if not provided
seq_len = (seq_len > 0 && seq_len <= max_seq_len_) ? seq_len : max_seq_len_;
seq_len = (seq_len > 0 && seq_len <= metadata_.at(kMaxSeqLen))
? seq_len
: metadata_.at(kMaxSeqLen);

Result<std::vector<uint64_t>> encode_res =
tokenizer_->encode(prompt, n_bos_, append_eos_ ? n_eos_ : 0);
Result<std::vector<uint64_t>> encode_res = tokenizer_->encode(
prompt,
metadata_.at(kNBos),
metadata_.at(kAppendEosToPrompt) ? metadata_.at(kNEos) : 0);

ET_CHECK_OK_OR_RETURN_ERROR(
encode_res.error(), "Failed to encode prompt %s", prompt.c_str());
Expand All @@ -159,11 +176,11 @@ Error Runner::generate(

ET_CHECK_MSG(num_prompt_tokens >= 1, "Expected at least 1 prompt token");
ET_CHECK_MSG(
num_prompt_tokens < max_seq_len_,
"num_prompt_tokens %d >= max_seq_len_ %d, Max seq length exceeded - please increase max seq len value in .../llama2/model.py",
num_prompt_tokens < metadata_.at(kMaxSeqLen),
"num_prompt_tokens %d >= max_seq_len_ %" PRId64
", Max seq length exceeded - please increase max seq len value in .../llama2/model.py",
num_prompt_tokens,
max_seq_len_);

metadata_.at(kMaxSeqLen));
ET_CHECK_MSG(
num_prompt_tokens < seq_len,
"num_prompt_tokens %d >= seq_len %d, Sequence length exceeded - please increase the seq_len value passed to generate()",
Expand Down
20 changes: 3 additions & 17 deletions examples/models/llama2/runner/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,14 @@
#include <functional>
#include <memory>
#include <string>
#include <type_traits>
#include <unordered_map>

#include <executorch/extension/llm/runner/stats.h>
#include <executorch/extension/llm/runner/text_decoder_runner.h>
#include <executorch/extension/llm/runner/text_prefiller.h>
#include <executorch/extension/llm/runner/text_token_generator.h>
#include <executorch/extension/llm/sampler/sampler.h>
#include <executorch/extension/llm/tokenizer/tokenizer.h>
#include <executorch/extension/module/module.h>
#include <executorch/extension/runner_util/managed_tensor.h>

namespace torch::executor {
using Stats = ::executorch::llm::Stats;
Expand All @@ -47,29 +44,18 @@ class Runner {
void stop();

private:
// metadata
int32_t vocab_size_;
int32_t bos_id_;
int32_t eos_id_;
int32_t n_bos_;
int32_t n_eos_;
int32_t max_seq_len_;
bool use_kv_cache_;
bool use_sdpa_with_kv_cache_;
bool append_eos_;
float temperature_;
bool enable_parallel_prefill_;
bool shouldStop_{false};

// model
std::unordered_set<std::string> model_methods_;
std::string model_path_;
std::unique_ptr<Module> module_;
std::string tokenizer_path_;
std::unique_ptr<Tokenizer> tokenizer_;
std::unordered_map<std::string, int64_t> metadata_;
std::unique_ptr<TextDecoderRunner> text_decoder_runner_;
std::unique_ptr<TextPrefiller> text_prefiller_;
std::unique_ptr<TextTokenGenerator> text_token_generator_;
std::string tokenizer_path_;
std::unique_ptr<Tokenizer> tokenizer_;

// stats
Stats stats_;
Expand Down
1 change: 0 additions & 1 deletion examples/models/llama2/runner/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def define_common_targets():
],
exported_deps = [
"//executorch/backends/xnnpack:xnnpack_backend",
"//executorch/extension/llm/runner:metadata_util" + aten_suffix,
"//executorch/extension/llm/runner:stats",
"//executorch/extension/llm/runner:text_decoder_runner" + aten_suffix,
"//executorch/extension/llm/runner:text_prefiller" + aten_suffix,
Expand Down
Loading