Skip to content

Commit 9739609

Browse files
authored
[llava] Expose prefill image and prompt APIs
Differential Revision: D62273041 Pull Request resolved: #5119
1 parent 030fc3f commit 9739609

File tree

5 files changed

+110
-48
lines changed

5 files changed

+110
-48
lines changed

examples/models/llama2/runner/runner.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,8 @@ Error Runner::generate(
204204

205205
// print prompts
206206
wrapped_callback(prompt);
207-
208-
auto prefill_res = text_prefiller_->prefill(prompt_tokens, 0);
207+
int64_t pos = 0;
208+
auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos);
209209
stats_.first_token_ms = util::time_in_ms();
210210
stats_.prompt_eval_end_ms = util::time_in_ms();
211211
ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());

examples/models/llava/runner/llava_runner.cpp

Lines changed: 52 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,54 @@ Error LlavaRunner::load() {
7272
return Error::Ok;
7373
}
7474

75+
Error LlavaRunner::prefill_images(
76+
std::vector<Image>& images,
77+
int64_t& start_pos) {
78+
for (auto& image : images) {
79+
// pos is updated inside image prefill.
80+
ET_UNWRAP(image_prefiller_->prefill(image, start_pos));
81+
}
82+
return Error::Ok;
83+
}
84+
85+
Result<uint64_t> LlavaRunner::prefill_prompt(
86+
const std::string& prompt,
87+
int64_t& start_pos,
88+
int8_t bos,
89+
int8_t eos) {
90+
std::vector<uint64_t> prompt_tokens =
91+
ET_UNWRAP(tokenizer_->encode(prompt, bos, eos));
92+
93+
return text_prefiller_->prefill(prompt_tokens, start_pos);
94+
}
95+
96+
Error LlavaRunner::generate_from_pos(
97+
const std::string& prompt,
98+
int32_t seq_len,
99+
int64_t start_pos,
100+
std::function<void(const std::string&)> token_callback,
101+
std::function<void(const ::executorch::extension::llm::Stats&)>
102+
stats_callback) {
103+
// prefill user prompt. No BOS because preset prompt already has it.
104+
token_callback(prompt);
105+
106+
uint64_t prefill_next_token =
107+
ET_UNWRAP(prefill_prompt(prompt, start_pos, /*bos=*/0, /*eos*/ 0));
108+
stats_.num_prompt_tokens = start_pos;
109+
110+
// Generate tokens
111+
int64_t num_generated_tokens = ET_UNWRAP(text_token_generator_->generate(
112+
{prefill_next_token}, start_pos, seq_len, token_callback));
113+
114+
// Bookkeeping
115+
stats_.num_generated_tokens = num_generated_tokens;
116+
::executorch::llm::print_report(stats_);
117+
if (stats_callback) {
118+
stats_callback(stats_);
119+
}
120+
return Error::Ok;
121+
}
122+
75123
Error LlavaRunner::generate(
76124
std::vector<Image> images,
77125
const std::string& prompt,
@@ -96,43 +144,14 @@ Error LlavaRunner::generate(
96144
int64_t pos = 0;
97145

98146
// prefill preset prompt
99-
std::vector<uint64_t> preset_prompt_tokens =
100-
ET_UNWRAP(tokenizer_->encode(kPresetPrompt, /*bos=*/1, /*eos=*/0));
101-
size_t num_preset_tokens = preset_prompt_tokens.size();
102-
103-
ET_UNWRAP(text_prefiller_->prefill(preset_prompt_tokens, pos));
104-
pos += num_preset_tokens;
147+
prefill_prompt(kPresetPrompt, pos, /*bos=*/1, /*eos*/ 0);
105148

106149
// prefill images
107-
for (auto& image : images) {
108-
// pos is updated inside image prefill.
109-
ET_UNWRAP(image_prefiller_->prefill(image, pos));
110-
}
111-
112-
// prefill user prompt. No BOS because preset prompt already has it.
113-
wrapped_callback(prompt);
114-
115-
std::vector<uint64_t> user_prompt_tokens =
116-
ET_UNWRAP(tokenizer_->encode(prompt, /*bos=*/0, /*eos=*/0));
117-
size_t num_user_tokens = user_prompt_tokens.size();
118-
119-
uint64_t prefill_next_token =
120-
ET_UNWRAP(text_prefiller_->prefill(user_prompt_tokens, pos));
121-
pos += num_user_tokens;
150+
prefill_images(images, pos);
122151

123152
// Generate tokens
124-
int64_t num_generated_tokens = ET_UNWRAP(text_token_generator_->generate(
125-
{prefill_next_token}, pos, seq_len, wrapped_callback));
126-
127-
// Bookkeeping
128-
stats_.num_prompt_tokens = num_preset_tokens + num_user_tokens;
129-
stats_.num_generated_tokens = num_generated_tokens;
130-
::executorch::llm::print_report(stats_);
131-
if (stats_callback) {
132-
stats_callback(stats_);
133-
}
134-
135-
return Error::Ok;
153+
return generate_from_pos(
154+
prompt, seq_len, pos, wrapped_callback, stats_callback);
136155
}
137156

138157
} // namespace torch::executor

examples/models/llava/runner/llava_runner.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,48 @@ class LlavaRunner : public MultimodalRunner {
3838
std::function<void(const ::executorch::extension::llm::Stats&)>
3939
stats_callback = {});
4040

