Skip to content

Commit 542ecb5

Browse files
authored
Add Echo parameter to multimodal runner (llava) and jni layer (#5181)
* Add Echo parameter to multimodal runner (llava) and jni layer * Rebasing - Unify order of echo parameter to be last in all layers
1 parent 6ce9f52 commit 542ecb5

File tree

6 files changed

+46
-32
lines changed

6 files changed

+46
-32
lines changed

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -675,8 +675,8 @@ public void run() {
675675
ModelUtils.VISION_MODEL_IMAGE_CHANNELS,
676676
rawPrompt,
677677
ModelUtils.VISION_MODEL_SEQ_LEN,
678-
false,
679-
MainActivity.this);
678+
MainActivity.this,
679+
false);
680680
} else {
681681
// no image selected, we pass in empty int array
682682
mModule.generate(
@@ -686,8 +686,8 @@ public void run() {
686686
ModelUtils.VISION_MODEL_IMAGE_CHANNELS,
687687
rawPrompt,
688688
ModelUtils.VISION_MODEL_SEQ_LEN,
689-
false,
690-
MainActivity.this);
689+
MainActivity.this,
690+
false);
691691
}
692692
} else {
693693
String finalPrompt =
@@ -696,8 +696,8 @@ public void run() {
696696
mModule.generate(
697697
finalPrompt,
698698
(int) (finalPrompt.length() * 0.75) + 64,
699-
false,
700-
MainActivity.this);
699+
MainActivity.this,
700+
false);
701701
}
702702

703703
long generateDuration = System.currentTimeMillis() - generateStartTime;

examples/models/llava/runner/llava_runner.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,12 @@ Error LlavaRunner::generate_from_pos(
9999
int64_t start_pos,
100100
std::function<void(const std::string&)> token_callback,
101101
std::function<void(const ::executorch::extension::llm::Stats&)>
102-
stats_callback) {
102+
stats_callback,
103+
bool echo) {
103104
// prefill user prompt. No BOS because preset prompt already has it.
104-
token_callback(prompt);
105+
if (echo) {
106+
token_callback(prompt);
107+
}
105108

106109
uint64_t prefill_next_token =
107110
ET_UNWRAP(prefill_prompt(prompt, start_pos, /*bos=*/0, /*eos*/ 0));
@@ -125,7 +128,8 @@ Error LlavaRunner::generate(
125128
const std::string& prompt,
126129
int32_t seq_len,
127130
std::function<void(const std::string&)> token_callback,
128-
std::function<void(const Stats&)> stats_callback) {
131+
std::function<void(const Stats&)> stats_callback,
132+
bool echo) {
129133
ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
130134
if (!is_loaded()) {
131135
ET_CHECK_OK_OR_RETURN_ERROR(load());
@@ -160,8 +164,8 @@ Error LlavaRunner::generate(
160164
util::get_rss_bytes() / 1024.0 / 1024.0);
161165

162166
// Generate tokens
163-
Error err =
164-
generate_from_pos(prompt, seq_len, pos, wrapped_callback, stats_callback);
167+
Error err = generate_from_pos(
168+
prompt, seq_len, pos, wrapped_callback, stats_callback, echo);
165169

166170
ET_LOG(
167171
Info,

examples/models/llava/runner/llava_runner.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ class LlavaRunner : public MultimodalRunner {
3636
int32_t seq_len = 1024,
3737
std::function<void(const std::string&)> token_callback = {},
3838
std::function<void(const ::executorch::extension::llm::Stats&)>
39-
stats_callback = {});
39+
stats_callback = {},
40+
bool echo = true);
4041

4142
/**
4243
* Prefill an LLaVA Module with the given images input.
@@ -70,6 +71,7 @@ class LlavaRunner : public MultimodalRunner {
7071
* @param start_pos The starting position in KV cache of the input in the LLM.
7172
* @param token_callback What to do after a token is generated.
7273
* @param stats_callback What to do with Stats.
74+
* @param echo Whether to echo the input prompt or not.
7375
* @return The error code.
7476
*/
7577
Error generate_from_pos(
@@ -78,7 +80,8 @@ class LlavaRunner : public MultimodalRunner {
7880
int64_t start_pos = 0,
7981
std::function<void(const std::string&)> token_callback = {},
8082
std::function<void(const ::executorch::extension::llm::Stats&)>
81-
stats_callback = {});
83+
stats_callback = {},
84+
bool echo = true);
8285

8386
private:
8487
inline static const std::string kPresetPrompt =

extension/android/jni/jni_layer_llama.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,8 @@ class ExecuTorchLlamaJni
150150
jint channels,
151151
facebook::jni::alias_ref<jstring> prompt,
152152
jint seq_len,
153-
jboolean echo,
154-
facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback) {
153+
facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback,
154+
jboolean echo) {
155155
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
156156
auto image_size = image->size();
157157
std::vector<Image> images;
@@ -170,7 +170,8 @@ class ExecuTorchLlamaJni
170170
prompt->toStdString(),
171171
seq_len,
172172
[callback](std::string result) { callback->onResult(result); },
173-
[callback](const Stats& result) { callback->onStats(result); });
173+
[callback](const Stats& result) { callback->onStats(result); },
174+
echo);
174175
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
175176
runner_->generate(
176177
prompt->toStdString(),
@@ -248,7 +249,8 @@ class ExecuTorchLlamaJni
248249
facebook::jni::alias_ref<jstring> prompt,
249250
jint seq_len,
250251
jlong start_pos,
251-
facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback) {
252+
facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback,
253+
jboolean echo) {
252254
if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
253255
return static_cast<jint>(Error::NotSupported);
254256
}
@@ -259,7 +261,8 @@ class ExecuTorchLlamaJni
259261
[callback](const std::string& result) { callback->onResult(result); },
260262
[callback](const ::executorch::extension::llm::Stats& stats) {
261263
callback->onStats(stats);
262-
}));
264+
},
265+
echo));
263266
}
264267

