Skip to content

Commit 20cb298

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Move metadata util to a separate header for reuse (#4550)
Summary: Pull Request resolved: #4550 As titled. imported-using-ghimport Test Plan: Imported from OSS Reviewed By: lucylq Differential Revision: D60812259 Pulled By: larryliu0820 fbshipit-source-id: 196e3a3ef7660e70e733ecf36584debcc0c897b5
1 parent f52d8ab commit 20cb298

File tree

4 files changed

+65
-44
lines changed

4 files changed

+65
-44
lines changed

examples/models/llama2/runner/runner.cpp

Lines changed: 18 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <executorch/extension/llm/tokenizer/bpe_tokenizer.h>
1717
#endif /* ET_USE_TIKTOKEN*/
1818
#include <executorch/extension/evalue_util/print_evalue.h>
19+
#include <executorch/extension/module/metadata_util.h>
1920
#include <executorch/extension/runner_util/managed_tensor.h>
2021

2122
#include <ctime>
@@ -66,13 +67,17 @@ Error Runner::load() {
6667
const auto method_names = module_->method_names();
6768
ET_CHECK_MSG(method_names.ok(), "Failed to read method names from model");
6869
model_methods_ = method_names.get();
69-
n_bos_ = getMetadataHelper<int64_t>("get_n_bos", 1);
70-
n_eos_ = getMetadataHelper<int64_t>("get_n_eos", 1);
71-
max_seq_len_ = getMetadataHelper<int64_t>("get_max_seq_len", 128);
72-
use_kv_cache_ = getMetadataHelper("use_kv_cache", true);
73-
use_sdpa_with_kv_cache_ = getMetadataHelper("use_sdpa_with_kv_cache", false);
74-
append_eos_ = getMetadataHelper("append_eos_to_prompt", false);
75-
enable_parallel_prefill_ = getMetadataHelper("enable_dynamic_shape", false);
70+
n_bos_ = get_module_metadata<int64_t>(module_.get(), "get_n_bos", 1);
71+
n_eos_ = get_module_metadata<int64_t>(module_.get(), "get_n_eos", 1);
72+
max_seq_len_ =
73+
get_module_metadata<int64_t>(module_.get(), "get_max_seq_len", 128);
74+
use_kv_cache_ = get_module_metadata(module_.get(), "use_kv_cache", true);
75+
use_sdpa_with_kv_cache_ =
76+
get_module_metadata(module_.get(), "use_sdpa_with_kv_cache", false);
77+
append_eos_ =
78+
get_module_metadata(module_.get(), "append_eos_to_prompt", false);
79+
enable_parallel_prefill_ =
80+
get_module_metadata(module_.get(), "enable_dynamic_shape", false);
7681

7782
// Load tokenizer
7883
#if ET_USE_TIKTOKEN
@@ -82,10 +87,12 @@ Error Runner::load() {
8287
#endif
8388
tokenizer_->load(tokenizer_path_);
8489

85-
vocab_size_ =
86-
getMetadataHelper<int64_t>("get_vocab_size", tokenizer_->vocab_size());
87-
bos_id_ = getMetadataHelper<int64_t>("get_bos_id", tokenizer_->bos_tok());
88-
eos_id_ = getMetadataHelper<int64_t>("get_eos_id", tokenizer_->eos_tok());
90+
vocab_size_ = get_module_metadata<int64_t>(
91+
module_.get(), "get_vocab_size", tokenizer_->vocab_size());
92+
bos_id_ = get_module_metadata<int64_t>(
93+
module_.get(), "get_bos_id", tokenizer_->bos_tok());
94+
eos_id_ = get_module_metadata<int64_t>(
95+
module_.get(), "get_eos_id", tokenizer_->eos_tok());
8996

9097
// Create sampler
9198
sampler_ = std::make_unique<Sampler>(
@@ -97,28 +104,6 @@ Error Runner::load() {
97104
return Error::Ok;
98105
}
99106

100-
template <typename T>
101-
T Runner::getMetadataHelper(const std::string& method_name, T default_val) {
102-
T res = default_val;
103-
if (model_methods_.count(method_name)) {
104-
Result<std::vector<EValue>> outputs = module_->execute(method_name);
105-
if (outputs.ok()) {
106-
std::vector<EValue> outs = outputs.get();
107-
if (outs.size() > 0) {
108-
res = outs[0].to<T>();
109-
}
110-
}
111-
} else {
112-
ET_LOG(
113-
Info,
114-
"The model does not contain %s method, using default value %lld",
115-
method_name.c_str(),
116-
(long long)default_val);
117-
}
118-
ET_LOG(Info, "%s: %lld", method_name.c_str(), (long long)res);
119-
return res;
120-
}
121-
122107
int32_t Runner::logitsToToken(const exec_aten::Tensor& logits_tensor) {
123108
ET_CHECK_MSG(logits_tensor.dim() == 3, "Logits tensor must be 3D");
124109
auto num_tokens = logits_tensor.size(1);
@@ -485,12 +470,4 @@ Error Runner::generate(
485470
void Runner::stop() {
486471
shouldStop_ = true;
487472
}
488-
489-
// explicit instantiation of template methods
490-
template int64_t Runner::getMetadataHelper<int64_t>(
491-
const std::string& method_name,
492-
int64_t default_val);
493-
template bool Runner::getMetadataHelper<bool>(
494-
const std::string& method_name,
495-
bool default_val);
496473
} // namespace torch::executor

examples/models/llama2/runner/runner.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,6 @@ class Runner {
4444
void stop();
4545

4646
private:
47-
// metadata
48-
template <typename T>
49-
T getMetadataHelper(const std::string& method_name, T default_val);
5047
int32_t logitsToToken(const exec_aten::Tensor& logits_tensor);
5148
Result<torch::executor::Tensor> prefill(
5249
const std::vector<uint64_t>& tokens,

extension/module/metadata_util.h

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
/**
10+
* Constant metadata can be serialized in .pte files, this helper enables
11+
* easy access to the metadata.
12+
*/
13+
#pragma once
14+
15+
#include <executorch/extension/module/module.h>
16+
17+
namespace torch::executor {
18+
template <typename T>
19+
T get_module_metadata(
20+
Module* module,
21+
const std::string& method_name,
22+
T default_val) {
23+
const auto method_names = module->method_names();
24+
ET_CHECK_MSG(method_names.ok(), "Failed to read method names from model");
25+
auto model_methods = method_names.get();
26+
27+
T res = default_val;
28+
if (model_methods.count(method_name)) {
29+
Result<std::vector<EValue>> outputs = module->execute(method_name);
30+
if (outputs.ok()) {
31+
std::vector<EValue> outs = outputs.get();
32+
if (outs.size() > 0) {
33+
res = outs[0].to<T>();
34+
}
35+
}
36+
} else {
37+
ET_LOG(
38+
Info,
39+
"The model does not contain %s method, using default value %lld",
40+
method_name.c_str(),
41+
(long long)default_val);
42+
}
43+
ET_LOG(Info, "%s: %lld", method_name.c_str(), (long long)res);
44+
return res;
45+
}
46+
} // namespace torch::executor

extension/module/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def define_common_targets():
1717
],
1818
exported_headers = [
1919
"module.h",
20+
"metadata_util.h",
2021
],
2122
visibility = [
2223
"@EXECUTORCH_CLIENTS",

0 commit comments

Comments
 (0)