Skip to content

Add Echo parameter to multimodal runner (llava) and jni layer #5181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -675,8 +675,8 @@ public void run() {
ModelUtils.VISION_MODEL_IMAGE_CHANNELS,
rawPrompt,
ModelUtils.VISION_MODEL_SEQ_LEN,
false,
MainActivity.this);
MainActivity.this,
false);
} else {
// no image selected, we pass in empty int array
mModule.generate(
Expand All @@ -686,8 +686,8 @@ public void run() {
ModelUtils.VISION_MODEL_IMAGE_CHANNELS,
rawPrompt,
ModelUtils.VISION_MODEL_SEQ_LEN,
false,
MainActivity.this);
MainActivity.this,
false);
}
} else {
String finalPrompt =
Expand All @@ -696,8 +696,8 @@ public void run() {
mModule.generate(
finalPrompt,
(int) (finalPrompt.length() * 0.75) + 64,
false,
MainActivity.this);
MainActivity.this,
false);
}

long generateDuration = System.currentTimeMillis() - generateStartTime;
Expand Down
14 changes: 9 additions & 5 deletions examples/models/llava/runner/llava_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,12 @@ Error LlavaRunner::generate_from_pos(
int64_t start_pos,
std::function<void(const std::string&)> token_callback,
std::function<void(const ::executorch::extension::llm::Stats&)>
stats_callback) {
stats_callback,
bool echo) {
// prefill user prompt. No BOS because preset prompt already has it.
token_callback(prompt);
if (echo) {
token_callback(prompt);
}

uint64_t prefill_next_token =
ET_UNWRAP(prefill_prompt(prompt, start_pos, /*bos=*/0, /*eos*/ 0));
Expand All @@ -125,7 +128,8 @@ Error LlavaRunner::generate(
const std::string& prompt,
int32_t seq_len,
std::function<void(const std::string&)> token_callback,
std::function<void(const Stats&)> stats_callback) {
std::function<void(const Stats&)> stats_callback,
bool echo) {
ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
if (!is_loaded()) {
ET_CHECK_OK_OR_RETURN_ERROR(load());
Expand Down Expand Up @@ -160,8 +164,8 @@ Error LlavaRunner::generate(
util::get_rss_bytes() / 1024.0 / 1024.0);

// Generate tokens
Error err =
generate_from_pos(prompt, seq_len, pos, wrapped_callback, stats_callback);
Error err = generate_from_pos(
prompt, seq_len, pos, wrapped_callback, stats_callback, echo);

ET_LOG(
Info,
Expand Down
7 changes: 5 additions & 2 deletions examples/models/llava/runner/llava_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ class LlavaRunner : public MultimodalRunner {
int32_t seq_len = 1024,
std::function<void(const std::string&)> token_callback = {},
std::function<void(const ::executorch::extension::llm::Stats&)>
stats_callback = {});
stats_callback = {},
bool echo = true);

/**
* Prefill an LLaVA Module with the given images input.
Expand Down Expand Up @@ -70,6 +71,7 @@ class LlavaRunner : public MultimodalRunner {
* @param start_pos The starting position in KV cache of the input in the LLM.
* @param token_callback What to do after a token is generated.
* @param stats_callback What to do with Stats.
* @param echo Whether to echo the input prompt or not.
* @return The error code.
*/
Error generate_from_pos(
Expand All @@ -78,7 +80,8 @@ class LlavaRunner : public MultimodalRunner {
int64_t start_pos = 0,
std::function<void(const std::string&)> token_callback = {},
std::function<void(const ::executorch::extension::llm::Stats&)>
stats_callback = {});
stats_callback = {},
bool echo = true);

private:
inline static const std::string kPresetPrompt =
Expand Down
13 changes: 8 additions & 5 deletions extension/android/jni/jni_layer_llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ class ExecuTorchLlamaJni
jint channels,
facebook::jni::alias_ref<jstring> prompt,
jint seq_len,
jboolean echo,
facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback) {
facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback,
jboolean echo) {
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
auto image_size = image->size();
std::vector<Image> images;
Expand All @@ -170,7 +170,8 @@ class ExecuTorchLlamaJni
prompt->toStdString(),
seq_len,
[callback](std::string result) { callback->onResult(result); },
[callback](const Stats& result) { callback->onStats(result); });
[callback](const Stats& result) { callback->onStats(result); },
echo);
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
runner_->generate(
prompt->toStdString(),
Expand Down Expand Up @@ -248,7 +249,8 @@ class ExecuTorchLlamaJni
facebook::jni::alias_ref<jstring> prompt,
jint seq_len,
jlong start_pos,
facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback) {
facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback,
jboolean echo) {
if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
return static_cast<jint>(Error::NotSupported);
}
Expand All @@ -259,7 +261,8 @@ class ExecuTorchLlamaJni
[callback](const std::string& result) { callback->onResult(result); },
[callback](const ::executorch::extension::llm::Stats& stats) {
callback->onStats(stats);
}));
},
echo));
}

