Skip to content

Commit 478f1e2

Browse files
committed
Update with generate_from_pos
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 5a75ccc commit 478f1e2

File tree

2 files changed

+46
-18
lines changed

2 files changed

+46
-18
lines changed

examples/models/llava/runner/llava_runner.cpp

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,32 @@ Result<uint64_t> LlavaRunner::prefill_prompt(
9393
return text_prefiller_->prefill(prompt_tokens, start_pos);
9494
}
9595

96+
Error LlavaRunner::generate_from_pos(
97+
const std::string& prompt,
98+
int32_t seq_len,
99+
int64_t start_pos,
100+
std::function<void(const std::string&)> token_callback,
101+
std::function<void(const ::executorch::extension::llm::Stats&)>
102+
stats_callback) {
103+
// prefill user prompt. No BOS because preset prompt already has it.
104+
token_callback(prompt);
105+
106+
uint64_t prefill_next_token =
107+
ET_UNWRAP(prefill_prompt(prompt, start_pos, /*bos=*/0, /*eos*/ 0));
108+
stats_.num_prompt_tokens = start_pos;
109+
110+
// Generate tokens
111+
int64_t num_generated_tokens = ET_UNWRAP(text_token_generator_->generate(
112+
{prefill_next_token}, start_pos, seq_len, token_callback));
113+
114+
// Bookkeeping
115+
stats_.num_generated_tokens = num_generated_tokens;
116+
::executorch::llm::print_report(stats_);
117+
if (stats_callback) {
118+
stats_callback(stats_);
119+
}
120+
}
121+
96122
Error LlavaRunner::generate(
97123
std::vector<Image> images,
98124
const std::string& prompt,
@@ -122,25 +148,9 @@ Error LlavaRunner::generate(
122148
// prefill images
123149
prefill_images(images, pos);
124150

125-
// prefill user prompt. No BOS because preset prompt already has it.
126-
wrapped_callback(prompt);
127-
128-
uint64_t prefill_next_token =
129-
ET_UNWRAP(prefill_prompt(prompt, pos, /*bos=*/0, /*eos*/ 0));
130-
stats_.num_prompt_tokens = pos;
131-
132151
// Generate tokens
133-
int64_t num_generated_tokens = ET_UNWRAP(text_token_generator_->generate(
134-
{prefill_next_token}, pos, seq_len, wrapped_callback));
135-
136-
// Bookkeeping
137-
stats_.num_generated_tokens = num_generated_tokens;
138-
::executorch::llm::print_report(stats_);
139-
if (stats_callback) {
140-
stats_callback(stats_);
141-
}
142-
143-
return Error::Ok;
152+
return generate_from_pos(
153+
prompt, seq_len, pos, wrapped_callback, stats_callback);
144154
}
145155

146156
} // namespace torch::executor

examples/models/llava/runner/llava_runner.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,24 @@ class LlavaRunner : public MultimodalRunner {
6262
int8_t bos = 0,
6363
int8_t eos = 0);
6464

65+
/**
66+
* Generate tokens from the given prompt, starting from the given position.
67+
* @param prompt The text prompt to LLaVA.
68+
* @param seq_len The total sequence length, including the prompt tokens and
69+
* new tokens.
70+
* @param start_pos The starting position in KV cache of the input in the LLM.
71+
* @param token_callback What to do after a token is generated.
72+
* @param stats_callback What to do with Stats.
73+
* @return The error code.
74+
*/
75+
Error generate_from_pos(
76+
const std::string& prompt,
77+
int32_t seq_len = 1024,
78+
int64_t start_pos = 0,
79+
std::function<void(const std::string&)> token_callback = {},
80+
std::function<void(const ::executorch::extension::llm::Stats&)>
81+
stats_callback = {});
82+
6583
private:
6684
inline static const std::string kPresetPrompt =
6785
"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: ";

0 commit comments

Comments
 (0)