13
13
#include < unordered_map>
14
14
#include < vector>
15
15
16
- #include < executorch/examples/mediatek/executor_runner/mtk_llama_runner.h>
17
16
#include < executorch/examples/models/llama/runner/runner.h>
18
17
#include < executorch/examples/models/llava/runner/llava_runner.h>
19
18
#include < executorch/extension/llm/runner/image.h>
19
+ #include < executorch/extension/llm/runner/runner_interface.h>
20
20
#include < executorch/runtime/platform/log.h>
21
21
#include < executorch/runtime/platform/platform.h>
22
22
#include < executorch/runtime/platform/runtime.h>
29
29
#include < fbjni/ByteBuffer.h>
30
30
#include < fbjni/fbjni.h>
31
31
32
+ #if defined(EXECUTORCH_BUILD_MEDIATEK)
33
+ #include < executorch/examples/mediatek/executor_runner/mtk_llama_runner.h>
34
+ #endif
35
+
32
36
namespace llm = ::executorch::extension::llm;
33
37
using ::executorch::runtime::Error;
34
38
@@ -112,9 +116,8 @@ class ExecuTorchLlamaJni
112
116
private:
113
117
friend HybridBase;
114
118
int model_type_category_;
115
- std::unique_ptr<example::Runner > runner_;
119
+ std::unique_ptr<llm::RunnerInterface > runner_;
116
120
std::unique_ptr<llm::MultimodalRunner> multi_modal_runner_;
117
- std::unique_ptr<MTKLlamaRunner> mtk_llama_runner_;
118
121
119
122
public:
120
123
constexpr static auto kJavaDescriptor =
@@ -161,11 +164,15 @@ class ExecuTorchLlamaJni
161
164
model_path->toStdString ().c_str (),
162
165
tokenizer_path->toStdString ().c_str (),
163
166
temperature);
167
+ #if defined(EXECUTORCH_BUILD_MEDIATEK)
164
168
} else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
165
- mtk_llama_runner_ = std::make_unique<MTKLlamaRunner>(
169
+ runner_ = std::make_unique<MTKLlamaRunner>(
166
170
model_path->toStdString ().c_str (),
167
171
tokenizer_path->toStdString ().c_str (),
168
172
temperature);
173
+ // Interpret the model type as LLM
174
+ model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
175
+ #endif
169
176
}
170
177
}
171
178
@@ -205,12 +212,6 @@ class ExecuTorchLlamaJni
205
212
[callback](std::string result) { callback->onResult (result); },
206
213
[callback](const llm::Stats& result) { callback->onStats (result); },
207
214
echo);
208
- } else if (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) {
209
- mtk_llama_runner_->generate (
210
- prompt->toStdString (),
211
- seq_len,
212
- [callback](std::string result) { callback->onResult (result); },
213
- [callback](const Stats& result) { callback->onStats (result); });
214
215
}
215
216
return 0 ;
216
217
}
@@ -300,8 +301,6 @@ class ExecuTorchLlamaJni
300
301
multi_modal_runner_->stop ();
301
302
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
302
303
runner_->stop ();
303
- } else if (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) {
304
- mtk_llama_runner_->stop ();
305
304
}
306
305
}
307
306
@@ -310,8 +309,6 @@ class ExecuTorchLlamaJni
310
309
return static_cast <jint>(multi_modal_runner_->load ());
311
310
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
312
311
return static_cast <jint>(runner_->load ());
313
- } else if (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) {
314
- return static_cast <jint>(mtk_llama_runner_->load ());
315
312
}
316
313
return static_cast <jint>(Error::InvalidArgument);
317
314
}
0 commit comments