Skip to content

Commit f55ce1f

Browse files
authored
Llava prefill Java API
Add Java API and JNI layer Pull Request resolved: #5132
1 parent 2763233 commit f55ce1f

File tree

3 files changed

+181
-0
lines changed

3 files changed

+181
-0
lines changed

extension/android/jni/jni_layer_llama.cpp

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,86 @@ class ExecuTorchLlamaJni
180180
return 0;
181181
}
182182

183+
// Returns a tuple of (error, start_pos)
184+
// Contract is valid within an AAR (JNI + corresponding Java code)
185+
// If the first element is not Error::Ok, the other element is undefined.
186+
facebook::jni::local_ref<jlongArray> prefill_prompt(
187+
facebook::jni::alias_ref<jstring> prompt,
188+
jlong start_pos,
189+
jint bos,
190+
jint eos) {
191+
facebook::jni::local_ref<jlongArray> tuple_result =
192+
facebook::jni::make_long_array(2);
193+
if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
194+
tuple_result->pin()[0] = static_cast<jint>(Error::NotSupported);
195+
return tuple_result;
196+
}
197+
198+
auto&& result = multi_modal_runner_->prefill_prompt(
199+
prompt->toStdString(), start_pos, bos, eos);
200+
tuple_result->pin()[0] = static_cast<jint>(Error::Ok);
201+
if (result.ok()) {
202+
tuple_result->pin()[1] = static_cast<jlong>(start_pos);
203+
}
204+
return tuple_result;
205+
}
206+
207+
// Returns a tuple of (error, start_pos)
208+
// Contract is valid within an AAR (JNI + corresponding Java code)
209+
// If the first element is not Error::Ok, the other element is undefined.
210+
211+
facebook::jni::local_ref<jlongArray> prefill_images(
212+
facebook::jni::alias_ref<jintArray> image,
213+
jint width,
214+
jint height,
215+
jint channels,
216+
jlong start_pos) {
217+
facebook::jni::local_ref<jlongArray> tuple_result =
218+
facebook::jni::make_long_array(2);
219+
220+
if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
221+
tuple_result->pin()[0] = static_cast<jint>(Error::NotSupported);
222+
return tuple_result;
223+
}
224+
225+
auto image_size = image->size();
226+
std::vector<Image> images;
227+
if (image_size != 0) {
228+
std::vector<jint> image_data_jint(image_size);
229+
std::vector<uint8_t> image_data(image_size);
230+
image->getRegion(0, image_size, image_data_jint.data());
231+
for (int i = 0; i < image_size; i++) {
232+
image_data[i] = image_data_jint[i];
233+
}
234+
Image image_runner{image_data, width, height, channels};
235+
images.push_back(image_runner);
236+
}
237+
// TODO(hsz): make start_pos a reference and update it here
238+
jint result = static_cast<jint>(
239+
multi_modal_runner_->prefill_images(images, start_pos));
240+
tuple_result->pin()[0] = result;
241+
tuple_result->pin()[1] = static_cast<jlong>(start_pos);
242+
return tuple_result;
243+
}
244+
245+
jint generate_from_pos(
246+
facebook::jni::alias_ref<jstring> prompt,
247+
jint seq_len,
248+
jlong start_pos,
249+
facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback) {
250+
if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
251+
return static_cast<jint>(Error::NotSupported);
252+
}
253+
return static_cast<jint>(multi_modal_runner_->generate_from_pos(
254+
prompt->toStdString(),
255+
seq_len,
256+
start_pos,
257+
[callback](const std::string& result) { callback->onResult(result); },
258+
[callback](const ::executorch::extension::llm::Stats& stats) {
259+
callback->onStats(stats);
260+
}));
261+
}
262+
183263
void stop() {
184264
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
185265
multi_modal_runner_->stop();

extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,63 @@ public native int generate(
9494
int seqLen,
9595
LlamaCallback llamaCallback);
9696

97+
/**
98+
* Prefill an LLaVA Module with the given images input.
99+
*
100+
* @param image Input image as a byte array
101+
* @param width Input image width
102+
* @param height Input image height
103+
* @param channels Input image number of channels
104+
* @param startPos The starting position in KV cache of the input in the LLM.
105+
* @return The updated starting position in KV cache of the input in the LLM.
106+
* @throws RuntimeException if the prefill failed
107+
*/
108+
public long prefillImages(int[] image, int width, int height, int channels, long startPos) {
109+
long[] nativeResult = prefillImagesNative(image, width, height, channels, startPos);
110+
if (nativeResult[0] != 0) {
111+
throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]);
112+
}
113+
return nativeResult[1];
114+
}
115+
116+
// returns a tuple of (status, updated startPos)
117+
private native long[] prefillImagesNative(
118+
int[] image, int width, int height, int channels, long startPos);
119+
120+
/**
121+
* Prefill an LLaVA Module with the given text input.
122+
*
123+
* @param prompt The text prompt to LLaVA.
124+
* @param startPos The starting position in KV cache of the input in the LLM. It's passed as
125+
* reference and will be updated inside this function.
126+
* @param bos The number of BOS (begin of sequence) token.
127+
* @param eos The number of EOS (end of sequence) token.
128+
* @return The updated starting position in KV cache of the input in the LLM.
129+
* @throws RuntimeException if the prefill failed
130+
*/
131+
public long prefillPrompt(String prompt, long startPos, int bos, int eos) {
132+
long[] nativeResult = prefillPromptNative(prompt, startPos, bos, eos);
133+
if (nativeResult[0] != 0) {
134+
throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]);
135+
}
136+
return nativeResult[1];
137+
}
138+
139+
// returns a tuple of (status, updated startPos)
140+
private native long[] prefillPromptNative(String prompt, long startPos, int bos, int eos);
141+
142+
/**
143+
* Generate tokens from the given prompt, starting from the given position.
144+
*
145+
* @param prompt The text prompt to LLaVA.
146+
* @param seqLen The total sequence length, including the prompt tokens and new tokens.
147+
* @param startPos The starting position in KV cache of the input in the LLM.
148+
* @param llamaCallback callback object to receive results.
149+
* @return The error code.
150+
*/
151+
public native int generateFromPos(
152+
String prompt, int seqLen, long startPos, LlamaCallback callback);
153+
97154
/** Stop current generate() before it finishes. */
98155
@DoNotStrip
99156
public native void stop();

