Skip to content

Commit 6837d46

Browse files
ggerganovarthw
authored andcommitted
llama.android : fix build (ggml-org#9350)
1 parent 1110222 commit 6837d46

File tree

2 files changed

+39
-15
lines changed

2 files changed

+39
-15
lines changed

examples/llama.android/llama/src/main/cpp/llama-android.cpp

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -269,12 +269,6 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
269269
return env->NewStringUTF(result.str().c_str());
270270
}
271271

272-
extern "C"
273-
JNIEXPORT void JNICALL
274-
Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) {
275-
llama_batch_free(*reinterpret_cast<llama_batch *>(batch_pointer));
276-
}
277-
278272
extern "C"
279273
JNIEXPORT jlong JNICALL
280274
Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) {
@@ -311,6 +305,29 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens,
311305
return reinterpret_cast<jlong>(batch);
312306
}
313307

308+
extern "C"
309+
JNIEXPORT void JNICALL
310+
Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) {
311+
llama_batch_free(*reinterpret_cast<llama_batch *>(batch_pointer));
312+
}
313+
314+
extern "C"
315+
JNIEXPORT jlong JNICALL
316+
Java_android_llama_cpp_LLamaAndroid_new_1sampler(JNIEnv *, jobject) {
317+
auto sparams = llama_sampler_chain_default_params();
318+
sparams.no_perf = true;
319+
llama_sampler * smpl = llama_sampler_chain_init(sparams);
320+
llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
321+
322+
return reinterpret_cast<jlong>(smpl);
323+
}
324+
325+
extern "C"
326+
JNIEXPORT void JNICALL
327+
Java_android_llama_cpp_LLamaAndroid_free_1sampler(JNIEnv *, jobject, jlong sampler_pointer) {
328+
llama_sampler_free(reinterpret_cast<llama_sampler *>(sampler_pointer));
329+
}
330+
314331
extern "C"
315332
JNIEXPORT void JNICALL
316333
Java_android_llama_cpp_LLamaAndroid_backend_1init(JNIEnv *, jobject) {
@@ -380,24 +397,24 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
380397
JNIEnv * env,
381398
jobject,
382399
jlong context_pointer,
383-
jlong sampling_pointer,
384400
jlong batch_pointer,
401+
jlong sampler_pointer,
385402
jint n_len,
386403
jobject intvar_ncur
387404
) {
388405
const auto context = reinterpret_cast<llama_context *>(context_pointer);
389-
const auto sampling = reinterpret_cast<llama_sampler *>(sampling_pointer);
390-
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
406+
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
407+
const auto sampler = reinterpret_cast<llama_sampler *>(sampler_pointer);
391408
const auto model = llama_get_model(context);
392409

393410
if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur);
394411
if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I");
395412
if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V");
396413

397414
// sample the most likely token
398-
const auto new_token_id = llama_sampler_sample(sampling, context, batch->n_tokens - 1);
415+
const auto new_token_id = llama_sampler_sample(sampler, context, -1);
399416

400-
llama_sampler_accept(sampling, new_token_id);
417+
llama_sampler_accept(sampler, new_token_id);
401418

402419
const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
403420
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {

examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,10 @@ class LLamaAndroid {
4545
private external fun free_context(context: Long)
4646
private external fun backend_init(numa: Boolean)
4747
private external fun backend_free()
48-
private external fun free_batch(batch: Long)
4948
private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long
49+
private external fun free_batch(batch: Long)
50+
private external fun new_sampler(): Long
51+
private external fun free_sampler(sampler: Long)
5052
private external fun bench_model(
5153
context: Long,
5254
model: Long,
@@ -69,6 +71,7 @@ class LLamaAndroid {
6971
private external fun completion_loop(
7072
context: Long,
7173
batch: Long,
74+
sampler: Long,
7275
nLen: Int,
7376
ncur: IntVar
7477
): String?
@@ -101,8 +104,11 @@ class LLamaAndroid {
101104
val batch = new_batch(512, 0, 1)
102105
if (batch == 0L) throw IllegalStateException("new_batch() failed")
103106

107+
val sampler = new_sampler()
108+
if (sampler == 0L) throw IllegalStateException("new_sampler() failed")
109+
104110
Log.i(tag, "Loaded model $pathToModel")
105-
threadLocalState.set(State.Loaded(model, context, batch))
111+
threadLocalState.set(State.Loaded(model, context, batch, sampler))
106112
}
107113
else -> throw IllegalStateException("Model already loaded")
108114
}
@@ -114,7 +120,7 @@ class LLamaAndroid {
114120
is State.Loaded -> {
115121
val ncur = IntVar(completion_init(state.context, state.batch, message, nlen))
116122
while (ncur.value <= nlen) {
117-
val str = completion_loop(state.context, state.batch, nlen, ncur)
123+
val str = completion_loop(state.context, state.batch, state.sampler, nlen, ncur)
118124
if (str == null) {
119125
break
120126
}
@@ -138,6 +144,7 @@ class LLamaAndroid {
138144
free_context(state.context)
139145
free_model(state.model)
140146
free_batch(state.batch)
147+
free_sampler(state.sampler);
141148

142149
threadLocalState.set(State.Idle)
143150
}
@@ -161,7 +168,7 @@ class LLamaAndroid {
161168

162169
private sealed interface State {
163170
data object Idle: State
164-
data class Loaded(val model: Long, val context: Long, val batch: Long): State
171+
data class Loaded(val model: Long, val context: Long, val batch: Long, val sampler: Long): State
165172
}
166173

167174
// Enforce only one instance of Llm.

0 commit comments

Comments
 (0)