Skip to content

Commit f975556

Browse files
committed
simplify user facing API to return startPos only
1 parent 9e24e2f commit f975556

File tree

1 file changed

+28
-16
lines changed

1 file changed

+28
-16
lines changed

extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -94,19 +94,6 @@ public native int generate(
9494
int seqLen,
9595
LlamaCallback llamaCallback);
9696

97-
/**
98-
* Prefill an LLaVA Module with the given text input.
99-
*
100-
* @param prompt The text prompt to LLaVA.
101-
* @param startPos The starting position in KV cache of the input in the LLM. It's passed as
102-
* reference and will be updated inside this function.
103-
* @param bos The number of BOS (begin of sequence) token.
104-
* @param eos The number of EOS (end of sequence) token.
105-
* @return a tuple of (error, token, updated startPos)
106-
*/
107-
public static native long[] prefillPrompt(
108-
String prompt, long startPos, int bos, int eos);
109-
11097
/**
11198
* Prefill an LLaVA Module with the given images input.
11299
*
@@ -115,11 +102,36 @@ public static native long[] prefillPrompt(
115102
* @param height Input image height
116103
* @param channels Input image number of channels
117104
* @param startPos The starting position in KV cache of the input in the LLM.
118-
* @return a tuple of (error code, updated startPos)
105+
* @return The updated starting position in KV cache of the input in the LLM.
119106
*/
120-
public static native long[] prefillImages(
107+
public long prefillImages(
108+
int[] image, int width, int height, int channels, long startPos) {
109+
return prefillImagesNative(image, width, height, channels, startPos)[1];
110+
}
111+
112+
// returns a tuple of (error code, updated startPos)
113+
private native long[] prefillImagesNative(
121114
int[] image, int width, int height, int channels, long startPos);
122115

116+
/**
117+
* Prefill an LLaVA Module with the given text input.
118+
*
119+
* @param prompt The text prompt to LLaVA.
120+
* @param startPos The starting position in KV cache of the input in the LLM. It's passed as
121+
* reference and will be updated inside this function.
122+
* @param bos The number of BOS (begin of sequence) token.
123+
* @param eos The number of EOS (end of sequence) token.
124+
* @return The updated starting position in KV cache of the input in the LLM.
125+
*/
126+
public long prefillPrompt(String prompt, long startPos, int bos, int eos) {
127+
return prefillPromptNative(prompt, startPos, bos, eos)[2];
128+
}
129+
130+
131+
// returns a tuple of (error, token, updated startPos)
132+
private native long[] prefillPromptNative(
133+
String prompt, long startPos, int bos, int eos);
134+
123135
/**
124136
* Generate tokens from the given prompt, starting from the given position.
125137
*
@@ -129,7 +141,7 @@ public static native long[] prefillImages(
129141
* @param llamaCallback callback object to receive results.
130142
* @return The error code.
131143
*/
132-
public static native int generateFromPos(
144+
public native int generateFromPos(
133145
String prompt, int seqLen, long startPos, LlamaCallback callback);
134146

135147
/** Stop current generate() before it finishes. */

0 commit comments

Comments
 (0)