41+
/**
42+
* Prefill an LLaVA Module with the given images input.
43+
* @param images The image input to LLaVA.
44+
* @param start_pos The starting position in KV cache of the input in the LLM.
45+
* It's passed as reference and will be updated inside this function.
46+
* @return The error status of prefilling images.
47+
*/
48+
Error prefill_images(std::vector<Image>& images, int64_t& start_pos);
49+
50+
/**
51+
* Prefill an LLaVA Module with the given text input.
52+
* @param prompt The text prompt to LLaVA.
53+
* @param start_pos The starting position in KV cache of the input in the LLM.
54+
* It's passed as reference and will be updated inside this function.
55+
* @param bos The number of BOS (begin of sequence) token.
56+
* @param eos The number of EOS (end of sequence) token.
57+
* @return The generated token of the LLaVA Module after prefill prompt.
58+
*/
59+
Result<uint64_t> prefill_prompt(
60+
const std::string& prompt,
61+
int64_t& start_pos,
62+
int8_t bos = 0,
63+
int8_t eos = 0);
64+
65+
/**
66+
* Generate tokens from the given prompt, starting from the given position.
67+
* @param prompt The text prompt to LLaVA.
68+
* @param seq_len The total sequence length, including the prompt tokens and
69+
* new tokens.
70+
* @param start_pos The starting position in KV cache of the input in the LLM.
71+
* @param token_callback What to do after a token is generated.
72+
* @param stats_callback What to do with Stats.
73+
* @return The error code.
74+
*/
75+
Error generate_from_pos(
76+
const std::string& prompt,
77+
int32_t seq_len = 1024,
78+
int64_t start_pos = 0,
79+
std::function<void(const std::string&)> token_callback = {},
80+
std::function<void(const ::executorch::extension::llm::Stats&)>
81+
stats_callback = {});
82+
4183
private:
4284
inline static const std::string kPresetPrompt =
4385
"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: ";

extension/llm/runner/text_prefiller.cpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ TextPrefiller::TextPrefiller(
2525

2626
::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
2727
std::vector<uint64_t>& prompt_tokens,
28-
int64_t start_pos_index) {
28+
int64_t& start_pos) {
2929
ET_CHECK_MSG(!prompt_tokens.empty(), "Prompt cannot be null");
3030
if (!text_decoder_runner_->is_method_loaded()) {
3131
ET_CHECK_OK_OR_RETURN_ERROR(text_decoder_runner_->load());
@@ -43,45 +43,46 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
4343
{1, num_prompt_tokens},
4444
exec_aten::ScalarType::Long);
4545

46-
auto start_pos =
47-
from_blob(&start_pos_index, {1}, exec_aten::ScalarType::Long);
46+
auto start_pos_tensor =
47+
from_blob(&start_pos, {1}, exec_aten::ScalarType::Long);
4848

49-
auto outputs_res = text_decoder_runner_->step(tokens, start_pos);
49+
auto outputs_res = text_decoder_runner_->step(tokens, start_pos_tensor);
5050

5151
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
5252
ET_LOG(
5353
Info, "Prefill token result numel(): %zu", outputs_res.get().numel());
5454

55+
start_pos += num_prompt_tokens;
5556
cur_token = text_decoder_runner_->logits_to_token(outputs_res.get());
5657
} else { // sequential prefill
5758
int64_t pos = 0; // position in the sequence
58-
// token & pos
59-
int64_t pos_data = 0;
6059
// NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
6160
cur_token = prompt_tokens[0];
6261

6362
// initialize tensor wrappers
6463
auto tokens = from_blob(&cur_token, {1, 1}, exec_aten::ScalarType::Long);
6564

66-
auto start_pos = from_blob(&pos_data, {1}, exec_aten::ScalarType::Long);
65+
auto start_pos_tensor =
66+
from_blob(&start_pos, {1}, exec_aten::ScalarType::Long);
6767

6868
// run the first token and get back logits tensor. Assuming the first token
6969
// is bos so don't callback.
7070
auto logits_tensor =
71-
ET_UNWRAP(text_decoder_runner_->step(tokens, start_pos));
71+
ET_UNWRAP(text_decoder_runner_->step(tokens, start_pos_tensor));
7272

73-
pos = 1; // start from index 1
73+
pos += 1; // start the loop from index 1
74+
start_pos += 1;
7475

7576
while (pos < num_prompt_tokens) {
7677
// Run the model
77-
pos_data = start_pos_index + pos;
78-
7978
// NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
8079
cur_token = prompt_tokens[pos];
8180

82-
logits_tensor = ET_UNWRAP(text_decoder_runner_->step(tokens, start_pos));
81+
logits_tensor =
82+
ET_UNWRAP(text_decoder_runner_->step(tokens, start_pos_tensor));
8383

8484
pos++;
85+
start_pos++;
8586
}
8687

8788
cur_token = text_decoder_runner_->logits_to_token(logits_tensor);

extension/llm/runner/text_prefiller.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class TextPrefiller {
3636
*/
3737
::executorch::runtime::Result<uint64_t> prefill(
3838
std::vector<uint64_t>& prompt_tokens,
39-
int64_t start_pos = 0);
39+
int64_t& start_pos);
4040

4141
private:
4242
TextDecoderRunner* text_decoder_runner_;

0 commit comments

Comments
 (0)