Skip to content

Commit 5424791

Browse files
committed
Enable JNI with MTK Llama Runner core functions
1 parent e3706d3 commit 5424791

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

extension/android/jni/jni_layer_llama.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include <executorch/examples/models/llama2/runner/runner.h>
1919
#include <executorch/examples/models/llava/runner/llava_runner.h>
20+
#include <executorch/examples/mediatek/executor_runner/mtk_llama_runner.h>
2021
#include <executorch/extension/llm/runner/image.h>
2122
#include <executorch/runtime/platform/log.h>
2223
#include <executorch/runtime/platform/platform.h>
@@ -69,13 +70,15 @@ class ExecuTorchLlamaJni
6970
int model_type_category_;
7071
std::unique_ptr<example::Runner> runner_;
7172
std::unique_ptr<llm::MultimodalRunner> multi_modal_runner_;
73+
std::unique_ptr<MTKLlamaRunner> mtk_llama_runner_;
7274

7375
public:
7476
constexpr static auto kJavaDescriptor =
7577
"Lorg/pytorch/executorch/LlamaModule;";
7678

7779
constexpr static int MODEL_TYPE_CATEGORY_LLM = 1;
7880
constexpr static int MODEL_TYPE_CATEGORY_MULTIMODAL = 2;
81+
constexpr static int MODEL_TYPE_MEDIATEK_LLAMA = 3;
7982

8083
static facebook::jni::local_ref<jhybriddata> initHybrid(
8184
facebook::jni::alias_ref<jclass>,
@@ -114,6 +117,11 @@ class ExecuTorchLlamaJni
114117
model_path->toStdString().c_str(),
115118
tokenizer_path->toStdString().c_str(),
116119
temperature);
120+
} else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
121+
mtk_llama_runner_ = std::make_unique<MTKLlamaRunner>(
122+
model_path->toStdString().c_str(),
123+
tokenizer_path->toStdString().c_str(),
124+
temperature);
117125
}
118126
}
119127

@@ -153,6 +161,12 @@ class ExecuTorchLlamaJni
153161
[callback](std::string result) { callback->onResult(result); },
154162
[callback](const llm::Stats& result) { callback->onStats(result); },
155163
echo);
164+
} else if (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) {
165+
mtk_llama_runner_->generate(
166+
prompt->toStdString(),
167+
seq_len,
168+
[callback](std::string result) { callback->onResult(result); },
169+
[callback](const Stats& result) { callback->onStats(result); });
156170
}
157171
return 0;
158172
}
@@ -242,6 +256,8 @@ class ExecuTorchLlamaJni
242256
multi_modal_runner_->stop();
243257
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
244258
runner_->stop();
259+
} else if (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) {
260+
mtk_llama_runner_->stop();
245261
}
246262
}
247263

@@ -250,6 +266,8 @@ class ExecuTorchLlamaJni
250266
return static_cast<jint>(multi_modal_runner_->load());
251267
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
252268
return static_cast<jint>(runner_->load());
269+
} else if (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) {
270+
return static_cast<jint>(mtk_llama_runner_->load());
253271
}
254272
return static_cast<jint>(Error::InvalidArgument);
255273
}

0 commit comments

Comments
 (0)