Skip to content

Commit 2378cda

Browse files
authored
Use a single map for model metadata.
Differential Revision: D61170117 Pull Request resolved: #4706
1 parent 1e9e5d0 commit 2378cda

File tree

3 files changed

+80
-78
lines changed

3 files changed

+80
-78
lines changed

examples/models/llama2/runner/runner.cpp

Lines changed: 77 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -10,29 +10,31 @@
1010
// The module takes in a string as input and emits a string as output.
1111

1212
#include <executorch/examples/models/llama2/runner/runner.h>
13+
14+
#include <ctime>
15+
16+
#include <executorch/extension/llm/runner/util.h>
17+
#include <executorch/extension/runner_util/managed_tensor.h>
18+
1319
#if ET_USE_TIKTOKEN
1420
#include <executorch/examples/models/llama2/tokenizer/llama_tiktoken.h>
1521
#else /* BPE */
1622
#include <executorch/extension/llm/tokenizer/bpe_tokenizer.h>
1723
#endif /* ET_USE_TIKTOKEN*/
18-
#include <executorch/extension/evalue_util/print_evalue.h>
19-
#include <executorch/extension/llm/runner/metadata_util.h>
20-
#include <executorch/extension/runner_util/managed_tensor.h>
21-
22-
#include <ctime>
23-
#include <memory>
24-
#include <sstream>
25-
26-
#ifdef USE_ATEN_LIB
27-
#include <torch/torch.h>
28-
#endif
29-
30-
#include <executorch/extension/llm/runner/util.h>
31-
#include <executorch/runtime/core/exec_aten/exec_aten.h>
32-
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
33-
#include <executorch/runtime/platform/log.h>
3424

