Skip to content

Commit 9e24e2f

Browse files
committed
fix
1 parent 6715585 commit 9e24e2f

File tree

2 files changed

+7
-10
lines changed

2 files changed

+7
-10
lines changed

extension/android/jni/jni_layer_llama.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,7 @@ class ExecuTorchLlamaJni
188188
facebook::jni::alias_ref<jstring> prompt,
189189
jlong start_pos,
190190
jint bos,
191-
jint eos,
192-
jlong generated_token) {
191+
jint eos) {
193192
facebook::jni::local_ref<jlongArray> tuple_result =
194193
facebook::jni::make_long_array(3);
195194
if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
@@ -201,9 +200,7 @@ class ExecuTorchLlamaJni
201200
prompt->toStdString(), start_pos, bos, eos);
202201
tuple_result->pin()[0] = static_cast<jint>(Error::Ok);
203202
if (result.ok()) {
204-
// TODO(hsz): make generated_token a reference and update it here
205-
generated_token = result.get();
206-
tuple_result->pin()[1] = static_cast<jlong>(generated_token);
203+
tuple_result->pin()[1] = static_cast<jlong>(result.get());
207204
tuple_result->pin()[2] = static_cast<jlong>(start_pos);
208205
}
209206
return tuple_result;

extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,8 @@ public native int generate(
104104
* @param eos The number of EOS (end of sequence) token.
105105
* @return a tuple of (error, token, updated startPos)
106106
*/
107-
public static native long[] prefill_prompt(
108-
String prompt, long startPos, int bos, int eos, long generatedToken);
107+
public static native long[] prefillPrompt(
108+
String prompt, long startPos, int bos, int eos);
109109

110110
/**
111111
* Prefill an LLaVA Module with the given images input.
@@ -117,7 +117,7 @@ public static native long[] prefill_prompt(
117117
* @param startPos The starting position in KV cache of the input in the LLM.
118118
* @return a tuple of (error code, updated startPos)
119119
*/
120-
public static native long[] prefill_images(
120+
public static native long[] prefillImages(
121121
int[] image, int width, int height, int channels, long startPos);
122122

123123
/**
@@ -129,8 +129,8 @@ public static native long[] prefill_images(
129129
* @param llamaCallback callback object to receive results.
130130
* @return The error code.
131131
*/
132-
public static native int generate_from_pos(
133-
String prompt, int seqLen, long startPos, ExecuTorchLlamaCallback callback);
132+
public static native int generateFromPos(
133+
String prompt, int seqLen, long startPos, LlamaCallback callback);
134134

135135
/** Stop current generate() before it finishes. */
136136
@DoNotStrip

0 commit comments

Comments
 (0)