Skip to content

Commit 8d375bf

Browse files
committed
Address comments about what to expose
1 parent 7b5a0bc commit 8d375bf

File tree

2 files changed

+7
-9
lines changed

2 files changed

+7
-9
lines changed

extension/android/jni/jni_layer_llama.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -180,17 +180,16 @@ class ExecuTorchLlamaJni
180180
return 0;
181181
}
182182

183-
// Returns a tuple of (error, token, start_pos)
183+
// Returns a tuple of (error, start_pos)
184184
// Contract is valid within an AAR (JNI + corresponding Java code)
185-
// If the first element is not Error::Ok, the other two elements are
186-
// undefined.
185+
// If the first element is not Error::Ok, the other element is undefined.
187186
facebook::jni::local_ref<jlongArray> prefill_prompt(
188187
facebook::jni::alias_ref<jstring> prompt,
189188
jlong start_pos,
190189
jint bos,
191190
jint eos) {
192191
facebook::jni::local_ref<jlongArray> tuple_result =
193-
facebook::jni::make_long_array(3);
192+
facebook::jni::make_long_array(2);
194193
if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
195194
tuple_result->pin()[0] = static_cast<jint>(Error::NotSupported);
196195
return tuple_result;
@@ -200,8 +199,7 @@ class ExecuTorchLlamaJni
200199
prompt->toStdString(), start_pos, bos, eos);
201200
tuple_result->pin()[0] = static_cast<jint>(Error::Ok);
202201
if (result.ok()) {
203-
tuple_result->pin()[1] = static_cast<jlong>(result.get());
204-
tuple_result->pin()[2] = static_cast<jlong>(start_pos);
202+
tuple_result->pin()[1] = static_cast<jlong>(start_pos);
205203
}
206204
return tuple_result;
207205
}

extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ public long prefillImages(int[] image, int width, int height, int channels, long
113113
return nativeResult[1];
114114
}
115115

116-
// returns a tuple of (error code, updated startPos)
116+
// returns a tuple of (status, updated startPos)
117117
private native long[] prefillImagesNative(
118118
int[] image, int width, int height, int channels, long startPos);
119119

@@ -133,10 +133,10 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) {
133133
if (nativeResult[0] != 0) {
134134
throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]);
135135
}
136-
return nativeResult[2];
136+
return nativeResult[1];
137137
}
138138

139-
// returns a tuple of (error, token, updated startPos)
139+
// returns a tuple of (status, updated startPos)
140140
private native long[] prefillPromptNative(String prompt, long startPos, int bos, int eos);
141141

142142
/**

0 commit comments

Comments
 (0)