@@ -269,12 +269,6 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
269
269
return env->NewStringUTF (result.str ().c_str ());
270
270
}
271
271
272
- extern " C"
273
- JNIEXPORT void JNICALL
274
- Java_android_llama_cpp_LLamaAndroid_free_1batch (JNIEnv *, jobject, jlong batch_pointer) {
275
- llama_batch_free (*reinterpret_cast <llama_batch *>(batch_pointer));
276
- }
277
-
278
272
extern " C"
279
273
JNIEXPORT jlong JNICALL
280
274
Java_android_llama_cpp_LLamaAndroid_new_1batch (JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) {
@@ -311,6 +305,29 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens,
311
305
return reinterpret_cast <jlong>(batch);
312
306
}
313
307
308
+ extern " C"
309
+ JNIEXPORT void JNICALL
310
+ Java_android_llama_cpp_LLamaAndroid_free_1batch (JNIEnv *, jobject, jlong batch_pointer) {
311
+ llama_batch_free (*reinterpret_cast <llama_batch *>(batch_pointer));
312
+ }
313
+
314
+ extern " C"
315
+ JNIEXPORT jlong JNICALL
316
+ Java_android_llama_cpp_LLamaAndroid_new_1sampler (JNIEnv *, jobject) {
317
+ auto sparams = llama_sampler_chain_default_params ();
318
+ sparams.no_perf = true ;
319
+ llama_sampler * smpl = llama_sampler_chain_init (sparams);
320
+ llama_sampler_chain_add (smpl, llama_sampler_init_greedy ());
321
+
322
+ return reinterpret_cast <jlong>(smpl);
323
+ }
324
+
325
+ extern " C"
326
+ JNIEXPORT void JNICALL
327
+ Java_android_llama_cpp_LLamaAndroid_free_1sampler (JNIEnv *, jobject, jlong sampler_pointer) {
328
+ llama_sampler_free (reinterpret_cast <llama_sampler *>(sampler_pointer));
329
+ }
330
+
314
331
extern " C"
315
332
JNIEXPORT void JNICALL
316
333
Java_android_llama_cpp_LLamaAndroid_backend_1init (JNIEnv *, jobject) {
@@ -380,24 +397,24 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
380
397
JNIEnv * env,
381
398
jobject,
382
399
jlong context_pointer,
383
- jlong sampling_pointer,
384
400
jlong batch_pointer,
401
+ jlong sampler_pointer,
385
402
jint n_len,
386
403
jobject intvar_ncur
387
404
) {
388
405
const auto context = reinterpret_cast <llama_context *>(context_pointer);
389
- const auto sampling = reinterpret_cast <llama_sampler *>(sampling_pointer );
390
- const auto batch = reinterpret_cast <llama_batch *>(batch_pointer );
406
+ const auto batch = reinterpret_cast <llama_batch *>(batch_pointer );
407
+ const auto sampler = reinterpret_cast <llama_sampler *>(sampler_pointer );
391
408
const auto model = llama_get_model (context);
392
409
393
410
if (!la_int_var) la_int_var = env->GetObjectClass (intvar_ncur);
394
411
if (!la_int_var_value) la_int_var_value = env->GetMethodID (la_int_var, " getValue" , " ()I" );
395
412
if (!la_int_var_inc) la_int_var_inc = env->GetMethodID (la_int_var, " inc" , " ()V" );
396
413
397
414
// sample the most likely token
398
- const auto new_token_id = llama_sampler_sample (sampling , context, batch-> n_tokens - 1 );
415
+ const auto new_token_id = llama_sampler_sample (sampler , context, - 1 );
399
416
400
- llama_sampler_accept (sampling , new_token_id);
417
+ llama_sampler_accept (sampler , new_token_id);
401
418
402
419
const auto n_cur = env->CallIntMethod (intvar_ncur, la_int_var_value);
403
420
if (llama_token_is_eog (model, new_token_id) || n_cur == n_len) {
0 commit comments