Skip to content

Commit 8929100

Browse files
committed
[llava] Expose prefill image and prompt APIs
Summary: We want to expose prefill_images() and prefill_prompt() for Llava runner. These APIs will be called by JNI/Demo app so that we can prefill asynchronously. Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent ee752f0 commit 8929100

File tree

4 files changed

+57
-26
lines changed

4 files changed

+57
-26
lines changed

examples/models/llava/runner/llava_runner.cpp

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,27 @@ Error LlavaRunner::load() {
7272
return Error::Ok;
7373
}
7474

75+
Error LlavaRunner::prefill_images(
76+
std::vector<Image>& images,
77+
int64_t& start_pos) {
78+
for (auto& image : images) {
79+
// pos is updated inside image prefill.
80+
ET_UNWRAP(image_prefiller_->prefill(image, start_pos));
81+
}
82+
return Error::Ok;
83+
}
84+
85+
Result<uint64_t> LlavaRunner::prefill_prompt(
86+
const std::string& prompt,
87+
int64_t& start_pos,
88+
int8_t bos,
89+
int8_t eos) {
90+
std::vector<uint64_t> prompt_tokens =
91+
ET_UNWRAP(tokenizer_->encode(prompt, bos, eos));
92+
93+
return text_prefiller_->prefill(prompt_tokens, start_pos);
94+
}
95+
7596
Error LlavaRunner::generate(
7697
std::vector<Image> images,
7798
const std::string& prompt,
@@ -96,36 +117,23 @@ Error LlavaRunner::generate(
96117
int64_t pos = 0;
97118

98119
// prefill preset prompt
99-
std::vector<uint64_t> preset_prompt_tokens =
100-
ET_UNWRAP(tokenizer_->encode(kPresetPrompt, /*bos=*/1, /*eos=*/0));
101-
size_t num_preset_tokens = preset_prompt_tokens.size();
102-
103-
ET_UNWRAP(text_prefiller_->prefill(preset_prompt_tokens, pos));
104-
pos += num_preset_tokens;
120+
prefill_prompt(kPresetPrompt, pos, /*bos=*/1, /*eos*/ 0);
105121

106122
// prefill images
107-
for (auto& image : images) {
108-
// pos is updated inside image prefill.
109-
ET_UNWRAP(image_prefiller_->prefill(image, pos));
110-
}
123+
prefill_images(images, pos);
111124

112125
// prefill user prompt. No BOS because preset prompt already has it.
113126
wrapped_callback(prompt);
114127

115-
std::vector<uint64_t> user_prompt_tokens =
116-
ET_UNWRAP(tokenizer_->encode(prompt, /*bos=*/0, /*eos=*/0));
117-
size_t num_user_tokens = user_prompt_tokens.size();
118-
119128
uint64_t prefill_next_token =
120-
ET_UNWRAP(text_prefiller_->prefill(user_prompt_tokens, pos));
121-
pos += num_user_tokens;
129+
ET_UNWRAP(prefill_prompt(prompt, pos, /*bos=*/0, /*eos*/ 0));
130+
stats_.num_prompt_tokens = pos;
122131

123132
// Generate tokens
124133
int64_t num_generated_tokens = ET_UNWRAP(text_token_generator_->generate(
125134
{prefill_next_token}, pos, seq_len, wrapped_callback));
126135

127136
// Bookkeeping
128-
stats_.num_prompt_tokens = num_preset_tokens + num_user_tokens;
129137
stats_.num_generated_tokens = num_generated_tokens;
130138
::executorch::llm::print_report(stats_);
131139
if (stats_callback) {

examples/models/llava/runner/llava_runner.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,30 @@ class LlavaRunner : public MultimodalRunner {
3838
std::function<void(const ::executorch::extension::llm::Stats&)>
3939
stats_callback = {});
4040

41+
/**
42+
* Prefill an LLaVA Module with the given images input.
43+
* @param images The image input to LLaVA.
44+
* @param start_pos The starting position in KV cache of the input in the LLM.
45+
* It's passed as reference and will be updated inside this function.
46+
* @return The error status of prefilling images.
47+
*/
48+
Error prefill_images(std::vector<Image>& images, int64_t& start_pos);
49+
50+
/**
51+
* Prefill an LLaVA Module with the given text input.
52+
* @param prompt The text prompt to LLaVA.
53+
* @param start_pos The starting position in KV cache of the input in the LLM.
54+
* It's passed as reference and will be updated inside this function.
55+
* @param bos The number of BOS (begin of sequence) token.
56+
* @param eos The number of EOS (end of sequence) token.
57+
* @return The generated token of the LLaVA Module after prefill prompt.
58+
*/
59+
Result<uint64_t> prefill_prompt(
60+
const std::string& prompt,
61+
int64_t& start_pos,
62+
int8_t bos = 0,
63+
int8_t eos = 0);
64+
4165
private:
4266
inline static const std::string kPresetPrompt =
4367
"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: ";

extension/llm/runner/text_prefiller.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ TextPrefiller::TextPrefiller(
2525

2626
::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
2727
std::vector<uint64_t>& prompt_tokens,
28-
int64_t start_pos) {
28+
int64_t& start_pos) {
2929
ET_CHECK_MSG(!prompt_tokens.empty(), "Prompt cannot be null");
3030
if (!text_decoder_runner_->is_method_loaded()) {
3131
ET_CHECK_OK_OR_RETURN_ERROR(text_decoder_runner_->load());
@@ -53,11 +53,10 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
5353
ET_LOG(
5454
Info, "Prefill token result numel(): %zu", outputs_res.get().numel());
5555

56+
start_pos += num_prompt_tokens;
5657
cur_token = text_decoder_runner_->logits_to_token(outputs_res.get());
5758
} else { // sequential prefill
5859
int64_t pos = 0; // position in the sequence
59-
// token & pos
60-
int64_t pos_data = 0;
6160
// NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
6261
cur_token = prompt_tokens[0];
6362

@@ -66,18 +65,17 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
6665
&cur_token, {1, 1}, exec_aten::ScalarType::Long);
6766

6867
ManagedTensor managed_start_pos(
69-
&pos_data, {1}, exec_aten::ScalarType::Long);
68+
&start_pos, {1}, exec_aten::ScalarType::Long);
7069

71-
// run the first token and get back logits tensor. Assuming the first token
72-
// is bos so don't callback.
70+
// run the first token and get back logits tensor.
7371
exec_aten::Tensor logits_tensor = ET_UNWRAP(
7472
text_decoder_runner_->step(managed_tokens, managed_start_pos));
7573

76-
pos = 1; // start from index 1
74+
pos += 1; // start the loop from index 1
75+
start_pos += 1;
7776

7877
while (pos < num_prompt_tokens) {
7978
// Run the model
80-
pos_data = start_pos + pos;
8179

8280
// NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
8381
cur_token = prompt_tokens[pos];
@@ -86,6 +84,7 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
8684
text_decoder_runner_->step(managed_tokens, managed_start_pos));
8785

8886
pos++;
87+
start_pos++;
8988
}
9089

9190
cur_token = text_decoder_runner_->logits_to_token(logits_tensor);

extension/llm/runner/text_prefiller.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class TextPrefiller {
3636
*/
3737
::executorch::runtime::Result<uint64_t> prefill(
3838
std::vector<uint64_t>& prompt_tokens,
39-
int64_t start_pos = 0);
39+
int64_t& start_pos);
4040

4141
private:
4242
TextDecoderRunner* text_decoder_runner_;

0 commit comments

Comments
 (0)