Skip to content

Factor out model loading and provide a way to stop generation. #2002

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 49 additions & 21 deletions examples/models/llama2/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,42 @@
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
#include <executorch/runtime/platform/log.h>

namespace torch {
namespace executor {
namespace torch::executor {
namespace {
static constexpr auto kTopp = 0.9f;
} // namespace

Runner::Runner(
const char* model_path,
const char* tokenizer_path,
float temperature) {
// Constants definition
float topp = 0.9f;
unsigned long long rng_seed =
(unsigned int)time(nullptr); // seed rng with time by default
// Create module
module_ = std::make_unique<Module>(
model_path, Module::MlockConfig::UseMlockIgnoreErrors);
const std::string& model_path,
const std::string& tokenizer_path,
const float temperature)
: module_(std::make_unique<Module>(
model_path,
Module::MlockConfig::UseMlockIgnoreErrors)),
tokenizer_path_(tokenizer_path),
temperature_(temperature) {
ET_LOG(
Info,
"Creating LLaMa runner: model_path=%s, tokenizer_path=%s",
model_path.c_str(),
tokenizer_path.c_str());
}

bool Runner::is_loaded() const {
return module_->is_loaded() && tokenizer_ && sampler_;
}

Error Runner::load() {
if (is_loaded()) {
return Error::Ok;
}
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward"));

// Read out metadata: vocab_size (expected by the model), BOS, EOS, n_BOS,
// n_EOS max_seq_len from the model
ET_LOG(Info, "Reading metadata from model");
const auto method_names = module_->method_names();
ET_CHECK_MSG(
method_names.ok(),
"Failed to read method names from model: %s",
model_path);
ET_CHECK_MSG(method_names.ok(), "Failed to read method names from model");
model_methods_ = method_names.get();
vocab_size_ = getMetadataHelper<int64_t>("get_vocab_size", 32000);
bos_id_ = getMetadataHelper<int64_t>("get_bos_id", 1);
Expand All @@ -59,7 +72,7 @@ Runner::Runner(

// Load tokenizer
tokenizer_ = std::make_unique<Tokenizer>(vocab_size_, bos_id_, eos_id_);
tokenizer_->load(tokenizer_path);
tokenizer_->load(tokenizer_path_);
if (tokenizer_->bos_tok() != bos_id_) {
ET_LOG(
Error,
Expand All @@ -75,8 +88,13 @@ Runner::Runner(
eos_id_);
}
// Create sampler
sampler_ =
std::make_unique<Sampler>(vocab_size_, temperature, topp, rng_seed);
sampler_ = std::make_unique<Sampler>(
vocab_size_,
temperature_,
kTopp,
static_cast<unsigned long long>(std::time(nullptr)));

return Error::Ok;
}

template <typename T>
Expand Down Expand Up @@ -141,6 +159,9 @@ Error Runner::generate(
// Prepare the inputs.
// Use ones-initialized inputs.
ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
ET_CHECK_OK_OR_RETURN_ERROR(load());

shouldStop_ = false;

// encode the (string) prompt into tokens sequence
int num_prompt_tokens = 0;
Expand Down Expand Up @@ -292,6 +313,10 @@ Error Runner::generate(
callback(piece);
}

if (shouldStop_) {
break;
}

// data-dependent terminating condition: we have n_eos_ number of EOS
if (pos >= num_prompt_tokens && next == eos_id_) {
eos_counter++;
Expand Down Expand Up @@ -338,12 +363,15 @@ Error Runner::generate(
return Error::Ok;
}

void Runner::stop() {
shouldStop_ = true;
}

// explicit instantiation of template methods
template int64_t Runner::getMetadataHelper<int64_t>(
std::string method_name,
int64_t default_val);
template bool Runner::getMetadataHelper<bool>(
std::string method_name,
bool default_val);
} // namespace executor
} // namespace torch
} // namespace torch::executor
21 changes: 11 additions & 10 deletions examples/models/llama2/runner/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,21 @@
#include <executorch/examples/models/llama2/tokenizer/tokenizer.h>
#include <executorch/extension/module/module.h>

namespace torch {
namespace executor {
namespace torch::executor {

class Runner {
public:
explicit Runner(
const char* model_path,
const char* tokenizer_path,
float temperature = 0.8f);
const std::string& model_path,
const std::string& tokenizer_path,
const float temperature = 0.8f);

bool is_loaded() const;
Error load();
Error generate(
const std::string& prompt,
std::function<void(const std::string&)> callback = {});
void stop();

private:
// metadata
Expand All @@ -53,13 +55,12 @@ class Runner {
bool use_kv_cache_;
bool append_eos_;
std::unordered_set<std::string> model_methods_;
// module
std::unique_ptr<Module> module_;
// tokenizer
std::string tokenizer_path_;
float temperature_;
std::unique_ptr<Tokenizer> tokenizer_;
// sampler
std::unique_ptr<Sampler> sampler_;
bool shouldStop_{false};
};

} // namespace executor
} // namespace torch
} // namespace torch::executor
6 changes: 3 additions & 3 deletions examples/models/llama2/tokenizer/tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,15 @@ Tokenizer::Tokenizer(int32_t vocab_size, int32_t bos_tok, int32_t eos_tok)
* @param tokenizer_path The path to the tokenizer file.
* @return Error
*/
Error Tokenizer::load(const char* tokenizer_path) {
Error Tokenizer::load(const std::string& tokenizer_path) {
if (initialized_) {
ET_LOG(Info, "Tokenizer already initialized");
return Error::Ok;
}
// read in the file
FILE* file = fopen(tokenizer_path, "rb");
FILE* file = fopen(tokenizer_path.c_str(), "rb");
if (!file) {
ET_LOG(Error, "couldn't load %s", tokenizer_path);
ET_LOG(Error, "couldn't load %s", tokenizer_path.c_str());
return Error::InvalidArgument;
}
int32_t metadata[2];
Expand Down
3 changes: 2 additions & 1 deletion examples/models/llama2/tokenizer/tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <cstdlib>
#include <cstring>
#include <memory>
#include <string>

#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
Expand All @@ -34,7 +35,7 @@ class Tokenizer {
explicit Tokenizer(int32_t vocab_size, int32_t bos_tok, int32_t eos_tok);
~Tokenizer();

Error load(const char* tokenizer_path);
Error load(const std::string& tokenizer_path);

Error encode(
const char* text,
Expand Down