Skip to content

Commit 7b1ac5d

Browse files
shoumikhinfacebook-github-bot
authored andcommitted
Factor out model loading and provide a way to stop generation. (#2002)
Summary: Pull Request resolved: #2002 . Reviewed By: manuelcandales Differential Revision: D53848385 fbshipit-source-id: 6bab78da0e8f484966e153007bb4bd86f98b4e76
1 parent 534664d commit 7b1ac5d

File tree

4 files changed

+65
-35
lines changed

4 files changed

+65
-35
lines changed

examples/models/llama2/runner/runner.cpp

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,42 @@
2424
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
2525
#include <executorch/runtime/platform/log.h>
2626

27-
namespace torch {
28-
namespace executor {
27+
namespace torch::executor {
28+
namespace {
29+
static constexpr auto kTopp = 0.9f;
30+
} // namespace
2931

3032
Runner::Runner(
31-
const char* model_path,
32-
const char* tokenizer_path,
33-
float temperature) {
34-
// Constants definition
35-
float topp = 0.9f;
36-
unsigned long long rng_seed =
37-
(unsigned int)time(nullptr); // seed rng with time by default
38-
// Create module
39-
module_ = std::make_unique<Module>(
40-
model_path, Module::MlockConfig::UseMlockIgnoreErrors);
33+
const std::string& model_path,
34+
const std::string& tokenizer_path,
35+
const float temperature)
36+
: module_(std::make_unique<Module>(
37+
model_path,
38+
Module::MlockConfig::UseMlockIgnoreErrors)),
39+
tokenizer_path_(tokenizer_path),
40+
temperature_(temperature) {
41+
ET_LOG(
42+
Info,
43+
"Creating LLaMa runner: model_path=%s, tokenizer_path=%s",
44+
model_path.c_str(),
45+
tokenizer_path.c_str());
46+
}
47+
48+
bool Runner::is_loaded() const {
49+
return module_->is_loaded() && tokenizer_ && sampler_;
50+
}
51+
52+
Error Runner::load() {
53+
if (is_loaded()) {
54+
return Error::Ok;
55+
}
56+
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward"));
4157

4258
// Read out metadata: vocab_size (expected by the model), BOS, EOS, n_BOS,
4359
// n_EOS max_seq_len from the model
4460
ET_LOG(Info, "Reading metadata from model");
4561
const auto method_names = module_->method_names();
46-
ET_CHECK_MSG(
47-
method_names.ok(),
48-
"Failed to read method names from model: %s",
49-
model_path);
62+
ET_CHECK_MSG(method_names.ok(), "Failed to read method names from model");
5063
model_methods_ = method_names.get();
5164
vocab_size_ = getMetadataHelper<int64_t>("get_vocab_size", 32000);
5265
bos_id_ = getMetadataHelper<int64_t>("get_bos_id", 1);
@@ -59,7 +72,7 @@ Runner::Runner(
5972

6073
// Load tokenizer
6174
tokenizer_ = std::make_unique<Tokenizer>(vocab_size_, bos_id_, eos_id_);
62-
tokenizer_->load(tokenizer_path);
75+
tokenizer_->load(tokenizer_path_);
6376
if (tokenizer_->bos_tok() != bos_id_) {
6477
ET_LOG(
6578
Error,
@@ -75,8 +88,13 @@ Runner::Runner(
7588
eos_id_);
7689
}
7790
// Create sampler
78-
sampler_ =
79-
std::make_unique<Sampler>(vocab_size_, temperature, topp, rng_seed);
91+
sampler_ = std::make_unique<Sampler>(
92+
vocab_size_,
93+
temperature_,
94+
kTopp,
95+
static_cast<unsigned long long>(std::time(nullptr)));
96+
97+
return Error::Ok;
8098
}
8199

82100
template <typename T>
@@ -141,6 +159,9 @@ Error Runner::generate(
141159
// Prepare the inputs.
142160
// Use ones-initialized inputs.
143161
ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
162+
ET_CHECK_OK_OR_RETURN_ERROR(load());
163+
164+
shouldStop_ = false;
144165

145166
// encode the (string) prompt into tokens sequence
146167
int num_prompt_tokens = 0;
@@ -292,6 +313,10 @@ Error Runner::generate(
292313
callback(piece);
293314
}
294315

316+
if (shouldStop_) {
317+
break;
318+
}
319+
295320
// data-dependent terminating condition: we have n_eos_ number of EOS
296321
if (pos >= num_prompt_tokens && next == eos_id_) {
297322
eos_counter++;
@@ -338,12 +363,15 @@ Error Runner::generate(
338363
return Error::Ok;
339364
}
340365

366+
void Runner::stop() {
367+
shouldStop_ = true;
368+
}
369+
341370
// explicit instantiation of template methods
342371
template int64_t Runner::getMetadataHelper<int64_t>(
343372
std::string method_name,
344373
int64_t default_val);
345374
template bool Runner::getMetadataHelper<bool>(
346375
std::string method_name,
347376
bool default_val);
348-
} // namespace executor
349-
} // namespace torch
377+
} // namespace torch::executor

examples/models/llama2/runner/runner.h

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,21 @@
2121
#include <executorch/examples/models/llama2/tokenizer/tokenizer.h>
2222
#include <executorch/extension/module/module.h>
2323

24-
namespace torch {
25-
namespace executor {
24+
namespace torch::executor {
2625

2726
class Runner {
2827
public:
2928
explicit Runner(
30-
const char* model_path,
31-
const char* tokenizer_path,
32-
float temperature = 0.8f);
29+
const std::string& model_path,
30+
const std::string& tokenizer_path,
31+
const float temperature = 0.8f);
3332

33+
bool is_loaded() const;
34+
Error load();
3435
Error generate(
3536
const std::string& prompt,
3637
std::function<void(const std::string&)> callback = {});
38+
void stop();
3739

3840
private:
3941
// metadata
@@ -53,13 +55,12 @@ class Runner {
5355
bool use_kv_cache_;
5456
bool append_eos_;
5557
std::unordered_set<std::string> model_methods_;
56-
// module
5758
std::unique_ptr<Module> module_;
58-
// tokenizer
59+
std::string tokenizer_path_;
60+
float temperature_;
5961
std::unique_ptr<Tokenizer> tokenizer_;
60-
// sampler
6162
std::unique_ptr<Sampler> sampler_;
63+
bool shouldStop_{false};
6264
};
6365

64-
} // namespace executor
65-
} // namespace torch
66+
} // namespace torch::executor

examples/models/llama2/tokenizer/tokenizer.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,15 @@ Tokenizer::Tokenizer(int32_t vocab_size, int32_t bos_tok, int32_t eos_tok)
4747
* @param tokenizer_path The path to the tokenizer file.
4848
* @return Error
4949
*/
50-
Error Tokenizer::load(const char* tokenizer_path) {
50+
Error Tokenizer::load(const std::string& tokenizer_path) {
5151
if (initialized_) {
5252
ET_LOG(Info, "Tokenizer already initialized");
5353
return Error::Ok;
5454
}
5555
// read in the file
56-
FILE* file = fopen(tokenizer_path, "rb");
56+
FILE* file = fopen(tokenizer_path.c_str(), "rb");
5757
if (!file) {
58-
ET_LOG(Error, "couldn't load %s", tokenizer_path);
58+
ET_LOG(Error, "couldn't load %s", tokenizer_path.c_str());
5959
return Error::InvalidArgument;
6060
}
6161
int32_t metadata[2];

examples/models/llama2/tokenizer/tokenizer.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <cstdlib>
1616
#include <cstring>
1717
#include <memory>
18+
#include <string>
1819

1920
#include <executorch/runtime/core/error.h>
2021
#include <executorch/runtime/core/exec_aten/exec_aten.h>
@@ -34,7 +35,7 @@ class Tokenizer {
3435
explicit Tokenizer(int32_t vocab_size, int32_t bos_tok, int32_t eos_tok);
3536
~Tokenizer();
3637

37-
Error load(const char* tokenizer_path);
38+
Error load(const std::string& tokenizer_path);
3839

3940
Error encode(
4041
const char* text,

0 commit comments

Comments
 (0)