265268
void stop() {

extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ public void resetNative() {
6060
* @param llamaCallback callback object to receive results.
6161
*/
6262
public int generate(String prompt, LlamaCallback llamaCallback) {
63-
return generate(prompt, DEFAULT_SEQ_LEN, DEFAULT_ECHO, llamaCallback);
63+
return generate(prompt, DEFAULT_SEQ_LEN, llamaCallback, DEFAULT_ECHO);
6464
}
6565

6666
/**
@@ -71,30 +71,30 @@ public int generate(String prompt, LlamaCallback llamaCallback) {
7171
* @param llamaCallback callback object to receive results.
7272
*/
7373
public int generate(String prompt, int seqLen, LlamaCallback llamaCallback) {
74-
return generate(null, 0, 0, 0, prompt, seqLen, DEFAULT_ECHO, llamaCallback);
74+
return generate(null, 0, 0, 0, prompt, seqLen, llamaCallback, DEFAULT_ECHO);
7575
}
7676

7777
/**
7878
* Start generating tokens from the module.
7979
*
8080
* @param prompt Input prompt
81+
* @param llamaCallback callback object to receive results
8182
* @param echo indicate whether to echo the input prompt or not (text completion vs chat)
82-
* @param llamaCallback callback object to receive results.
8383
*/
84-
public int generate(String prompt, boolean echo, LlamaCallback llamaCallback) {
85-
return generate(null, 0, 0, 0, prompt, DEFAULT_SEQ_LEN, echo, llamaCallback);
84+
public int generate(String prompt, LlamaCallback llamaCallback, boolean echo) {
85+
return generate(null, 0, 0, 0, prompt, DEFAULT_SEQ_LEN, llamaCallback, echo);
8686
}
8787

8888
/**
8989
* Start generating tokens from the module.
9090
*
9191
* @param prompt Input prompt
9292
* @param seqLen sequence length
93+
* @param llamaCallback callback object to receive results
9394
* @param echo indicate whether to echo the input prompt or not (text completion vs chat)
94-
* @param llamaCallback callback object to receive results.
9595
*/
96-
public int generate(String prompt, int seqLen, boolean echo, LlamaCallback llamaCallback) {
97-
return generate(null, 0, 0, 0, prompt, seqLen, echo, llamaCallback);
96+
public int generate(String prompt, int seqLen, LlamaCallback llamaCallback, boolean echo) {
97+
return generate(null, 0, 0, 0, prompt, seqLen, llamaCallback, echo);
9898
}
9999

100100
/**
@@ -106,8 +106,8 @@ public int generate(String prompt, int seqLen, boolean echo, LlamaCallback llama
106106
* @param channels Input image number of channels
107107
* @param prompt Input prompt
108108
* @param seqLen sequence length
109-
* @param echo indicate whether to echo the input prompt or not (text completion vs chat)
110109
* @param llamaCallback callback object to receive results.
110+
* @param echo indicate whether to echo the input prompt or not (text completion vs chat)
111111
*/
112112
@DoNotStrip
113113
public native int generate(
@@ -117,8 +117,8 @@ public native int generate(
117117
int channels,
118118
String prompt,
119119
int seqLen,
120-
boolean echo,
121-
LlamaCallback llamaCallback);
120+
LlamaCallback llamaCallback,
121+
boolean echo);
122122

123123
/**
124124
* Prefill an LLaVA Module with the given images input.
@@ -172,10 +172,11 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) {
172172
* @param seqLen The total sequence length, including the prompt tokens and new tokens.
173173
* @param startPos The starting position in KV cache of the input in the LLM.
174174
* @param llamaCallback callback object to receive results.
175+
* @param echo indicate whether to echo the input prompt or not.
175176
* @return The error code.
176177
*/
177178
public native int generateFromPos(
178-
String prompt, int seqLen, long startPos, LlamaCallback callback);
179+
String prompt, int seqLen, long startPos, LlamaCallback callback, boolean echo);
179180

180181
/** Stop current generate() before it finishes. */
181182
@DoNotStrip

extension/llm/runner/multimodal_runner.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ class MultimodalRunner {
5959
const std::string& prompt,
6060
int32_t seq_len = 1024,
6161
std::function<void(const std::string&)> token_callback = {},
62-
std::function<void(const Stats&)> stats_callback = {}) = 0;
62+
std::function<void(const Stats&)> stats_callback = {},
63+
bool echo = true) = 0;
6364

6465
/**
6566
* Prefill an LLaVA Module with the given images input.
@@ -95,6 +96,7 @@ class MultimodalRunner {
9596
* @param start_pos The starting position in KV cache of the input in the LLM.
9697
* @param token_callback What to do after a token is generated.
9798
* @param stats_callback What to do with Stats.
99+
* @param echo Whether to echo the input prompt or not.
98100
* @return The error code.
99101
*/
100102
virtual runtime::Error generate_from_pos(
@@ -103,7 +105,8 @@ class MultimodalRunner {
103105
int64_t start_pos = 0,
104106
std::function<void(const std::string&)> token_callback = {},
105107
std::function<void(const ::executorch::extension::llm::Stats&)>
106-
stats_callback = {}) = 0;
108+
stats_callback = {},
109+
bool echo = true) = 0;
107110

108111
inline void stop() {
109112
text_token_generator_->stop();

0 commit comments

Comments
 (0)