Skip to content

Commit 9a84d5a

Browse files
committed
[llava][20/N] Add llava runner using building blocks in e/llm/runner
ghstack-source-id: bbd753f Pull Request resolved: #4666
1 parent 9a88484 commit 9a84d5a

File tree

5 files changed

+351
-0
lines changed

5 files changed

+351
-0
lines changed
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
#
8+
# Simple CMake build system for LLaVa runner.
9+
#
10+
# ### Editing this file ###
11+
#
12+
# This file should be formatted with
13+
# ~~~
14+
# cmake-format -i CMakeLists.txt
15+
# ~~~
16+
# It should also be cmake-lint clean.
17+
#
18+
19+
if(NOT EXECUTORCH_ROOT)
20+
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../..)
21+
endif()
22+
23+
include(${EXECUTORCH_ROOT}/build/Utils.cmake)
24+
include(${EXECUTORCH_ROOT}/build/Codegen.cmake)
25+
# Let files say "include <executorch/path/to/header.h>".
26+
set(_common_include_directories ${EXECUTORCH_ROOT}/..)
27+
28+
# build llava_runner library
29+
set(_llava_runner__srcs
30+
"${CMAKE_CURRENT_SOURCE_DIR}/llava_runner.cpp"
31+
"${EXECUTORCH_ROOT}/extension/llm/sampler/sampler.cpp"
32+
"${EXECUTORCH_ROOT}/extension/llm/tokenizer/bpe_tokenizer.cpp"
33+
)
34+
35+
# extension llm runner lib
36+
add_subdirectory(
37+
${CMAKE_CURRENT_SOURCE_DIR}/../../../../extension/llm/runner
38+
${CMAKE_CURRENT_BINARY_DIR}/../../../../extension/llm/runner
39+
)
40+
41+
add_library(llava_runner STATIC ${_llava_runner__srcs})
42+
43+
set(llava_runner_deps executorch extension_module extension_data_loader extension_llm_runner)
44+
45+
target_link_libraries(llava_runner PUBLIC ${llava_runner_deps})
46+
47+
target_include_directories(
48+
llava_runner INTERFACE ${_common_include_directories} ${EXECUTORCH_ROOT}
49+
)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// Given a image tensor, prefill the KV cache of LLaVA.
10+
11+
#pragma once
12+
13+
#include <executorch/extension/llm/runner/image_prefiller.h>
14+
#include <executorch/extension/runner_util/managed_tensor.h>
15+
16+
namespace torch::executor {
17+
18+
class LlavaImagePrefiller : public ImagePrefiller {
19+
public:
20+
LlavaImagePrefiller(Module* module) : ImagePrefiller(module){};
21+
/**
22+
* Prefill an LLM Module with the given image input.
23+
* @param image The image input to LLaVa.
24+
* @param start_pos The starting position in KV cache of the input in the LLM
25+
* @return logits of the image prefill.
26+
*/
27+
inline Result<exec_aten::Tensor> prefill(
28+
Image& image,
29+
int64_t start_pos = 0) {
30+
ManagedTensor managed_images(
31+
image.data.data(), {3, image.height, image.width}, ScalarType::Byte);
32+
// Run image encoder
33+
std::vector<EValue> image_encoder_outputs = ET_UNWRAP(module_->execute(
34+
"image_encoder", {managed_images.get_aliasing_tensor()}));
35+
36+
// inputs:[start_pos, embeds]
37+
ManagedTensor managed_start_pos(&start_pos, {1}, ScalarType::Long);
38+
auto start_pos_tensor = managed_start_pos.get_aliasing_tensor();
39+
40+
// Run text model
41+
std::vector<EValue> outputs_res = ET_UNWRAP(module_->execute(
42+
"text_decoder", {start_pos_tensor, image_encoder_outputs[0]}));
43+
ET_CHECK_MSG(
44+
outputs_res[0].isTensor(),
45+
"Non Tensor Output returned from executing image prefill");
46+
47+
return outputs_res[0].toTensor();
48+
}
49+
};
50+
51+
} // namespace torch::executor
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// A simple LLaVA runner that includes preprocessing and post processing logic.
10+
// The module takes in a string as input and emits a string as output.
11+
12+
#include <executorch/examples/models/llava/runner/llava_image_prefiller.h>
13+
#include <executorch/examples/models/llava/runner/llava_runner.h>
14+
#include <executorch/examples/models/llava/runner/llava_text_decoder_runner.h>
15+
#include <executorch/extension/llm/tokenizer/bpe_tokenizer.h>
16+
17+
#include <ctime>
18+
#include <memory>
19+
#include <sstream>
20+
#include <vector>
21+
22+
namespace torch::executor {
23+
24+
bool LlavaRunner::is_loaded() {
25+
Result<std::unordered_set<std::string>> methods_res = module_->method_names();
26+
if (methods_res.error() != Error::Ok) {
27+
ET_LOG(Error, "Failed to get method names");
28+
ET_CHECK_MSG(false, "Failed to get method names");
29+
}
30+
std::unordered_set<std::string> methods = methods_res.get();
31+
bool methods_exist = methods.find("image_encoder") != methods.end() &&
32+
methods.find("token_embedding") != methods.end() &&
33+
methods.find("text_decoder") != methods.end();
34+
if (!methods_exist) {
35+
for (const auto& method : methods) {
36+
ET_LOG(Error, "Method: %s", method.c_str());
37+
}
38+
ET_CHECK_MSG(
39+
methods_exist,
40+
"Missing required methods (image_encoder, token_embedding, text_decoder) in the model");
41+
}
42+
bool methods_loaded = module_->is_method_loaded("image_encoder") &&
43+
module_->is_method_loaded("token_embedding") &&
44+
module_->is_method_loaded("text_decoder");
45+
return methods_loaded && tokenizer_ && text_decoder_runner_ &&
46+
text_prefiller_ && image_prefiller_ && text_token_generator_;
47+
}
48+
49+
Error LlavaRunner::load() {
50+
if (is_loaded()) {
51+
return Error::Ok;
52+
}
53+
stats_.model_load_start_ms = util::time_in_ms();
54+
55+
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("image_encoder"));
56+
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("token_embedding"));
57+
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("text_decoder"));
58+
59+
// Load the tokenizer
60+
tokenizer_ = std::make_unique<BPETokenizer>();
61+
tokenizer_->load(tokenizer_path_);
62+
63+
// Load the text decoder runner
64+
text_decoder_runner_ = std::make_unique<LlavaTextDecoderRunner>(
65+
module_.get(), tokenizer_->vocab_size(), temperature_);
66+
67+
// Load the text prefiller
68+
text_prefiller_ = std::make_unique<TextPrefiller>(
69+
tokenizer_.get(),
70+
text_decoder_runner_.get(),
71+
/*use_kv_cache=*/true,
72+
/*enable_parallel_prefill=*/true);
73+
74+
// Load the image prefiller
75+
image_prefiller_ = std::make_unique<LlavaImagePrefiller>(module_.get());
76+
77+
// Load the text token generator
78+
text_token_generator_ = std::make_unique<TextTokenGenerator>(
79+
tokenizer_.get(),
80+
text_decoder_runner_.get(),
81+
/*use_kv_cache=*/true,
82+
tokenizer_->eos_tok(),
83+
&stats_);
84+
85+
stats_.model_load_end_ms = util::time_in_ms();
86+
return Error::Ok;
87+
}
88+
89+
Error LlavaRunner::generate(
90+
std::vector<Image>& images,
91+
const std::string& prompt,
92+
int32_t seq_len,
93+
std::function<void(const std::string&)> token_callback,
94+
std::function<void(const Stats&)> stats_callback) {
95+
ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
96+
if (!is_loaded()) {
97+
ET_CHECK_OK_OR_RETURN_ERROR(load());
98+
}
99+
100+
// Wrap the token_callback with print function
101+
std::function<void(const std::string&)> wrapped_callback =
102+
[token_callback](const std::string& piece) {
103+
util::safe_printf(piece.c_str());
104+
fflush(stdout);
105+
if (token_callback) {
106+
token_callback(piece);
107+
}
108+
};
109+
110+
int64_t pos = 0;
111+
112+
// prefill preset prompt
113+
std::vector<uint64_t> preset_prompt_tokens =
114+
ET_UNWRAP(tokenizer_->encode(kPresetPrompt, /*bos=*/1, /*eos=*/0));
115+
size_t num_preset_tokens = preset_prompt_tokens.size();
116+
117+
ET_UNWRAP(text_prefiller_->prefill(preset_prompt_tokens, pos));
118+
pos += num_preset_tokens;
119+
120+
// prefill images
121+
for (auto& image : images) {
122+
auto logits = ET_UNWRAP(image_prefiller_->prefill(image, pos));
123+
pos += logits.size(1);
124+
}
125+
126+
// prefill user prompt. No BOS because preset prompt already has it.
127+
std::vector<uint64_t> user_prompt_tokens =
128+
ET_UNWRAP(tokenizer_->encode(prompt, /*bos=*/0, /*eos=*/0));
129+
size_t num_user_tokens = user_prompt_tokens.size();
130+
131+
uint64_t prefill_next_token = ET_UNWRAP(
132+
text_prefiller_->prefill(user_prompt_tokens, pos, wrapped_callback));
133+
pos += num_user_tokens;
134+
135+
// Generate tokens
136+
int64_t num_generated_tokens = ET_UNWRAP(text_token_generator_->generate(
137+
{prefill_next_token}, pos, seq_len, wrapped_callback));
138+
139+
// Bookkeeping
140+
stats_.num_prompt_tokens = num_preset_tokens + num_user_tokens;
141+
stats_.num_generated_tokens = num_generated_tokens;
142+
::executorch::llm::print_report(stats_);
143+
if (stats_callback) {
144+
stats_callback(stats_);
145+
}
146+
147+
return Error::Ok;
148+
}
149+
150+
} // namespace torch::executor
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// A simple multimodal LLM runner that includes preprocessing and post
10+
// processing logic.
11+
#pragma once
12+
13+
#include <cstdint>
14+
#include <functional>
15+
#include <memory>
16+
#include <string>
17+
#include <type_traits>
18+
#include <unordered_map>
19+
20+
#include <executorch/extension/llm/runner/multimodal_runner.h>
21+
#include <executorch/extension/llm/runner/stats.h>
22+
#include <executorch/extension/llm/runner/text_decoder_runner.h>
23+
#include <executorch/extension/llm/runner/text_prefiller.h>
24+
#include <executorch/extension/llm/runner/text_token_generator.h>
25+
#include <executorch/extension/llm/sampler/sampler.h>
26+
#include <executorch/extension/llm/tokenizer/tokenizer.h>
27+
#include <executorch/extension/module/module.h>
28+
#include <executorch/extension/runner_util/managed_tensor.h>
29+
30+
namespace torch::executor {
31+
32+
class LlavaRunner : public MultimodalRunner {
33+
public:
34+
explicit LlavaRunner(
35+
const std::string& model_path,
36+
const std::string& tokenizer_path,
37+
const float temperature = 0.8f)
38+
: MultimodalRunner(model_path, tokenizer_path, temperature){};
39+
bool is_loaded();
40+
Error load();
41+
Error generate(
42+
std::vector<Image>& images,
43+
const std::string& prompt,
44+
int32_t seq_len = 1024,
45+
std::function<void(const std::string&)> token_callback = {},
46+
std::function<void(const Stats&)> stats_callback = {});
47+
48+
private:
49+
inline static const std::string kPresetPrompt =
50+
"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: ";
51+
};
52+
53+
} // namespace torch::executor
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// Given inputs, run a text decoder in Llava and return the output.
10+
11+
#pragma once
12+
13+
#include <executorch/extension/llm/runner/text_decoder_runner.h>
14+
15+
namespace torch::executor {
16+
17+
class LlavaTextDecoderRunner : public TextDecoderRunner {
18+
public:
19+
LlavaTextDecoderRunner(Module* module, int32_t vocab_size, float temperature)
20+
: TextDecoderRunner(module, true, vocab_size, temperature){};
21+
22+
Result<exec_aten::Tensor> step(
23+
ManagedTensor& managed_tokens,
24+
ManagedTensor& managed_start_pos) {
25+
auto tokens = managed_tokens.get_aliasing_tensor();
26+
auto start_pos = managed_start_pos.get_aliasing_tensor();
27+
28+
// run token embedding
29+
std::vector<EValue> token_embedding_outputs =
30+
ET_UNWRAP(module_->execute("token_embedding", {tokens}));
31+
32+
// run text model
33+
std::vector<EValue> outputs_res = ET_UNWRAP(module_->execute(
34+
"text_decoder", {start_pos, token_embedding_outputs[0]}));
35+
36+
ET_CHECK_MSG(
37+
outputs_res.size() == 1,
38+
"More then one output returned from executing LLM.");
39+
ET_CHECK_MSG(
40+
outputs_res[0].isTensor(),
41+
"Non Tensor Output returned from executing LLM");
42+
43+
// Return the logits tensor
44+
return outputs_res[0].toTensor();
45+
}
46+
};
47+
48+
} // namespace torch::executor

0 commit comments

Comments
 (0)