15
15
#include < unordered_map>
16
16
#include < vector>
17
17
18
- #include < executorch/examples/mediatek/executor_runner/mtk_llama_runner.h>
19
18
#include < executorch/examples/models/llama/runner/runner.h>
20
19
#include < executorch/examples/models/llava/runner/llava_runner.h>
21
20
#include < executorch/extension/llm/runner/image.h>
21
+ #include < executorch/extension/llm/runner/runner_interface.h>
22
22
#include < executorch/runtime/platform/log.h>
23
23
#include < executorch/runtime/platform/platform.h>
24
24
#include < executorch/runtime/platform/runtime.h>
31
31
#include < fbjni/ByteBuffer.h>
32
32
#include < fbjni/fbjni.h>
33
33
34
+ #if defined(EXECUTORCH_BUILD_MEDIATEK)
35
+ #include < executorch/examples/mediatek/executor_runner/mtk_llama_runner.h>
36
+ #endif
37
+
34
38
namespace llm = ::executorch::extension::llm;
35
39
using ::executorch::runtime::Error;
36
40
@@ -68,9 +72,8 @@ class ExecuTorchLlamaJni
68
72
private:
69
73
friend HybridBase;
70
74
int model_type_category_;
71
- std::unique_ptr<example::Runner > runner_;
75
+ std::unique_ptr<llm::RunnerInterface > runner_;
72
76
std::unique_ptr<llm::MultimodalRunner> multi_modal_runner_;
73
- std::unique_ptr<MTKLlamaRunner> mtk_llama_runner_;
74
77
75
78
public:
76
79
constexpr static auto kJavaDescriptor =
@@ -117,11 +120,15 @@ class ExecuTorchLlamaJni
117
120
model_path->toStdString ().c_str (),
118
121
tokenizer_path->toStdString ().c_str (),
119
122
temperature);
123
+ #if defined(EXECUTORCH_BUILD_MEDIATEK)
120
124
} else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
121
- mtk_llama_runner_ = std::make_unique<MTKLlamaRunner>(
125
+ runner_ = std::make_unique<MTKLlamaRunner>(
122
126
model_path->toStdString ().c_str (),
123
127
tokenizer_path->toStdString ().c_str (),
124
128
temperature);
129
+ // Interpret the model type as LLM
130
+ model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
131
+ #endif
125
132
}
126
133
}
127
134
@@ -161,12 +168,6 @@ class ExecuTorchLlamaJni
161
168
[callback](std::string result) { callback->onResult (result); },
162
169
[callback](const llm::Stats& result) { callback->onStats (result); },
163
170
echo);
164
- } else if (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) {
165
- mtk_llama_runner_->generate (
166
- prompt->toStdString (),
167
- seq_len,
168
- [callback](std::string result) { callback->onResult (result); },
169
- [callback](const Stats& result) { callback->onStats (result); });
170
171
}
171
172
return 0 ;
172
173
}
@@ -256,8 +257,6 @@ class ExecuTorchLlamaJni
256
257
multi_modal_runner_->stop ();
257
258
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
258
259
runner_->stop ();
259
- } else if (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) {
260
- mtk_llama_runner_->stop ();
261
260
}
262
261
}
263
262
@@ -266,8 +265,6 @@ class ExecuTorchLlamaJni
266
265
return static_cast <jint>(multi_modal_runner_->load ());
267
266
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
268
267
return static_cast <jint>(runner_->load ());
269
- } else if (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) {
270
- return static_cast <jint>(mtk_llama_runner_->load ());
271
268
}
272
269
return static_cast <jint>(Error::InvalidArgument);
273
270
}
0 commit comments