24
24
#include < executorch/runtime/core/exec_aten/util/scalar_type_util.h>
25
25
#include < executorch/runtime/platform/log.h>
26
26
27
- namespace torch {
28
- namespace executor {
27
+ namespace torch ::executor {
28
+ namespace {
29
+ static constexpr auto kTopp = 0 .9f ;
30
+ } // namespace
29
31
30
32
Runner::Runner (
31
- const char * model_path,
32
- const char * tokenizer_path,
33
- float temperature) {
34
- // Constants definition
35
- float topp = 0 .9f ;
36
- unsigned long long rng_seed =
37
- (unsigned int )time (nullptr ); // seed rng with time by default
38
- // Create module
39
- module_ = std::make_unique<Module>(
40
- model_path, Module::MlockConfig::UseMlockIgnoreErrors);
33
+ const std::string& model_path,
34
+ const std::string& tokenizer_path,
35
+ const float temperature)
36
+ : module_(std::make_unique<Module>(
37
+ model_path,
38
+ Module::MlockConfig::UseMlockIgnoreErrors)),
39
+ tokenizer_path_ (tokenizer_path),
40
+ temperature_(temperature) {
41
+ ET_LOG (
42
+ Info,
43
+ " Creating LLaMa runner: model_path=%s, tokenizer_path=%s" ,
44
+ model_path.c_str (),
45
+ tokenizer_path.c_str ());
46
+ }
47
+
48
+ bool Runner::is_loaded () const {
49
+ return module_->is_loaded () && tokenizer_ && sampler_;
50
+ }
51
+
52
+ Error Runner::load () {
53
+ if (is_loaded ()) {
54
+ return Error::Ok;
55
+ }
56
+ ET_CHECK_OK_OR_RETURN_ERROR (module_->load_method (" forward" ));
41
57
42
58
// Read out metadata: vocab_size (expected by the model), BOS, EOS, n_BOS,
43
59
// n_EOS max_seq_len from the model
44
60
ET_LOG (Info, " Reading metadata from model" );
45
61
const auto method_names = module_->method_names ();
46
- ET_CHECK_MSG (
47
- method_names.ok (),
48
- " Failed to read method names from model: %s" ,
49
- model_path);
62
+ ET_CHECK_MSG (method_names.ok (), " Failed to read method names from model" );
50
63
model_methods_ = method_names.get ();
51
64
vocab_size_ = getMetadataHelper<int64_t >(" get_vocab_size" , 32000 );
52
65
bos_id_ = getMetadataHelper<int64_t >(" get_bos_id" , 1 );
@@ -59,7 +72,7 @@ Runner::Runner(
59
72
60
73
// Load tokenizer
61
74
tokenizer_ = std::make_unique<Tokenizer>(vocab_size_, bos_id_, eos_id_);
62
- tokenizer_->load (tokenizer_path );
75
+ tokenizer_->load (tokenizer_path_ );
63
76
if (tokenizer_->bos_tok () != bos_id_) {
64
77
ET_LOG (
65
78
Error,
@@ -75,8 +88,13 @@ Runner::Runner(
75
88
eos_id_);
76
89
}
77
90
// Create sampler
78
- sampler_ =
79
- std::make_unique<Sampler>(vocab_size_, temperature, topp, rng_seed);
91
+ sampler_ = std::make_unique<Sampler>(
92
+ vocab_size_,
93
+ temperature_,
94
+ kTopp ,
95
+ static_cast <unsigned long long >(std::time (nullptr )));
96
+
97
+ return Error::Ok;
80
98
}
81
99
82
100
template <typename T>
@@ -141,6 +159,9 @@ Error Runner::generate(
141
159
// Prepare the inputs.
142
160
// Use ones-initialized inputs.
143
161
ET_CHECK_MSG (!prompt.empty (), " Prompt cannot be null" );
162
+ ET_CHECK_OK_OR_RETURN_ERROR (load ());
163
+
164
+ shouldStop_ = false ;
144
165
145
166
// encode the (string) prompt into tokens sequence
146
167
int num_prompt_tokens = 0 ;
@@ -292,6 +313,10 @@ Error Runner::generate(
292
313
callback (piece);
293
314
}
294
315
316
+ if (shouldStop_) {
317
+ break ;
318
+ }
319
+
295
320
// data-dependent terminating condition: we have n_eos_ number of EOS
296
321
if (pos >= num_prompt_tokens && next == eos_id_) {
297
322
eos_counter++;
@@ -338,12 +363,15 @@ Error Runner::generate(
338
363
return Error::Ok;
339
364
}
340
365
366
+ void Runner::stop () {
367
+ shouldStop_ = true ;
368
+ }
369
+
341
370
// explicit instantiation of template methods
342
371
template int64_t Runner::getMetadataHelper<int64_t >(
343
372
std::string method_name,
344
373
int64_t default_val);
345
374
template bool Runner::getMetadataHelper<bool >(
346
375
std::string method_name,
347
376
bool default_val);
348
- } // namespace executor
349
- } // namespace torch
377
+ } // namespace torch::executor
0 commit comments