Skip to content

Commit 82d4ff2

Browse files
committed
Use common LLM interface
1 parent 11d1eeb commit 82d4ff2

File tree

4 files changed

+21
-17
lines changed

4 files changed

+21
-17
lines changed

examples/mediatek/executor_runner/mtk_llama_runner.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,9 @@ Error MTKLlamaRunner::generate(
120120
const std::string& prompt,
121121
int32_t seq_len,
122122
std::function<void(const std::string&)> token_callback,
123-
std::function<void(const Stats&)> stats_callback) {
123+
std::function<void(const Stats&)> stats_callback
124+
bool,
125+
bool) {
124126
if (!is_loaded()) {
125127
ET_CHECK_OK_OR_RETURN_ERROR(load());
126128
}

examples/mediatek/executor_runner/mtk_llama_runner.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#pragma once
1313

1414
#include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
15+
#include <executorch/extension/llm/runner/runner_interface.h>
1516
#include <executorch/extension/llm/runner/stats.h>
1617
#include <executorch/extension/llm/tokenizer/bpe_tokenizer.h>
1718
#include <executorch/extension/llm/tokenizer/tiktoken.h>
@@ -31,7 +32,8 @@ using executorch::extension::llm::Tokenizer;
3132
using executorch::runtime::Error;
3233
using executorch::runtime::Result;
3334

34-
class MTKLlamaRunner {
35+
class MTKLlamaRunner
36+
: public executorch::extension::llm::RunnerInterface {
3537
public:
3638
explicit MTKLlamaRunner(
3739
const std::string& model_path,
@@ -44,7 +46,9 @@ class MTKLlamaRunner {
4446
const std::string& prompt,
4547
int32_t seq_len = 128,
4648
std::function<void(const std::string&)> token_callback = {},
47-
std::function<void(const Stats&)> stats_callback = {});
49+
std::function<void(const Stats&)> stats_callback = {},
50+
bool echo = true,
51+
bool warming = false);
4852
void stop();
4953

5054
LlamaModelOptions get_model_options();

extension/android/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ if(EXECUTORCH_BUILD_LLAMA_JNI)
179179
ADD_LIBRARY(libneuron_buffer_allocator SHARED IMPORTED)
180180
SET_PROPERTY(TARGET libneuron_buffer_allocator PROPERTY IMPORTED_LOCATION ${NEURON_BUFFER_ALLOCATOR_LIB})
181181
list(APPEND link_libraries neuron_backend libneuron_buffer_allocator)
182+
target_compile_definitions(executorch_jni PRIVATE EXECUTORCH_BUILD_MEDIATEK=1)
182183
endif()
183184
endif()
184185

extension/android/jni/jni_layer_llama.cpp

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
#include <unordered_map>
1616
#include <vector>
1717

18-
#include <executorch/examples/mediatek/executor_runner/mtk_llama_runner.h>
1918
#include <executorch/examples/models/llama/runner/runner.h>
2019
#include <executorch/examples/models/llava/runner/llava_runner.h>
2120
#include <executorch/extension/llm/runner/image.h>
21+
#include <executorch/extension/llm/runner/runner_interface.h>
2222
#include <executorch/runtime/platform/log.h>
2323
#include <executorch/runtime/platform/platform.h>
2424
#include <executorch/runtime/platform/runtime.h>
@@ -31,6 +31,10 @@
3131
#include <fbjni/ByteBuffer.h>
3232
#include <fbjni/fbjni.h>
3333

34+
#if defined(EXECUTORCH_BUILD_MEDIATEK)
35+
#include <executorch/examples/mediatek/executor_runner/mtk_llama_runner.h>
36+
#endif
37+
3438
namespace llm = ::executorch::extension::llm;
3539
using ::executorch::runtime::Error;
3640

@@ -68,9 +72,8 @@ class ExecuTorchLlamaJni
6872
private:
6973
friend HybridBase;
7074
int model_type_category_;
71-
std::unique_ptr<example::Runner> runner_;
75+
std::unique_ptr<llm::RunnerInterface> runner_;
7276
std::unique_ptr<llm::MultimodalRunner> multi_modal_runner_;
73-
std::unique_ptr<MTKLlamaRunner> mtk_llama_runner_;
7477

7578
public:
7679
constexpr static auto kJavaDescriptor =
@@ -117,11 +120,15 @@ class ExecuTorchLlamaJni
117120
model_path->toStdString().c_str(),
118121
tokenizer_path->toStdString().c_str(),
119122
temperature);
123+
#if defined(EXECUTORCH_BUILD_MEDIATEK)
120124
} else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
121-
mtk_llama_runner_ = std::make_unique<MTKLlamaRunner>(
125+
runner_ = std::make_unique<MTKLlamaRunner>(
122126
model_path->toStdString().c_str(),
123127
tokenizer_path->toStdString().c_str(),
124128
temperature);
129+
// Interpret the model type as LLM
130+
model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
131+
#endif
125132
}
126133
}
127134

@@ -161,12 +168,6 @@ class ExecuTorchLlamaJni
161168
[callback](std::string result) { callback->onResult(result); },
162169
[callback](const llm::Stats& result) { callback->onStats(result); },
163170
echo);
164-
} else if (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) {
165-
mtk_llama_runner_->generate(
166-
prompt->toStdString(),
167-
seq_len,
168-
[callback](std::string result) { callback->onResult(result); },
169-
[callback](const Stats& result) { callback->onStats(result); });
170171
}
171172
return 0;
172173
}
@@ -256,8 +257,6 @@ class ExecuTorchLlamaJni
256257
multi_modal_runner_->stop();
257258
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
258259
runner_->stop();
259-
} else if (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) {
260-
mtk_llama_runner_->stop();
261260
}
262261
}
263262

@@ -266,8 +265,6 @@ class ExecuTorchLlamaJni
266265
return static_cast<jint>(multi_modal_runner_->load());
267266
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
268267
return static_cast<jint>(runner_->load());
269-
} else if (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) {
270-
return static_cast<jint>(mtk_llama_runner_->load());
271268
}
272269
return static_cast<jint>(Error::InvalidArgument);
273270
}

0 commit comments

Comments
 (0)