17
17
18
18
#include < executorch/examples/models/llama2/runner/runner.h>
19
19
#include < executorch/examples/models/llava/runner/llava_runner.h>
20
+ #include < executorch/examples/mediatek/executor_runner/mtk_llama_runner.h>
20
21
#include < executorch/extension/llm/runner/image.h>
21
22
#include < executorch/runtime/platform/log.h>
22
23
#include < executorch/runtime/platform/platform.h>
@@ -69,13 +70,15 @@ class ExecuTorchLlamaJni
69
70
int model_type_category_;
70
71
std::unique_ptr<example::Runner> runner_;
71
72
std::unique_ptr<llm::MultimodalRunner> multi_modal_runner_;
73
+ std::unique_ptr<MTKLlamaRunner> mtk_llama_runner_;
72
74
73
75
public:
74
76
constexpr static auto kJavaDescriptor =
75
77
" Lorg/pytorch/executorch/LlamaModule;" ;
76
78
77
79
constexpr static int MODEL_TYPE_CATEGORY_LLM = 1 ;
78
80
constexpr static int MODEL_TYPE_CATEGORY_MULTIMODAL = 2 ;
81
+ constexpr static int MODEL_TYPE_MEDIATEK_LLAMA = 3 ;
79
82
80
83
static facebook::jni::local_ref<jhybriddata> initHybrid (
81
84
facebook::jni::alias_ref<jclass>,
@@ -114,6 +117,11 @@ class ExecuTorchLlamaJni
114
117
model_path->toStdString ().c_str (),
115
118
tokenizer_path->toStdString ().c_str (),
116
119
temperature);
120
+ } else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
121
+ mtk_llama_runner_ = std::make_unique<MTKLlamaRunner>(
122
+ model_path->toStdString ().c_str (),
123
+ tokenizer_path->toStdString ().c_str (),
124
+ temperature);
117
125
}
118
126
}
119
127
@@ -153,6 +161,12 @@ class ExecuTorchLlamaJni
153
161
[callback](std::string result) { callback->onResult (result); },
154
162
[callback](const llm::Stats& result) { callback->onStats (result); },
155
163
echo);
164
+ } else if (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) {
165
+ mtk_llama_runner_->generate (
166
+ prompt->toStdString (),
167
+ seq_len,
168
+ [callback](std::string result) { callback->onResult (result); },
169
+ [callback](const Stats& result) { callback->onStats (result); });
156
170
}
157
171
return 0 ;
158
172
}
@@ -242,6 +256,8 @@ class ExecuTorchLlamaJni
242
256
multi_modal_runner_->stop ();
243
257
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
244
258
runner_->stop ();
259
+ } else if (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) {
260
+ mtk_llama_runner_->stop ();
245
261
}
246
262
}
247
263
@@ -250,6 +266,8 @@ class ExecuTorchLlamaJni
250
266
return static_cast <jint>(multi_modal_runner_->load ());
251
267
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
252
268
return static_cast <jint>(runner_->load ());
269
+ } else if (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) {
270
+ return static_cast <jint>(mtk_llama_runner_->load ());
253
271
}
254
272
return static_cast <jint>(Error::InvalidArgument);
255
273
}
0 commit comments