Skip to content

[llava] Expose prefill image and prompt APIs #5119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/models/llama2/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,8 @@ Error Runner::generate(

// print prompts
wrapped_callback(prompt);

auto prefill_res = text_prefiller_->prefill(prompt_tokens, 0);
int64_t pos = 0;
auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos);
stats_.first_token_ms = util::time_in_ms();
stats_.prompt_eval_end_ms = util::time_in_ms();
ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
Expand Down
85 changes: 52 additions & 33 deletions examples/models/llava/runner/llava_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,54 @@ Error LlavaRunner::load() {
return Error::Ok;
}

Error LlavaRunner::prefill_images(
std::vector<Image>& images,
int64_t& start_pos) {
for (auto& image : images) {
// pos is updated inside image prefill.
ET_UNWRAP(image_prefiller_->prefill(image, start_pos));
}
return Error::Ok;
}

Result<uint64_t> LlavaRunner::prefill_prompt(
const std::string& prompt,
int64_t& start_pos,
int8_t bos,
int8_t eos) {
std::vector<uint64_t> prompt_tokens =
ET_UNWRAP(tokenizer_->encode(prompt, bos, eos));

return text_prefiller_->prefill(prompt_tokens, start_pos);
}

Error LlavaRunner::generate_from_pos(
const std::string& prompt,
int32_t seq_len,
int64_t start_pos,
std::function<void(const std::string&)> token_callback,
std::function<void(const ::executorch::extension::llm::Stats&)>
stats_callback) {
// prefill user prompt. No BOS because preset prompt already has it.
token_callback(prompt);

uint64_t prefill_next_token =
ET_UNWRAP(prefill_prompt(prompt, start_pos, /*bos=*/0, /*eos*/ 0));
stats_.num_prompt_tokens = start_pos;

// Generate tokens
int64_t num_generated_tokens = ET_UNWRAP(text_token_generator_->generate(
{prefill_next_token}, start_pos, seq_len, token_callback));

// Bookkeeping
stats_.num_generated_tokens = num_generated_tokens;
::executorch::llm::print_report(stats_);
if (stats_callback) {
stats_callback(stats_);
}
return Error::Ok;
}

Error LlavaRunner::generate(
std::vector<Image> images,
const std::string& prompt,
Expand All @@ -96,43 +144,14 @@ Error LlavaRunner::generate(
int64_t pos = 0;

// prefill preset prompt
std::vector<uint64_t> preset_prompt_tokens =
ET_UNWRAP(tokenizer_->encode(kPresetPrompt, /*bos=*/1, /*eos=*/0));
size_t num_preset_tokens = preset_prompt_tokens.size();

ET_UNWRAP(text_prefiller_->prefill(preset_prompt_tokens, pos));
pos += num_preset_tokens;
prefill_prompt(kPresetPrompt, pos, /*bos=*/1, /*eos*/ 0);

// prefill images
for (auto& image : images) {
// pos is updated inside image prefill.
ET_UNWRAP(image_prefiller_->prefill(image, pos));
}

// prefill user prompt. No BOS because preset prompt already has it.
wrapped_callback(prompt);

std::vector<uint64_t> user_prompt_tokens =
ET_UNWRAP(tokenizer_->encode(prompt, /*bos=*/0, /*eos=*/0));
size_t num_user_tokens = user_prompt_tokens.size();

uint64_t prefill_next_token =
ET_UNWRAP(text_prefiller_->prefill(user_prompt_tokens, pos));
pos += num_user_tokens;
prefill_images(images, pos);

// Generate tokens
int64_t num_generated_tokens = ET_UNWRAP(text_token_generator_->generate(
{prefill_next_token}, pos, seq_len, wrapped_callback));

// Bookkeeping
stats_.num_prompt_tokens = num_preset_tokens + num_user_tokens;
stats_.num_generated_tokens = num_generated_tokens;
::executorch::llm::print_report(stats_);
if (stats_callback) {
stats_callback(stats_);
}

return Error::Ok;
return generate_from_pos(
prompt, seq_len, pos, wrapped_callback, stats_callback);
}

} // namespace torch::executor
42 changes: 42 additions & 0 deletions examples/models/llava/runner/llava_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,48 @@ class LlavaRunner : public MultimodalRunner {
std::function<void(const ::executorch::extension::llm::Stats&)>
stats_callback = {});

