Skip to content

Commit e26af08

Browse files
committed
Android JNI llama cache temperature in class
1 parent ef7d4ca commit e26af08

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

extension/android/jni/jni_layer_llama.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ class ExecuTorchLlmCallbackJni
114114
class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
115115
private:
116116
friend HybridBase;
117+
float temperature_;
117118
int model_type_category_;
118119
std::unique_ptr<llm::IRunner> runner_;
119120
std::unique_ptr<llm::MultimodalRunner> multi_modal_runner_;
@@ -149,7 +150,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
149150
facebook::jni::alias_ref<jstring> data_path = nullptr) {
150151
#if defined(ET_USE_THREADPOOL)
151152
// Reserve 1 thread for the main thread.
152-
uint32_t num_performant_cores =
153+
int32_t num_performant_cores =
153154
::executorch::extension::cpuinfo::get_num_performant_cores() - 1;
154155
if (num_performant_cores > 0) {
155156
ET_LOG(Info, "Resetting threadpool to %d threads", num_performant_cores);
@@ -169,20 +170,17 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
169170
runner_ = std::make_unique<example::Runner>(
170171
model_path->toStdString().c_str(),
171172
tokenizer_path->toStdString().c_str(),
172-
temperature,
173173
data_path->toStdString().c_str());
174174
} else {
175175
runner_ = std::make_unique<example::Runner>(
176176
model_path->toStdString().c_str(),
177-
tokenizer_path->toStdString().c_str(),
178-
temperature);
177+
tokenizer_path->toStdString().c_str());
179178
}
180179
#if defined(EXECUTORCH_BUILD_MEDIATEK)
181180
} else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
182181
runner_ = std::make_unique<MTKLlamaRunner>(
183182
model_path->toStdString().c_str(),
184-
tokenizer_path->toStdString().c_str(),
185-
temperature);
183+
tokenizer_path->toStdString().c_str());
186184
// Interpret the model type as LLM
187185
model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
188186
#endif
@@ -222,6 +220,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
222220
executorch::extension::llm::GenerationConfig config{
223221
.echo = static_cast<bool>(echo),
224222
.seq_len = seq_len,
223+
.temperature = temperature_,
225224
};
226225
runner_->generate(
227226
prompt->toStdString(),

0 commit comments

Comments
 (0)