Skip to content

Prefill API for JNI #5132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions extension/android/jni/jni_layer_llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,86 @@ class ExecuTorchLlamaJni
return 0;
}

// Returns a tuple of (error, start_pos)
// Contract is valid within an AAR (JNI + corresponding Java code)
// If the first element is not Error::Ok, the other element is undefined.
facebook::jni::local_ref<jlongArray> prefill_prompt(
facebook::jni::alias_ref<jstring> prompt,
jlong start_pos,
jint bos,
jint eos) {
facebook::jni::local_ref<jlongArray> tuple_result =
facebook::jni::make_long_array(2);
if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
tuple_result->pin()[0] = static_cast<jint>(Error::NotSupported);
return tuple_result;
}

auto&& result = multi_modal_runner_->prefill_prompt(
prompt->toStdString(), start_pos, bos, eos);
tuple_result->pin()[0] = static_cast<jint>(Error::Ok);
if (result.ok()) {
tuple_result->pin()[1] = static_cast<jlong>(start_pos);
}
return tuple_result;
}

// Returns a tuple of (error, start_pos)
// Contract is valid within an AAR (JNI + corresponding Java code)
// If the first element is not Error::Ok, the other element is undefined.

facebook::jni::local_ref<jlongArray> prefill_images(
facebook::jni::alias_ref<jintArray> image,
jint width,
jint height,
jint channels,
jlong start_pos) {
facebook::jni::local_ref<jlongArray> tuple_result =
facebook::jni::make_long_array(2);

if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
tuple_result->pin()[0] = static_cast<jint>(Error::NotSupported);
return tuple_result;
}

auto image_size = image->size();
std::vector<Image> images;
if (image_size != 0) {
std::vector<jint> image_data_jint(image_size);
std::vector<uint8_t> image_data(image_size);
image->getRegion(0, image_size, image_data_jint.data());
for (int i = 0; i < image_size; i++) {
image_data[i] = image_data_jint[i];
}
Image image_runner{image_data, width, height, channels};
images.push_back(image_runner);
}
// TODO(hsz): make start_pos a reference and update it here
jint result = static_cast<jint>(
multi_modal_runner_->prefill_images(images, start_pos));
tuple_result->pin()[0] = result;
tuple_result->pin()[1] = static_cast<jlong>(start_pos);
return tuple_result;
}

jint generate_from_pos(
facebook::jni::alias_ref<jstring> prompt,
jint seq_len,
jlong start_pos,
facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback) {
if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
return static_cast<jint>(Error::NotSupported);
}
return static_cast<jint>(multi_modal_runner_->generate_from_pos(
prompt->toStdString(),
seq_len,
start_pos,
[callback](const std::string& result) { callback->onResult(result); },
[callback](const ::executorch::extension::llm::Stats& stats) {
callback->onStats(stats);
}));
}

void stop() {
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
multi_modal_runner_->stop();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,63 @@ public native int generate(
int seqLen,
LlamaCallback llamaCallback);

/**
* Prefill an LLaVA Module with the given images input.
*
* @param image Input image as a byte array
* @param width Input image width
* @param height Input image height
* @param channels Input image number of channels
* @param startPos The starting position in KV cache of the input in the LLM.
* @return The updated starting position in KV cache of the input in the LLM.
* @throws RuntimeException if the prefill failed
*/
public long prefillImages(int[] image, int width, int height, int channels, long startPos) {
long[] nativeResult = prefillImagesNative(image, width, height, channels, startPos);
if (nativeResult[0] != 0) {
throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]);
}
return nativeResult[1];
}

// returns a tuple of (status, updated startPos)
private native long[] prefillImagesNative(
int[] image, int width, int height, int channels, long startPos);

/**
* Prefill an LLaVA Module with the given text input.
*
* @param prompt The text prompt to LLaVA.
* @param startPos The starting position in KV cache of the input in the LLM. It's passed as
* reference and will be updated inside this function.
* @param bos The number of BOS (begin of sequence) token.
* @param eos The number of EOS (end of sequence) token.
* @return The updated starting position in KV cache of the input in the LLM.
* @throws RuntimeException if the prefill failed
*/
public long prefillPrompt(String prompt, long startPos, int bos, int eos) {
long[] nativeResult = prefillPromptNative(prompt, startPos, bos, eos);
if (nativeResult[0] != 0) {
throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]);
}
return nativeResult[1];
}

// returns a tuple of (status, updated startPos)
private native long[] prefillPromptNative(String prompt, long startPos, int bos, int eos);

/**
* Generate tokens from the given prompt, starting from the given position.
*
* @param prompt The text prompt to LLaVA.
* @param seqLen The total sequence length, including the prompt tokens and new tokens.
* @param startPos The starting position in KV cache of the input in the LLM.
* @param llamaCallback callback object to receive results.
* @return The error code.
*/
public native int generateFromPos(
String prompt, int seqLen, long startPos, LlamaCallback callback);

/** Stop current generate() before it finishes. */
@DoNotStrip
public native void stop();
Expand Down
44 changes: 44 additions & 0 deletions extension/llm/runner/multimodal_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,50 @@ class MultimodalRunner {
std::function<void(const std::string&)> token_callback = {},
std::function<void(const Stats&)> stats_callback = {}) = 0;

/**
* Prefill an LLaVA Module with the given images input.
* @param images The image input to LLaVA.
* @param start_pos The starting position in KV cache of the input in the LLM.
* It's passed as reference and will be updated inside this function.
* @return The error status of prefilling images.
*/
virtual runtime::Error prefill_images(
std::vector<Image>& images,
int64_t& start_pos) = 0;

/**
* Prefill an LLaVA Module with the given text input.
* @param prompt The text prompt to LLaVA.
* @param start_pos The starting position in KV cache of the input in the LLM.
* It's passed as reference and will be updated inside this function.
* @param bos The number of BOS (begin of sequence) token.
* @param eos The number of EOS (end of sequence) token.
* @return The generated token of the LLaVA Module after prefill prompt.
*/
virtual runtime::Result<uint64_t> prefill_prompt(
const std::string& prompt,
int64_t& start_pos,
int8_t bos = 0,
int8_t eos = 0) = 0;

/**
* Generate tokens from the given prompt, starting from the given position.
* @param prompt The text prompt to LLaVA.
* @param seq_len The total sequence length, including the prompt tokens and
* new tokens.
* @param start_pos The starting position in KV cache of the input in the LLM.
* @param token_callback What to do after a token is generated.
* @param stats_callback What to do with Stats.
* @return The error code.
*/
virtual runtime::Error generate_from_pos(
const std::string& prompt,
int32_t seq_len = 1024,
int64_t start_pos = 0,
std::function<void(const std::string&)> token_callback = {},
std::function<void(const ::executorch::extension::llm::Stats&)>
stats_callback = {}) = 0;

inline void stop() {
text_token_generator_->stop();
}
Expand Down
Loading