Skip to content

Commit 11d1eeb

Browse files
committed
Add an interface for LLM runner
In case we have custom LLM runners other than llama runner, we want to have a uniform interface
1 parent ef5e841 commit 11d1eeb

File tree

2 files changed

+53
-1
lines changed

2 files changed

+53
-1
lines changed

examples/models/llama/runner/runner.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <string>
1818
#include <unordered_map>
1919

20+
#include <executorch/extension/llm/runner/runner_interface.h>
2021
#include <executorch/extension/llm/runner/stats.h>
2122
#include <executorch/extension/llm/runner/text_decoder_runner.h>
2223
#include <executorch/extension/llm/runner/text_prefiller.h>
@@ -26,7 +27,8 @@
2627

2728
namespace example {
2829

29-
class ET_EXPERIMENTAL Runner {
30+
class ET_EXPERIMENTAL Runner
31+
: public executorch::extension::llm::RunnerInterface {
3032
public:
3133
explicit Runner(
3234
const std::string& model_path,
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// An interface for LLM runners. Developers can create their own runner that
10+
// implements their own load and generation logic to run the model.
11+
12+
#pragma once
13+
14+
#include <functional>
15+
#include <string>
16+
17+
#include <executorch/extension/llm/runner/stats.h>
18+
#include <executorch/extension/module/module.h>
19+
20+
namespace executorch {
21+
namespace extension {
22+
namespace llm {
23+
24+
class ET_EXPERIMENTAL RunnerInterface {
25+
public:
26+
virtual ~RunnerInterface() = default;
27+
28+
// Checks if the model is loaded.
29+
virtual bool is_loaded() const = 0;
30+
31+
// Load the model and tokenizer.
32+
virtual ::executorch::runtime::Error load() = 0;
33+
34+
// Generate the output tokens.
35+
virtual ::executorch::runtime::Error generate(
36+
const std::string& prompt,
37+
int32_t seq_len,
38+
std::function<void(const std::string&)> token_callback = {},
39+
std::function<void(const ::executorch::extension::llm::Stats&)>
40+
stats_callback = {},
41+
bool echo = true,
42+
bool warming = false) = 0;
43+
44+
// Stop the generation.
45+
virtual void stop() = 0;
46+
};
47+
48+
} // namespace llm
49+
} // namespace extension
50+
} // namespace executorch

0 commit comments

Comments
 (0)