Skip to content

Commit 4d7b294

Browse files
kirklandsignfacebook-github-bot
authored andcommitted
Add an interface for LLM runner (#6356)
Summary: In case we have custom LLM runners other than llama runner, we want to have a uniform interface Pull Request resolved: #6356 Reviewed By: cccclai Differential Revision: D64629696 Pulled By: kirklandsign fbshipit-source-id: b9a670e47c4a73ae1180c85e9f11f0b23e3e4ed6
1 parent 8209bc1 commit 4d7b294

File tree

4 files changed

+63
-1
lines changed

4 files changed

+63
-1
lines changed

examples/models/llama/runner/runner.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <string>
1818
#include <unordered_map>
1919

20+
#include <executorch/extension/llm/runner/irunner.h>
2021
#include <executorch/extension/llm/runner/stats.h>
2122
#include <executorch/extension/llm/runner/text_decoder_runner.h>
2223
#include <executorch/extension/llm/runner/text_prefiller.h>
@@ -26,7 +27,7 @@
2627

2728
namespace example {
2829

29-
class ET_EXPERIMENTAL Runner {
30+
class ET_EXPERIMENTAL Runner : public executorch::extension::llm::IRunner {
3031
public:
3132
explicit Runner(
3233
const std::string& model_path,

examples/models/llama/runner/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def define_common_targets():
3939
],
4040
exported_deps = [
4141
"//executorch/backends/xnnpack:xnnpack_backend",
42+
"//executorch/extension/llm/runner:irunner",
4243
"//executorch/extension/llm/runner:stats",
4344
"//executorch/extension/llm/runner:text_decoder_runner" + aten_suffix,
4445
"//executorch/extension/llm/runner:text_prefiller" + aten_suffix,

extension/llm/runner/irunner.h

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// An interface for LLM runners. Developers can create their own runner that
10+
// implements their own load and generation logic to run the model.
11+
12+
#pragma once
13+
14+
#include <functional>
15+
#include <string>
16+
17+
#include <executorch/extension/llm/runner/stats.h>
18+
#include <executorch/extension/module/module.h>
19+
20+
namespace executorch {
21+
namespace extension {
22+
namespace llm {
23+
24+
class ET_EXPERIMENTAL IRunner {
25+
public:
26+
virtual ~IRunner() = default;
27+
28+
// Checks if the model is loaded.
29+
virtual bool is_loaded() const = 0;
30+
31+
// Load the model and tokenizer.
32+
virtual ::executorch::runtime::Error load() = 0;
33+
34+
// Generate the output tokens.
35+
virtual ::executorch::runtime::Error generate(
36+
const std::string& prompt,
37+
int32_t seq_len,
38+
std::function<void(const std::string&)> token_callback = {},
39+
std::function<void(const ::executorch::extension::llm::Stats&)>
40+
stats_callback = {},
41+
bool echo = true,
42+
bool warming = false) = 0;
43+
44+
// Stop the generation.
45+
virtual void stop() = 0;
46+
};
47+
48+
} // namespace llm
49+
} // namespace extension
50+
} // namespace executorch

extension/llm/runner/targets.bzl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,16 @@
11
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
22

33
def define_common_targets():
4+
runtime.cxx_library(
5+
name = "irunner",
6+
exported_headers = [
7+
"irunner.h",
8+
],
9+
visibility = [
10+
"@EXECUTORCH_CLIENTS",
11+
],
12+
)
13+
414
runtime.cxx_library(
515
name = "stats",
616
exported_headers = [

0 commit comments

Comments
 (0)