3525
namespace torch::executor {
26+
namespace {
27+
static constexpr auto kAppendEosToPrompt = "append_eos_to_prompt";
28+
static constexpr auto kEnableDynamicShape = "enable_dynamic_shape";
29+
static constexpr auto kBosId = "get_bos_id";
30+
static constexpr auto kEosId = "get_eos_id";
31+
static constexpr auto kMaxSeqLen = "get_max_seq_len";
32+
static constexpr auto kNBos = "get_n_bos";
33+
static constexpr auto kNEos = "get_n_eos";
34+
static constexpr auto kVocabSize = "get_vocab_size";
35+
static constexpr auto kUseKVCache = "use_kv_cache";
36+
static constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";
37+
} // namespace
3638

3739
Runner::Runner(
3840
const std::string& model_path,
@@ -43,7 +45,23 @@ Runner::Runner(
4345
// FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
4446
: temperature_(temperature),
4547
module_(std::make_unique<Module>(model_path, Module::LoadMode::File)),
46-
tokenizer_path_(tokenizer_path) {
48+
tokenizer_path_(tokenizer_path),
49+
tokenizer_(
50+
#if ET_USE_TIKTOKEN
51+
get_tiktoken_for_llama()
52+
#else
53+
std::make_unique<BPETokenizer>()
54+
#endif
55+
),
56+
metadata_({
57+
{kAppendEosToPrompt, false},
58+
{kEnableDynamicShape, false},
59+
{kMaxSeqLen, 128},
60+
{kNBos, 1},
61+
{kNEos, 1},
62+
{kUseKVCache, true},
63+
{kUseSDPAWithKVCache, false},
64+
}) {
4765
ET_LOG(
4866
Info,
4967
"Creating LLaMa runner: model_path=%s, tokenizer_path=%s",
@@ -62,54 +80,49 @@ Error Runner::load() {
6280
}
6381
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward"));
6482

65-
// Read out metadata: vocab_size (expected by the model), BOS, EOS, n_BOS,
66-
// n_EOS max_seq_len from the model
67-
ET_LOG(Info, "Reading metadata from model");
68-
const auto method_names = module_->method_names();
69-
ET_CHECK_MSG(method_names.ok(), "Failed to read method names from model");
70-
model_methods_ = method_names.get();
71-
n_bos_ = get_module_metadata<int64_t>(module_.get(), "get_n_bos", 1);
72-
n_eos_ = get_module_metadata<int64_t>(module_.get(), "get_n_eos", 1);
73-
max_seq_len_ =
74-
get_module_metadata<int64_t>(module_.get(), "get_max_seq_len", 128);
75-
use_kv_cache_ = get_module_metadata(module_.get(), "use_kv_cache", true);
76-
use_sdpa_with_kv_cache_ =
77-
get_module_metadata(module_.get(), "use_sdpa_with_kv_cache", false);
78-
append_eos_ =
79-
get_module_metadata(module_.get(), "append_eos_to_prompt", false);
80-
enable_parallel_prefill_ =
81-
get_module_metadata(module_.get(), "enable_dynamic_shape", false);
82-
83-
// Load tokenizer
84-
#if ET_USE_TIKTOKEN
85-
tokenizer_ = get_tiktoken_for_llama();
86-
#else
87-
tokenizer_ = std::make_unique<BPETokenizer>();
88-
#endif
8983
tokenizer_->load(tokenizer_path_);
9084

91-
vocab_size_ = get_module_metadata<int64_t>(
92-
module_.get(), "get_vocab_size", tokenizer_->vocab_size());
93-
bos_id_ = get_module_metadata<int64_t>(
94-
module_.get(), "get_bos_id", tokenizer_->bos_tok());
95-
eos_id_ = get_module_metadata<int64_t>(
96-
module_.get(), "get_eos_id", tokenizer_->eos_tok());
85+
ET_LOG(Info, "Reading metadata from model");
9786

98-
// Create text decoder runner and prefiller
87+
metadata_[kBosId] = tokenizer_->bos_tok();
88+
metadata_[kEosId] = tokenizer_->eos_tok();
89+
metadata_[kVocabSize] = tokenizer_->vocab_size();
90+
91+
const auto method_names =
92+
ET_UNWRAP(module_->method_names(), "Failed reading method names");
93+
94+
for (auto& pair : metadata_) {
95+
const auto& method_name = pair.first;
96+
auto& value = pair.second;
97+
98+
if (method_names.count(method_name)) {
99+
value = ET_UNWRAP(module_->get(method_name))
100+
.toScalar()
101+
.to<decltype(metadata_)::mapped_type>();
102+
} else {
103+
ET_LOG(
104+
Info,
105+
"Methond %s not found, using the default value %" PRId64,
106+
method_name.c_str(),
107+
value);
108+
}
109+
}
99110
text_decoder_runner_ = std::make_unique<TextDecoderRunner>(
100-
module_.get(), use_kv_cache_, vocab_size_, temperature_);
101-
111+
module_.get(),
112+
metadata_.at(kUseKVCache),
113+
metadata_.at(kVocabSize),
114+
temperature_);
102115
text_prefiller_ = std::make_unique<TextPrefiller>(
103116
tokenizer_.get(),
104117
text_decoder_runner_.get(),
105-
use_kv_cache_,
118+
metadata_.at(kUseKVCache),
106119
enable_parallel_prefill_);
107120

108121
text_token_generator_ = std::make_unique<TextTokenGenerator>(
109122
tokenizer_.get(),
110123
text_decoder_runner_.get(),
111-
use_kv_cache_,
112-
eos_id_,
124+
metadata_.at(kUseKVCache),
125+
metadata_.at(kEosId),
113126
&stats_);
114127

115128
return Error::Ok;
@@ -145,10 +158,14 @@ Error Runner::generate(
145158
shouldStop_ = false;
146159

147160
// Set the sequence length to the max seq length if not provided
148-
seq_len = (seq_len > 0 && seq_len <= max_seq_len_) ? seq_len : max_seq_len_;
161+
seq_len = (seq_len > 0 && seq_len <= metadata_.at(kMaxSeqLen))
162+
? seq_len
163+
: metadata_.at(kMaxSeqLen);
149164

150-
Result<std::vector<uint64_t>> encode_res =
151-
tokenizer_->encode(prompt, n_bos_, append_eos_ ? n_eos_ : 0);
165+
Result<std::vector<uint64_t>> encode_res = tokenizer_->encode(
166+
prompt,
167+
metadata_.at(kNBos),
168+
metadata_.at(kAppendEosToPrompt) ? metadata_.at(kNEos) : 0);
152169

153170
ET_CHECK_OK_OR_RETURN_ERROR(
154171
encode_res.error(), "Failed to encode prompt %s", prompt.c_str());
@@ -159,11 +176,11 @@ Error Runner::generate(
159176

160177
ET_CHECK_MSG(num_prompt_tokens >= 1, "Expected at least 1 prompt token");
161178
ET_CHECK_MSG(
162-
num_prompt_tokens < max_seq_len_,
163-
"num_prompt_tokens %d >= max_seq_len_ %d, Max seq length exceeded - please increase max seq len value in .../llama2/model.py",
179+
num_prompt_tokens < metadata_.at(kMaxSeqLen),
180+
"num_prompt_tokens %d >= max_seq_len_ %" PRId64
181+
", Max seq length exceeded - please increase max seq len value in .../llama2/model.py",
164182
num_prompt_tokens,
165-
max_seq_len_);
166-
183+
metadata_.at(kMaxSeqLen));
167184
ET_CHECK_MSG(
168185
num_prompt_tokens < seq_len,
169186
"num_prompt_tokens %d >= seq_len %d, Sequence length exceeded - please increase the seq_len value passed to generate()",

examples/models/llama2/runner/runner.h

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,14 @@
1515
#include <functional>
1616
#include <memory>
1717
#include <string>
18-
#include <type_traits>
1918
#include <unordered_map>
2019

2120
#include <executorch/extension/llm/runner/stats.h>
2221
#include <executorch/extension/llm/runner/text_decoder_runner.h>
2322
#include <executorch/extension/llm/runner/text_prefiller.h>
2423
#include <executorch/extension/llm/runner/text_token_generator.h>
25-
#include <executorch/extension/llm/sampler/sampler.h>
2624
#include <executorch/extension/llm/tokenizer/tokenizer.h>
2725
#include <executorch/extension/module/module.h>
28-
#include <executorch/extension/runner_util/managed_tensor.h>
2926

3027
namespace torch::executor {
3128
using Stats = ::executorch::llm::Stats;
@@ -47,29 +44,18 @@ class Runner {
4744
void stop();
4845

4946
private:
50-
// metadata
51-
int32_t vocab_size_;
52-
int32_t bos_id_;
53-
int32_t eos_id_;
54-
int32_t n_bos_;
55-
int32_t n_eos_;
56-
int32_t max_seq_len_;
57-
bool use_kv_cache_;
58-
bool use_sdpa_with_kv_cache_;
59-
bool append_eos_;
6047
float temperature_;
6148
bool enable_parallel_prefill_;
6249
bool shouldStop_{false};
6350

6451
// model
65-
std::unordered_set<std::string> model_methods_;
66-
std::string model_path_;
6752
std::unique_ptr<Module> module_;
53+
std::string tokenizer_path_;
54+
std::unique_ptr<Tokenizer> tokenizer_;
55+
std::unordered_map<std::string, int64_t> metadata_;
6856
std::unique_ptr<TextDecoderRunner> text_decoder_runner_;
6957
std::unique_ptr<TextPrefiller> text_prefiller_;
7058
std::unique_ptr<TextTokenGenerator> text_token_generator_;
71-
std::string tokenizer_path_;
72-
std::unique_ptr<Tokenizer> tokenizer_;
7359

7460
// stats
7561
Stats stats_;

examples/models/llama2/runner/targets.bzl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ def define_common_targets():
3232
],
3333
exported_deps = [
3434
"//executorch/backends/xnnpack:xnnpack_backend",
35-
"//executorch/extension/llm/runner:metadata_util" + aten_suffix,
3635
"//executorch/extension/llm/runner:stats",
3736
"//executorch/extension/llm/runner:text_decoder_runner" + aten_suffix,
3837
"//executorch/extension/llm/runner:text_prefiller" + aten_suffix,

0 commit comments

Comments
 (0)