16
16
#include < executorch/extension/llm/tokenizer/bpe_tokenizer.h>
17
17
#endif /* ET_USE_TIKTOKEN*/
18
18
#include < executorch/extension/evalue_util/print_evalue.h>
19
+ #include < executorch/extension/module/metadata_util.h>
19
20
#include < executorch/extension/runner_util/managed_tensor.h>
20
21
21
22
#include < ctime>
@@ -66,13 +67,17 @@ Error Runner::load() {
66
67
const auto method_names = module_->method_names ();
67
68
ET_CHECK_MSG (method_names.ok (), " Failed to read method names from model" );
68
69
model_methods_ = method_names.get ();
69
- n_bos_ = getMetadataHelper<int64_t >(" get_n_bos" , 1 );
70
- n_eos_ = getMetadataHelper<int64_t >(" get_n_eos" , 1 );
71
- max_seq_len_ = getMetadataHelper<int64_t >(" get_max_seq_len" , 128 );
72
- use_kv_cache_ = getMetadataHelper (" use_kv_cache" , true );
73
- use_sdpa_with_kv_cache_ = getMetadataHelper (" use_sdpa_with_kv_cache" , false );
74
- append_eos_ = getMetadataHelper (" append_eos_to_prompt" , false );
75
- enable_parallel_prefill_ = getMetadataHelper (" enable_dynamic_shape" , false );
70
+ n_bos_ = get_module_metadata<int64_t >(module_.get (), " get_n_bos" , 1 );
71
+ n_eos_ = get_module_metadata<int64_t >(module_.get (), " get_n_eos" , 1 );
72
+ max_seq_len_ =
73
+ get_module_metadata<int64_t >(module_.get (), " get_max_seq_len" , 128 );
74
+ use_kv_cache_ = get_module_metadata (module_.get (), " use_kv_cache" , true );
75
+ use_sdpa_with_kv_cache_ =
76
+ get_module_metadata (module_.get (), " use_sdpa_with_kv_cache" , false );
77
+ append_eos_ =
78
+ get_module_metadata (module_.get (), " append_eos_to_prompt" , false );
79
+ enable_parallel_prefill_ =
80
+ get_module_metadata (module_.get (), " enable_dynamic_shape" , false );
76
81
77
82
// Load tokenizer
78
83
#if ET_USE_TIKTOKEN
@@ -82,10 +87,12 @@ Error Runner::load() {
82
87
#endif
83
88
tokenizer_->load (tokenizer_path_);
84
89
85
- vocab_size_ =
86
- getMetadataHelper<int64_t >(" get_vocab_size" , tokenizer_->vocab_size ());
87
- bos_id_ = getMetadataHelper<int64_t >(" get_bos_id" , tokenizer_->bos_tok ());
88
- eos_id_ = getMetadataHelper<int64_t >(" get_eos_id" , tokenizer_->eos_tok ());
90
+ vocab_size_ = get_module_metadata<int64_t >(
91
+ module_.get (), " get_vocab_size" , tokenizer_->vocab_size ());
92
+ bos_id_ = get_module_metadata<int64_t >(
93
+ module_.get (), " get_bos_id" , tokenizer_->bos_tok ());
94
+ eos_id_ = get_module_metadata<int64_t >(
95
+ module_.get (), " get_eos_id" , tokenizer_->eos_tok ());
89
96
90
97
// Create sampler
91
98
sampler_ = std::make_unique<Sampler>(
@@ -97,28 +104,6 @@ Error Runner::load() {
97
104
return Error::Ok;
98
105
}
99
106
100
- template <typename T>
101
- T Runner::getMetadataHelper (const std::string& method_name, T default_val) {
102
- T res = default_val;
103
- if (model_methods_.count (method_name)) {
104
- Result<std::vector<EValue>> outputs = module_->execute (method_name);
105
- if (outputs.ok ()) {
106
- std::vector<EValue> outs = outputs.get ();
107
- if (outs.size () > 0 ) {
108
- res = outs[0 ].to <T>();
109
- }
110
- }
111
- } else {
112
- ET_LOG (
113
- Info,
114
- " The model does not contain %s method, using default value %lld" ,
115
- method_name.c_str (),
116
- (long long )default_val);
117
- }
118
- ET_LOG (Info, " %s: %lld" , method_name.c_str (), (long long )res);
119
- return res;
120
- }
121
-
122
107
int32_t Runner::logitsToToken (const exec_aten::Tensor& logits_tensor) {
123
108
ET_CHECK_MSG (logits_tensor.dim () == 3 , " Logits tensor must be 3D" );
124
109
auto num_tokens = logits_tensor.size (1 );
@@ -485,12 +470,4 @@ Error Runner::generate(
485
470
void Runner::stop () {
486
471
shouldStop_ = true ;
487
472
}
488
-
489
- // explicit instantiation of template methods
490
- template int64_t Runner::getMetadataHelper<int64_t >(
491
- const std::string& method_name,
492
- int64_t default_val);
493
- template bool Runner::getMetadataHelper<bool >(
494
- const std::string& method_name,
495
- bool default_val);
496
473
} // namespace torch::executor
0 commit comments