/**
* Prefill an LLaVA Module with the given images input.
* @param images The image input to LLaVA.
* @param start_pos The starting position in KV cache of the input in the LLM.
* It's passed as reference and will be updated inside this function.
* @return The error status of prefilling images.
*/
Error prefill_images(std::vector<Image>& images, int64_t& start_pos);

/**
* Prefill an LLaVA Module with the given text input.
* @param prompt The text prompt to LLaVA.
* @param start_pos The starting position in KV cache of the input in the LLM.
* It's passed as reference and will be updated inside this function.
* @param bos The number of BOS (begin of sequence) token.
* @param eos The number of EOS (end of sequence) token.
* @return The generated token of the LLaVA Module after prefill prompt.
*/
Result<uint64_t> prefill_prompt(
const std::string& prompt,
int64_t& start_pos,
int8_t bos = 0,
int8_t eos = 0);

/**
* Generate tokens from the given prompt, starting from the given position.
* @param prompt The text prompt to LLaVA.
* @param seq_len The total sequence length, including the prompt tokens and
* new tokens.
* @param start_pos The starting position in KV cache of the input in the LLM.
* @param token_callback What to do after a token is generated.
* @param stats_callback What to do with Stats.
* @return The error code.
*/
Error generate_from_pos(
const std::string& prompt,
int32_t seq_len = 1024,
int64_t start_pos = 0,
std::function<void(const std::string&)> token_callback = {},
std::function<void(const ::executorch::extension::llm::Stats&)>
stats_callback = {});

private:
inline static const std::string kPresetPrompt =
"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: ";
Expand Down
25 changes: 13 additions & 12 deletions extension/llm/runner/text_prefiller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ TextPrefiller::TextPrefiller(

::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
std::vector<uint64_t>& prompt_tokens,
int64_t start_pos_index) {
int64_t& start_pos) {
ET_CHECK_MSG(!prompt_tokens.empty(), "Prompt cannot be null");
if (!text_decoder_runner_->is_method_loaded()) {
ET_CHECK_OK_OR_RETURN_ERROR(text_decoder_runner_->load());
Expand All @@ -43,45 +43,46 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
{1, num_prompt_tokens},
exec_aten::ScalarType::Long);

auto start_pos =
from_blob(&start_pos_index, {1}, exec_aten::ScalarType::Long);
auto start_pos_tensor =
from_blob(&start_pos, {1}, exec_aten::ScalarType::Long);

auto outputs_res = text_decoder_runner_->step(tokens, start_pos);
auto outputs_res = text_decoder_runner_->step(tokens, start_pos_tensor);

ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
ET_LOG(
Info, "Prefill token result numel(): %zu", outputs_res.get().numel());

start_pos += num_prompt_tokens;
cur_token = text_decoder_runner_->logits_to_token(outputs_res.get());
} else { // sequential prefill
int64_t pos = 0; // position in the sequence
// token & pos
int64_t pos_data = 0;
// NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
cur_token = prompt_tokens[0];

// initialize tensor wrappers
auto tokens = from_blob(&cur_token, {1, 1}, exec_aten::ScalarType::Long);

auto start_pos = from_blob(&pos_data, {1}, exec_aten::ScalarType::Long);
auto start_pos_tensor =
from_blob(&start_pos, {1}, exec_aten::ScalarType::Long);

// run the first token and get back logits tensor. Assuming the first token
// is bos so don't callback.
auto logits_tensor =
ET_UNWRAP(text_decoder_runner_->step(tokens, start_pos));
ET_UNWRAP(text_decoder_runner_->step(tokens, start_pos_tensor));

pos = 1; // start from index 1
pos += 1; // start the loop from index 1
start_pos += 1;

while (pos < num_prompt_tokens) {
// Run the model
pos_data = start_pos_index + pos;

// NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
cur_token = prompt_tokens[pos];

logits_tensor = ET_UNWRAP(text_decoder_runner_->step(tokens, start_pos));
logits_tensor =
ET_UNWRAP(text_decoder_runner_->step(tokens, start_pos_tensor));

pos++;
start_pos++;
}

cur_token = text_decoder_runner_->logits_to_token(logits_tensor);
Expand Down
2 changes: 1 addition & 1 deletion extension/llm/runner/text_prefiller.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class TextPrefiller {
*/
::executorch::runtime::Result<uint64_t> prefill(
std::vector<uint64_t>& prompt_tokens,
int64_t start_pos = 0);
int64_t& start_pos);

private:
TextDecoderRunner* text_decoder_runner_;
Expand Down
Loading