Skip to content

Commit 032a35d

Browse files
committed
Add Echo parameter to multimodal runner (llava) and jni layer
1 parent 99fbca3 commit 032a35d

File tree

4 files changed

+19
-8
lines changed

4 files changed

+19
-8
lines changed

examples/models/llava/runner/llava_runner.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,12 @@ Error LlavaRunner::generate_from_pos(
9999
int64_t start_pos,
100100
std::function<void(const std::string&)> token_callback,
101101
std::function<void(const ::executorch::extension::llm::Stats&)>
102-
stats_callback) {
102+
stats_callback,
103+
bool echo) {
103104
// prefill user prompt. No BOS because preset prompt already has it.
104-
token_callback(prompt);
105+
if (echo) {
106+
token_callback(prompt);
107+
}
105108

106109
uint64_t prefill_next_token =
107110
ET_UNWRAP(prefill_prompt(prompt, start_pos, /*bos=*/0, /*eos*/ 0));
@@ -125,7 +128,8 @@ Error LlavaRunner::generate(
125128
const std::string& prompt,
126129
int32_t seq_len,
127130
std::function<void(const std::string&)> token_callback,
128-
std::function<void(const Stats&)> stats_callback) {
131+
std::function<void(const Stats&)> stats_callback,
132+
bool echo) {
129133
ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
130134
if (!is_loaded()) {
131135
ET_CHECK_OK_OR_RETURN_ERROR(load());

examples/models/llava/runner/llava_runner.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ class LlavaRunner : public MultimodalRunner {
3636
int32_t seq_len = 1024,
3737
std::function<void(const std::string&)> token_callback = {},
3838
std::function<void(const ::executorch::extension::llm::Stats&)>
39-
stats_callback = {});
39+
stats_callback = {},
40+
bool echo = true);
4041

4142
/**
4243
* Prefill an LLaVA Module with the given images input.
@@ -70,6 +71,7 @@ class LlavaRunner : public MultimodalRunner {
7071
* @param start_pos The starting position in KV cache of the input in the LLM.
7172
* @param token_callback What to do after a token is generated.
7273
* @param stats_callback What to do with Stats.
74+
* @param echo Whether to echo the input prompt or not.
7375
* @return The error code.
7476
*/
7577
Error generate_from_pos(
@@ -78,7 +80,8 @@ class LlavaRunner : public MultimodalRunner {
7880
int64_t start_pos = 0,
7981
std::function<void(const std::string&)> token_callback = {},
8082
std::function<void(const ::executorch::extension::llm::Stats&)>
81-
stats_callback = {});
83+
stats_callback = {},
84+
bool echo = true);
8285

8386
private:
8487
inline static const std::string kPresetPrompt =

extension/android/jni/jni_layer_llama.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,8 @@ class ExecuTorchLlamaJni
170170
prompt->toStdString(),
171171
seq_len,
172172
[callback](std::string result) { callback->onResult(result); },
173-
[callback](const Stats& result) { callback->onStats(result); });
173+
[callback](const Stats& result) { callback->onStats(result); },
174+
echo);
174175
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
175176
runner_->generate(
176177
prompt->toStdString(),

extension/llm/runner/multimodal_runner.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ class MultimodalRunner {
5959
const std::string& prompt,
6060
int32_t seq_len = 1024,
6161
std::function<void(const std::string&)> token_callback = {},
62-
std::function<void(const Stats&)> stats_callback = {}) = 0;
62+
std::function<void(const Stats&)> stats_callback = {},
63+
bool echo = true) = 0;
6364

6465
/**
6566
* Prefill an LLaVA Module with the given images input.
@@ -95,6 +96,7 @@ class MultimodalRunner {
9596
* @param start_pos The starting position in KV cache of the input in the LLM.
9697
* @param token_callback What to do after a token is generated.
9798
* @param stats_callback What to do with Stats.
99+
* @param echo Whether to echo the input prompt or not.
98100
* @return The error code.
99101
*/
100102
virtual runtime::Error generate_from_pos(
@@ -103,7 +105,8 @@ class MultimodalRunner {
103105
int64_t start_pos = 0,
104106
std::function<void(const std::string&)> token_callback = {},
105107
std::function<void(const ::executorch::extension::llm::Stats&)>
106-
stats_callback = {}) = 0;
108+
stats_callback = {},
109+
bool echo = true) = 0;
107110

108111
inline void stop() {
109112
text_token_generator_->stop();

0 commit comments

Comments
 (0)