extension/llm/runner/multimodal_runner.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,50 @@ class MultimodalRunner {
6161
std::function<void(const std::string&)> token_callback = {},
6262
std::function<void(const Stats&)> stats_callback = {}) = 0;
6363

64+
/**
65+
* Prefill an LLaVA Module with the given images input.
66+
* @param images The image input to LLaVA.
67+
* @param start_pos The starting position in KV cache of the input in the LLM.
68+
* It's passed as reference and will be updated inside this function.
69+
* @return The error status of prefilling images.
70+
*/
71+
virtual runtime::Error prefill_images(
72+
std::vector<Image>& images,
73+
int64_t& start_pos) = 0;
74+
75+
/**
76+
* Prefill an LLaVA Module with the given text input.
77+
* @param prompt The text prompt to LLaVA.
78+
* @param start_pos The starting position in KV cache of the input in the LLM.
79+
* It's passed as reference and will be updated inside this function.
80+
* @param bos The number of BOS (begin of sequence) token.
81+
* @param eos The number of EOS (end of sequence) token.
82+
* @return The generated token of the LLaVA Module after prefill prompt.
83+
*/
84+
virtual runtime::Result<uint64_t> prefill_prompt(
85+
const std::string& prompt,
86+
int64_t& start_pos,
87+
int8_t bos = 0,
88+
int8_t eos = 0) = 0;
89+
90+
/**
91+
* Generate tokens from the given prompt, starting from the given position.
92+
* @param prompt The text prompt to LLaVA.
93+
* @param seq_len The total sequence length, including the prompt tokens and
94+
* new tokens.
95+
* @param start_pos The starting position in KV cache of the input in the LLM.
96+
* @param token_callback What to do after a token is generated.
97+
* @param stats_callback What to do with Stats.
98+
* @return The error code.
99+
*/
100+
virtual runtime::Error generate_from_pos(
101+
const std::string& prompt,
102+
int32_t seq_len = 1024,
103+
int64_t start_pos = 0,
104+
std::function<void(const std::string&)> token_callback = {},
105+
std::function<void(const ::executorch::extension::llm::Stats&)>
106+
stats_callback = {}) = 0;
107+
64108
inline void stop() {
65109
text_token_generator_->stop();
66110
}

0 commit comments

Comments
 (0)