10
10
// The module takes in a string as input and emits a string as output.
11
11
12
12
#include < executorch/examples/models/llama2/runner/runner.h>
13
+
14
+ #include < ctime>
15
+
16
+ #include < executorch/extension/llm/runner/util.h>
17
+ #include < executorch/extension/runner_util/managed_tensor.h>
18
+
13
19
#if ET_USE_TIKTOKEN
14
20
#include < executorch/examples/models/llama2/tokenizer/llama_tiktoken.h>
15
21
#else /* BPE */
16
22
#include < executorch/extension/llm/tokenizer/bpe_tokenizer.h>
17
23
#endif /* ET_USE_TIKTOKEN*/
18
- #include < executorch/extension/evalue_util/print_evalue.h>
19
- #include < executorch/extension/llm/runner/metadata_util.h>
20
- #include < executorch/extension/runner_util/managed_tensor.h>
21
-
22
- #include < ctime>
23
- #include < memory>
24
- #include < sstream>
25
-
26
- #ifdef USE_ATEN_LIB
27
- #include < torch/torch.h>
28
- #endif
29
-
30
- #include < executorch/extension/llm/runner/util.h>
31
- #include < executorch/runtime/core/exec_aten/exec_aten.h>
32
- #include < executorch/runtime/core/exec_aten/util/scalar_type_util.h>
33
- #include < executorch/runtime/platform/log.h>
34
24
35
25
namespace torch ::executor {
26
+ namespace {
27
+ static constexpr auto kAppendEosToPrompt = " append_eos_to_prompt" ;
28
+ static constexpr auto kEnableDynamicShape = " enable_dynamic_shape" ;
29
+ static constexpr auto kBosId = " get_bos_id" ;
30
+ static constexpr auto kEosId = " get_eos_id" ;
31
+ static constexpr auto kMaxSeqLen = " get_max_seq_len" ;
32
+ static constexpr auto kNBos = " get_n_bos" ;
33
+ static constexpr auto kNEos = " get_n_eos" ;
34
+ static constexpr auto kVocabSize = " get_vocab_size" ;
35
+ static constexpr auto kUseKVCache = " use_kv_cache" ;
36
+ static constexpr auto kUseSDPAWithKVCache = " use_sdpa_with_kv_cache" ;
37
+ } // namespace
36
38
37
39
Runner::Runner (
38
40
const std::string& model_path,
@@ -43,7 +45,23 @@ Runner::Runner(
43
45
// FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
44
46
: temperature_(temperature),
45
47
module_ (std::make_unique<Module>(model_path, Module::LoadMode::File)),
46
- tokenizer_path_(tokenizer_path) {
48
+ tokenizer_path_(tokenizer_path),
49
+ tokenizer_(
50
+ #if ET_USE_TIKTOKEN
51
+ get_tiktoken_for_llama ()
52
+ #else
53
+ std::make_unique<BPETokenizer>()
54
+ #endif
55
+ ),
56
+ metadata_({
57
+ {kAppendEosToPrompt , false },
58
+ {kEnableDynamicShape , false },
59
+ {kMaxSeqLen , 128 },
60
+ {kNBos , 1 },
61
+ {kNEos , 1 },
62
+ {kUseKVCache , true },
63
+ {kUseSDPAWithKVCache , false },
64
+ }) {
47
65
ET_LOG (
48
66
Info,
49
67
" Creating LLaMa runner: model_path=%s, tokenizer_path=%s" ,
@@ -62,54 +80,49 @@ Error Runner::load() {
62
80
}
63
81
ET_CHECK_OK_OR_RETURN_ERROR (module_->load_method (" forward" ));
64
82
65
- // Read out metadata: vocab_size (expected by the model), BOS, EOS, n_BOS,
66
- // n_EOS max_seq_len from the model
67
- ET_LOG (Info, " Reading metadata from model" );
68
- const auto method_names = module_->method_names ();
69
- ET_CHECK_MSG (method_names.ok (), " Failed to read method names from model" );
70
- model_methods_ = method_names.get ();
71
- n_bos_ = get_module_metadata<int64_t >(module_.get (), " get_n_bos" , 1 );
72
- n_eos_ = get_module_metadata<int64_t >(module_.get (), " get_n_eos" , 1 );
73
- max_seq_len_ =
74
- get_module_metadata<int64_t >(module_.get (), " get_max_seq_len" , 128 );
75
- use_kv_cache_ = get_module_metadata (module_.get (), " use_kv_cache" , true );
76
- use_sdpa_with_kv_cache_ =
77
- get_module_metadata (module_.get (), " use_sdpa_with_kv_cache" , false );
78
- append_eos_ =
79
- get_module_metadata (module_.get (), " append_eos_to_prompt" , false );
80
- enable_parallel_prefill_ =
81
- get_module_metadata (module_.get (), " enable_dynamic_shape" , false );
82
-
83
- // Load tokenizer
84
- #if ET_USE_TIKTOKEN
85
- tokenizer_ = get_tiktoken_for_llama ();
86
- #else
87
- tokenizer_ = std::make_unique<BPETokenizer>();
88
- #endif
89
83
tokenizer_->load (tokenizer_path_);
90
84
91
- vocab_size_ = get_module_metadata<int64_t >(
92
- module_.get (), " get_vocab_size" , tokenizer_->vocab_size ());
93
- bos_id_ = get_module_metadata<int64_t >(
94
- module_.get (), " get_bos_id" , tokenizer_->bos_tok ());
95
- eos_id_ = get_module_metadata<int64_t >(
96
- module_.get (), " get_eos_id" , tokenizer_->eos_tok ());
85
+ ET_LOG (Info, " Reading metadata from model" );
97
86
98
- // Create text decoder runner and prefiller
87
+ metadata_[kBosId ] = tokenizer_->bos_tok ();
88
+ metadata_[kEosId ] = tokenizer_->eos_tok ();
89
+ metadata_[kVocabSize ] = tokenizer_->vocab_size ();
90
+
91
+ const auto method_names =
92
+ ET_UNWRAP (module_->method_names (), " Failed reading method names" );
93
+
94
+ for (auto & pair : metadata_) {
95
+ const auto & method_name = pair.first ;
96
+ auto & value = pair.second ;
97
+
98
+ if (method_names.count (method_name)) {
99
+ value = ET_UNWRAP (module_->get (method_name))
100
+ .toScalar ()
101
+ .to <decltype (metadata_)::mapped_type>();
102
+ } else {
103
+ ET_LOG (
104
+ Info,
105
+ " Methond %s not found, using the default value %" PRId64,
106
+ method_name.c_str (),
107
+ value);
108
+ }
109
+ }
99
110
text_decoder_runner_ = std::make_unique<TextDecoderRunner>(
100
- module_.get (), use_kv_cache_, vocab_size_, temperature_);
101
-
111
+ module_.get (),
112
+ metadata_.at (kUseKVCache ),
113
+ metadata_.at (kVocabSize ),
114
+ temperature_);
102
115
text_prefiller_ = std::make_unique<TextPrefiller>(
103
116
tokenizer_.get (),
104
117
text_decoder_runner_.get (),
105
- use_kv_cache_ ,
118
+ metadata_. at ( kUseKVCache ) ,
106
119
enable_parallel_prefill_);
107
120
108
121
text_token_generator_ = std::make_unique<TextTokenGenerator>(
109
122
tokenizer_.get (),
110
123
text_decoder_runner_.get (),
111
- use_kv_cache_ ,
112
- eos_id_ ,
124
+ metadata_. at ( kUseKVCache ) ,
125
+ metadata_. at ( kEosId ) ,
113
126
&stats_);
114
127
115
128
return Error::Ok;
@@ -145,10 +158,14 @@ Error Runner::generate(
145
158
shouldStop_ = false ;
146
159
147
160
// Set the sequence length to the max seq length if not provided
148
- seq_len = (seq_len > 0 && seq_len <= max_seq_len_) ? seq_len : max_seq_len_;
161
+ seq_len = (seq_len > 0 && seq_len <= metadata_.at (kMaxSeqLen ))
162
+ ? seq_len
163
+ : metadata_.at (kMaxSeqLen );
149
164
150
- Result<std::vector<uint64_t >> encode_res =
151
- tokenizer_->encode (prompt, n_bos_, append_eos_ ? n_eos_ : 0 );
165
+ Result<std::vector<uint64_t >> encode_res = tokenizer_->encode (
166
+ prompt,
167
+ metadata_.at (kNBos ),
168
+ metadata_.at (kAppendEosToPrompt ) ? metadata_.at (kNEos ) : 0 );
152
169
153
170
ET_CHECK_OK_OR_RETURN_ERROR (
154
171
encode_res.error (), " Failed to encode prompt %s" , prompt.c_str ());
@@ -159,11 +176,11 @@ Error Runner::generate(
159
176
160
177
ET_CHECK_MSG (num_prompt_tokens >= 1 , " Expected at least 1 prompt token" );
161
178
ET_CHECK_MSG (
162
- num_prompt_tokens < max_seq_len_,
163
- " num_prompt_tokens %d >= max_seq_len_ %d, Max seq length exceeded - please increase max seq len value in .../llama2/model.py" ,
179
+ num_prompt_tokens < metadata_.at (kMaxSeqLen ),
180
+ " num_prompt_tokens %d >= max_seq_len_ %" PRId64
181
+ " , Max seq length exceeded - please increase max seq len value in .../llama2/model.py" ,
164
182
num_prompt_tokens,
165
- max_seq_len_);
166
-
183
+ metadata_.at (kMaxSeqLen ));
167
184
ET_CHECK_MSG (
168
185
num_prompt_tokens < seq_len,
169
186
" num_prompt_tokens %d >= seq_len %d, Sequence length exceeded - please increase the seq_len value passed to generate()" ,
0 commit comments