Skip to content

Commit cdbcab2

Browse files
kirklandsigncmodi-meta
authored andcommitted
Use common LLM interface
1 parent 314c8dd commit cdbcab2

File tree

4 files changed

+21
-17
lines changed

4 files changed

+21
-17
lines changed

examples/mediatek/executor_runner/mtk_llama_runner.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,9 @@ Error MTKLlamaRunner::generate(
120120
const std::string& prompt,
121121
int32_t seq_len,
122122
std::function<void(const std::string&)> token_callback,
123-
std::function<void(const Stats&)> stats_callback) {
123+
std::function<void(const Stats&)> stats_callback
124+
bool,
125+
bool) {
124126
if (!is_loaded()) {
125127
ET_CHECK_OK_OR_RETURN_ERROR(load());
126128
}

examples/mediatek/executor_runner/mtk_llama_runner.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#pragma once
1313

1414
#include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
15+
#include <executorch/extension/llm/runner/runner_interface.h>
1516
#include <executorch/extension/llm/runner/stats.h>
1617
#include <executorch/extension/llm/tokenizer/bpe_tokenizer.h>
1718
#include <executorch/extension/llm/tokenizer/tiktoken.h>
@@ -31,7 +32,8 @@ using executorch::extension::llm::Tokenizer;
3132
using executorch::runtime::Error;
3233
using executorch::runtime::Result;
3334

34-
class MTKLlamaRunner {
35+
class MTKLlamaRunner
36+
: public executorch::extension::llm::RunnerInterface {
3537
public:
3638
explicit MTKLlamaRunner(
3739
const std::string& model_path,
@@ -44,7 +46,9 @@ class MTKLlamaRunner {
4446
const std::string& prompt,
4547
int32_t seq_len = 128,
4648
std::function<void(const std::string&)> token_callback = {},
47-
std::function<void(const Stats&)> stats_callback = {});
49+
std::function<void(const Stats&)> stats_callback = {},
50+
bool echo = true,
51+
bool warming = false);
4852
void stop();
4953

5054
LlamaModelOptions get_model_options();

extension/android/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ if(EXECUTORCH_BUILD_LLAMA_JNI)
179179
ADD_LIBRARY(libneuron_buffer_allocator SHARED IMPORTED)
180180
SET_PROPERTY(TARGET libneuron_buffer_allocator PROPERTY IMPORTED_LOCATION ${NEURON_BUFFER_ALLOCATOR_LIB})
181181
list(APPEND link_libraries neuron_backend libneuron_buffer_allocator)
182+
target_compile_definitions(executorch_jni PRIVATE EXECUTORCH_BUILD_MEDIATEK=1)
182183
endif()
183184
endif()
184185

extension/android/jni/jni_layer_llama.cpp

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
#include <unordered_map>
1414
#include <vector>
1515

16-
#include <executorch/examples/mediatek/executor_runner/mtk_llama_runner.h>
1716
#include <executorch/examples/models/llama/runner/runner.h>
1817
#include <executorch/examples/models/llava/runner/llava_runner.h>
1918
#include <executorch/extension/llm/runner/image.h>
19+
#include <executorch/extension/llm/runner/runner_interface.h>
2020
#include <executorch/runtime/platform/log.h>
2121
#include <executorch/runtime/platform/platform.h>
2222
#include <executorch/runtime/platform/runtime.h>
@@ -29,6 +29,10 @@
2929
#include <fbjni/ByteBuffer.h>
3030
#include <fbjni/fbjni.h>
3131

32+
#if defined(EXECUTORCH_BUILD_MEDIATEK)
33+
#include <executorch/examples/mediatek/executor_runner/mtk_llama_runner.h>
34+
#endif
35+
3236
namespace llm = ::executorch::extension::llm;
3337
using ::executorch::runtime::Error;
3438

@@ -112,9 +116,8 @@ class ExecuTorchLlamaJni
112116
private:
113117
friend HybridBase;
114118
int model_type_category_;
115-
std::unique_ptr<example::Runner> runner_;
119+
std::unique_ptr<llm::RunnerInterface> runner_;
116120
std::unique_ptr<llm::MultimodalRunner> multi_modal_runner_;
117-
std::unique_ptr<MTKLlamaRunner> mtk_llama_runner_;
118121

119122
public:
120123
constexpr static auto kJavaDescriptor =
@@ -161,11 +164,15 @@ class ExecuTorchLlamaJni
161164
model_path->toStdString().c_str(),
162165
tokenizer_path->toStdString().c_str(),
163166
temperature);
167+
#if defined(EXECUTORCH_BUILD_MEDIATEK)
164168
} else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
165-
mtk_llama_runner_ = std::make_unique<MTKLlamaRunner>(
169+
runner_ = std::make_unique<MTKLlamaRunner>(
166170
model_path->toStdString().c_str(),
167171
tokenizer_path->toStdString().c_str(),
168172
temperature);
173+
// Interpret the model type as LLM
174+
model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
175+
#endif
169176
}
170177
}
171178

@@ -205,12 +212,6 @@ class ExecuTorchLlamaJni
205212
[callback](std::string result) { callback->onResult(result); },
206213
[callback](const llm::Stats& result) { callback->onStats(result); },
207214
echo);
208-
} else if (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) {
209-
mtk_llama_runner_->generate(
210-
prompt->toStdString(),
211-
seq_len,
212-
[callback](std::string result) { callback->onResult(result); },
213-
[callback](const Stats& result) { callback->onStats(result); });
214215
}
215216
return 0;
216217
}
@@ -300,8 +301,6 @@ class ExecuTorchLlamaJni
300301
multi_modal_runner_->stop();
301302
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
302303
runner_->stop();
303-
} else if (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) {
304-
mtk_llama_runner_->stop();
305304
}
306305
}
307306

@@ -310,8 +309,6 @@ class ExecuTorchLlamaJni
310309
return static_cast<jint>(multi_modal_runner_->load());
311310
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
312311
return static_cast<jint>(runner_->load());
313-
} else if (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) {
314-
return static_cast<jint>(mtk_llama_runner_->load());
315312
}
316313
return static_cast<jint>(Error::InvalidArgument);
317314
}

0 commit comments

Comments
 (0)