Skip to content

Commit 84100d1

Browse files
authored
[llava][20/N] Add llava runner using building blocks in e/llm/runner (#4666)
* [llava][18/N] Move token generation loop to a class As titled. This PR moves the token generation loop in llama2 runner into a new class so it can be reused. [ghstack-poisoned] * [llava][19/N] Add multimodal runner base class and build file [ghstack-poisoned] * [llava][20/N] Add llava runner using building blocks in e/llm/runner [ghstack-poisoned] * Update base for Update on "[llava][20/N] Add llava runner using building blocks in e/llm/runner" Add llava runner that uses runner lib in `extension/llm/runner`. [ghstack-poisoned] * Update base for Update on "[llava][20/N] Add llava runner using building blocks in e/llm/runner" Add llava runner that uses runner lib in `extension/llm/runner`. Differential Revision: [D61292846](https://our.internmc.facebook.com/intern/diff/D61292846) [ghstack-poisoned]
1 parent ef56414 commit 84100d1

File tree

5 files changed

+346
-0
lines changed

5 files changed

+346
-0
lines changed
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
#
8+
# Simple CMake build system for LLaVa runner.
9+
#
10+
# ### Editing this file ###
11+
#
12+
# This file should be formatted with
13+
# ~~~
14+
# cmake-format -i CMakeLists.txt
15+
# ~~~
16+
# It should also be cmake-lint clean.
17+
#
18+
19+
if(NOT EXECUTORCH_ROOT)
20+
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../..)
21+
endif()
22+
23+
include(${EXECUTORCH_ROOT}/build/Utils.cmake)
24+
include(${EXECUTORCH_ROOT}/build/Codegen.cmake)
25+
# Let files say "include <executorch/path/to/header.h>".
26+
set(_common_include_directories ${EXECUTORCH_ROOT}/..)
27+
28+
# build llava_runner library
29+
set(_llava_runner__srcs
30+
"${CMAKE_CURRENT_SOURCE_DIR}/llava_runner.cpp"
31+
"${EXECUTORCH_ROOT}/extension/llm/sampler/sampler.cpp"
32+
"${EXECUTORCH_ROOT}/extension/llm/tokenizer/bpe_tokenizer.cpp"
33+
)
34+
35+
# extension llm runner lib
36+
add_subdirectory(
37+
${CMAKE_CURRENT_SOURCE_DIR}/../../../../extension/llm/runner
38+
${CMAKE_CURRENT_BINARY_DIR}/../../../../extension/llm/runner
39+
)
40+
41+
add_library(llava_runner STATIC ${_llava_runner__srcs})
42+
43+
set(llava_runner_deps executorch extension_module extension_data_loader
44+
extension_llm_runner
45+
)
46+
47+
target_link_libraries(llava_runner PUBLIC ${llava_runner_deps})
48+
49+
target_include_directories(
50+
llava_runner INTERFACE ${_common_include_directories} ${EXECUTORCH_ROOT}
51+
)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// Given a image tensor, prefill the KV cache of LLaVA.
10+
11+
#pragma once
12+
13+
#include <executorch/extension/llm/runner/image_prefiller.h>
14+
#include <executorch/extension/runner_util/managed_tensor.h>
15+
16+
namespace torch::executor {
17+
18+
class LlavaImagePrefiller : public ImagePrefiller {
19+
public:
20+
LlavaImagePrefiller(Module* module) : ImagePrefiller(module){};
21+
/**
22+
* Prefill an LLM Module with the given image input.
23+
* @param image The image input to LLaVa.
24+
* @param start_pos The starting position in KV cache of the input in the LLM
25+
* @return logits of the image prefill.
26+
*/
27+
inline Result<exec_aten::Tensor> prefill(
28+
Image& image,
29+
int64_t start_pos = 0) {
30+
ManagedTensor managed_images(
31+
image.data.data(), {3, image.height, image.width}, ScalarType::Byte);
32+
// Run image encoder
33+
std::vector<EValue> image_encoder_outputs = ET_UNWRAP(module_->execute(
34+
"image_encoder", {managed_images.get_aliasing_tensor()}));
35+
36+
// inputs:[start_pos, embeds]
37+
ManagedTensor managed_start_pos(&start_pos, {1}, ScalarType::Long);
38+
auto start_pos_tensor = managed_start_pos.get_aliasing_tensor();
39+
40+
// Run text model
41+
std::vector<EValue> outputs_res = ET_UNWRAP(module_->execute(
42+
"text_decoder", {start_pos_tensor, image_encoder_outputs[0]}));
43+
ET_CHECK_MSG(
44+
outputs_res[0].isTensor(),
45+
"Non Tensor Output returned from executing image prefill");
46+
47+
return outputs_res[0].toTensor();
48+
}
49+
};
50+
51+
} // namespace torch::executor
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// A simple LLaVA runner that includes preprocessing and post processing logic.
10+
// The runner takes in a prompt string as well as a list of images as input and
11+
// emits a string as output.
12+
13+
#include <executorch/examples/models/llava/runner/llava_image_prefiller.h>
14+
#include <executorch/examples/models/llava/runner/llava_runner.h>
15+
#include <executorch/examples/models/llava/runner/llava_text_decoder_runner.h>
16+
#include <executorch/extension/llm/tokenizer/bpe_tokenizer.h>
17+
18+
#include <ctime>
19+
#include <memory>
20+
#include <sstream>
21+
#include <vector>
22+
23+
namespace torch::executor {
24+
25+
bool LlavaRunner::is_loaded() {
26+
Result<std::unordered_set<std::string>> methods_res = module_->method_names();
27+
if (methods_res.error() != Error::Ok) {
28+
ET_LOG(Error, "Failed to get method names");
29+
ET_CHECK_MSG(false, "Failed to get method names");
30+
}
31+
std::unordered_set<std::string> methods = methods_res.get();
32+
bool methods_exist = methods.find("image_encoder") != methods.end() &&
33+
methods.find("token_embedding") != methods.end() &&
34+
methods.find("text_decoder") != methods.end();
35+
if (!methods_exist) {
36+
for (const auto& method : methods) {
37+
ET_LOG(Error, "Method: %s", method.c_str());
38+
}
39+
ET_CHECK_MSG(
40+
methods_exist,
41+
"Missing required methods (image_encoder, token_embedding, text_decoder) in the model");
42+
}
43+
bool methods_loaded = module_->is_method_loaded("image_encoder") &&
44+
module_->is_method_loaded("token_embedding") &&
45+
module_->is_method_loaded("text_decoder");
46+
return methods_loaded && tokenizer_ && text_decoder_runner_ &&
47+
text_prefiller_ && image_prefiller_ && text_token_generator_;
48+
}
49+
50+
Error LlavaRunner::load() {
51+
if (is_loaded()) {
52+
return Error::Ok;
53+
}
54+
stats_.model_load_start_ms = util::time_in_ms();
55+
56+
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("image_encoder"));
57+
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("token_embedding"));
58+
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("text_decoder"));
59+
60+
// Load the tokenizer
61+
tokenizer_ = std::make_unique<BPETokenizer>();
62+
tokenizer_->load(tokenizer_path_);
63+
64+
// Load the text decoder runner
65+
text_decoder_runner_ = std::make_unique<LlavaTextDecoderRunner>(
66+
module_.get(), tokenizer_->vocab_size(), temperature_);
67+
68+
// Load the text prefiller
69+
text_prefiller_ = std::make_unique<TextPrefiller>(
70+
tokenizer_.get(),
71+
text_decoder_runner_.get(),
72+
/*use_kv_cache=*/true,
73+
/*enable_parallel_prefill=*/true);
74+
75+
// Load the image prefiller
76+
image_prefiller_ = std::make_unique<LlavaImagePrefiller>(module_.get());
77+
78+
// Load the text token generator
79+
text_token_generator_ = std::make_unique<TextTokenGenerator>(
80+
tokenizer_.get(),
81+
text_decoder_runner_.get(),
82+
/*use_kv_cache=*/true,
83+
tokenizer_->eos_tok(),
84+
&stats_);
85+
86+
stats_.model_load_end_ms = util::time_in_ms();
87+
return Error::Ok;
88+
}
89+
90+
Error LlavaRunner::generate(
91+
std::vector<Image>& images,
92+
const std::string& prompt,
93+
int32_t seq_len,
94+
std::function<void(const std::string&)> token_callback,
95+
std::function<void(const Stats&)> stats_callback) {
96+
ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
97+
if (!is_loaded()) {
98+
ET_CHECK_OK_OR_RETURN_ERROR(load());
99+
}
100+
101+
// Wrap the token_callback with print function
102+
std::function<void(const std::string&)> wrapped_callback =
103+
[token_callback](const std::string& piece) {
104+
util::safe_printf(piece.c_str());
105+
fflush(stdout);
106+
if (token_callback) {
107+
token_callback(piece);
108+
}
109+
};
110+
111+
int64_t pos = 0;
112+
113+
// prefill preset prompt
114+
std::vector<uint64_t> preset_prompt_tokens =
115+
ET_UNWRAP(tokenizer_->encode(kPresetPrompt, /*bos=*/1, /*eos=*/0));
116+
size_t num_preset_tokens = preset_prompt_tokens.size();
117+
118+
ET_UNWRAP(text_prefiller_->prefill(preset_prompt_tokens, pos));
119+
pos += num_preset_tokens;
120+
121+
// prefill images
122+
for (auto& image : images) {
123+
auto logits = ET_UNWRAP(image_prefiller_->prefill(image, pos));
124+
pos += logits.size(1);
125+
}
126+
127+
// prefill user prompt. No BOS because preset prompt already has it.
128+
std::vector<uint64_t> user_prompt_tokens =
129+
ET_UNWRAP(tokenizer_->encode(prompt, /*bos=*/0, /*eos=*/0));
130+
size_t num_user_tokens = user_prompt_tokens.size();
131+
132+
uint64_t prefill_next_token = ET_UNWRAP(
133+
text_prefiller_->prefill(user_prompt_tokens, pos, wrapped_callback));
134+
pos += num_user_tokens;
135+
136+
// Generate tokens
137+
int64_t num_generated_tokens = ET_UNWRAP(text_token_generator_->generate(
138+
{prefill_next_token}, pos, seq_len, wrapped_callback));
139+
140+
// Bookkeeping
141+
stats_.num_prompt_tokens = num_preset_tokens + num_user_tokens;
142+
stats_.num_generated_tokens = num_generated_tokens;
143+
::executorch::llm::print_report(stats_);
144+
if (stats_callback) {
145+
stats_callback(stats_);
146+
}
147+
148+
return Error::Ok;
149+
}
150+
151+
} // namespace torch::executor
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// A simple multimodal LLM runner that includes preprocessing and post
10+
// processing logic.
11+
#pragma once
12+
13+
#include <cstdint>
14+
#include <functional>
15+
#include <memory>
16+
#include <string>
17+
#include <type_traits>
18+
#include <unordered_map>
19+
20+
#include <executorch/extension/llm/runner/multimodal_runner.h>
21+
22+
namespace torch::executor {
23+
24+
class LlavaRunner : public MultimodalRunner {
25+
public:
26+
explicit LlavaRunner(
27+
const std::string& model_path,
28+
const std::string& tokenizer_path,
29+
const float temperature = 0.8f)
30+
: MultimodalRunner(model_path, tokenizer_path, temperature){};
31+
bool is_loaded();
32+
Error load();
33+
Error generate(
34+
std::vector<Image>& images,
35+
const std::string& prompt,
36+
int32_t seq_len = 1024,
37+
std::function<void(const std::string&)> token_callback = {},
38+
std::function<void(const Stats&)> stats_callback = {});
39+
40+
private:
41+
inline static const std::string kPresetPrompt =
42+
"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: ";
43+
};
44+
45+
} // namespace torch::executor
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// Given inputs, run a text decoder in Llava and return the output.
10+
11+
#pragma once
12+
13+
#include <executorch/extension/llm/runner/text_decoder_runner.h>
14+
15+
namespace torch::executor {
16+
17+
class LlavaTextDecoderRunner : public TextDecoderRunner {
18+
public:
19+
LlavaTextDecoderRunner(Module* module, int32_t vocab_size, float temperature)
20+
: TextDecoderRunner(module, true, vocab_size, temperature){};
21+
22+
Result<exec_aten::Tensor> step(
23+
ManagedTensor& managed_tokens,
24+
ManagedTensor& managed_start_pos) {
25+
auto tokens = managed_tokens.get_aliasing_tensor();
26+
auto start_pos = managed_start_pos.get_aliasing_tensor();
27+
28+
// run token embedding
29+
std::vector<EValue> token_embedding_outputs =
30+
ET_UNWRAP(module_->execute("token_embedding", {tokens}));
31+
32+
// run text model
33+
std::vector<EValue> outputs_res = ET_UNWRAP(module_->execute(
34+
"text_decoder", {start_pos, token_embedding_outputs[0]}));
35+
36+
ET_CHECK_MSG(
37+
outputs_res.size() == 1,
38+
"More then one output returned from executing LLM.");
39+
ET_CHECK_MSG(
40+
outputs_res[0].isTensor(),
41+
"Non Tensor Output returned from executing LLM");
42+
43+
// Return the logits tensor
44+
return outputs_res[0].toTensor();
45+
}
46+
};
47+
48+
} // namespace torch::executor

0 commit comments

Comments
 (0)