void stop() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public void resetNative() {
* @param llamaCallback callback object to receive results.
*/
public int generate(String prompt, LlamaCallback llamaCallback) {
return generate(prompt, DEFAULT_SEQ_LEN, DEFAULT_ECHO, llamaCallback);
return generate(prompt, DEFAULT_SEQ_LEN, llamaCallback, DEFAULT_ECHO);
}

/**
Expand All @@ -71,30 +71,30 @@ public int generate(String prompt, LlamaCallback llamaCallback) {
* @param llamaCallback callback object to receive results.
*/
public int generate(String prompt, int seqLen, LlamaCallback llamaCallback) {
return generate(null, 0, 0, 0, prompt, seqLen, DEFAULT_ECHO, llamaCallback);
return generate(null, 0, 0, 0, prompt, seqLen, llamaCallback, DEFAULT_ECHO);
}

/**
* Start generating tokens from the module.
*
* @param prompt Input prompt
* @param llamaCallback callback object to receive results
* @param echo indicate whether to echo the input prompt or not (text completion vs chat)
* @param llamaCallback callback object to receive results.
*/
public int generate(String prompt, boolean echo, LlamaCallback llamaCallback) {
return generate(null, 0, 0, 0, prompt, DEFAULT_SEQ_LEN, echo, llamaCallback);
public int generate(String prompt, LlamaCallback llamaCallback, boolean echo) {
return generate(null, 0, 0, 0, prompt, DEFAULT_SEQ_LEN, llamaCallback, echo);
}

/**
* Start generating tokens from the module.
*
* @param prompt Input prompt
* @param seqLen sequence length
* @param llamaCallback callback object to receive results
* @param echo indicate whether to echo the input prompt or not (text completion vs chat)
* @param llamaCallback callback object to receive results.
*/
public int generate(String prompt, int seqLen, boolean echo, LlamaCallback llamaCallback) {
return generate(null, 0, 0, 0, prompt, seqLen, echo, llamaCallback);
public int generate(String prompt, int seqLen, LlamaCallback llamaCallback, boolean echo) {
return generate(null, 0, 0, 0, prompt, seqLen, llamaCallback, echo);
}

/**
Expand All @@ -106,8 +106,8 @@ public int generate(String prompt, int seqLen, boolean echo, LlamaCallback llama
* @param channels Input image number of channels
* @param prompt Input prompt
* @param seqLen sequence length
* @param echo indicate whether to echo the input prompt or not (text completion vs chat)
* @param llamaCallback callback object to receive results.
* @param echo indicate whether to echo the input prompt or not (text completion vs chat)
*/
@DoNotStrip
public native int generate(
Expand All @@ -117,8 +117,8 @@ public native int generate(
int channels,
String prompt,
int seqLen,
boolean echo,
LlamaCallback llamaCallback);
LlamaCallback llamaCallback,
boolean echo);

/**
* Prefill an LLaVA Module with the given images input.
Expand Down Expand Up @@ -172,10 +172,11 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) {
* @param seqLen The total sequence length, including the prompt tokens and new tokens.
* @param startPos The starting position in KV cache of the input in the LLM.
* @param llamaCallback callback object to receive results.
* @param echo indicate whether to echo the input prompt or not.
* @return The error code.
*/
public native int generateFromPos(
String prompt, int seqLen, long startPos, LlamaCallback callback);
String prompt, int seqLen, long startPos, LlamaCallback callback, boolean echo);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't echo be the last argument?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the Android side echo is not last argument, but from jni side echo is passed into the runner as the last argument.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@larryliu0820 latest commits adds this as last parameter in all layers

/** Stop current generate() before it finishes. */
@DoNotStrip
Expand Down
7 changes: 5 additions & 2 deletions extension/llm/runner/multimodal_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ class MultimodalRunner {
const std::string& prompt,
int32_t seq_len = 1024,
std::function<void(const std::string&)> token_callback = {},
std::function<void(const Stats&)> stats_callback = {}) = 0;
std::function<void(const Stats&)> stats_callback = {},
bool echo = true) = 0;

/**
* Prefill an LLaVA Module with the given images input.
Expand Down Expand Up @@ -95,6 +96,7 @@ class MultimodalRunner {
* @param start_pos The starting position in KV cache of the input in the LLM.
* @param token_callback What to do after a token is generated.
* @param stats_callback What to do with Stats.
* @param echo Whether to echo the input prompt or not.
* @return The error code.
*/
virtual runtime::Error generate_from_pos(
Expand All @@ -103,7 +105,8 @@ class MultimodalRunner {
int64_t start_pos = 0,
std::function<void(const std::string&)> token_callback = {},
std::function<void(const ::executorch::extension::llm::Stats&)>
stats_callback = {}) = 0;
stats_callback = {},
bool echo = true) = 0;

inline void stop() {
text_token_generator_->stop();
Expand Down
Loading