1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the BSD-style license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ // A simple LLaVA runner that includes preprocessing and post processing logic.
10
+ // The module takes in a string as input and emits a string as output.
11
+
12
+ #include < executorch/examples/models/llava/runner/llava_image_prefiller.h>
13
+ #include < executorch/examples/models/llava/runner/llava_runner.h>
14
+ #include < executorch/examples/models/llava/runner/llava_text_decoder_runner.h>
15
+ #include < executorch/extension/llm/tokenizer/bpe_tokenizer.h>
16
+
17
+ #include < ctime>
18
+ #include < memory>
19
+ #include < sstream>
20
+ #include < vector>
21
+
22
+ namespace torch ::executor {
23
+
24
+ bool LlavaRunner::is_loaded () {
25
+ Result<std::unordered_set<std::string>> methods_res = module_->method_names ();
26
+ if (methods_res.error () != Error::Ok) {
27
+ ET_LOG (Error, " Failed to get method names" );
28
+ ET_CHECK_MSG (false , " Failed to get method names" );
29
+ }
30
+ std::unordered_set<std::string> methods = methods_res.get ();
31
+ bool methods_exist = methods.find (" image_encoder" ) != methods.end () &&
32
+ methods.find (" token_embedding" ) != methods.end () &&
33
+ methods.find (" text_decoder" ) != methods.end ();
34
+ if (!methods_exist) {
35
+ for (const auto & method : methods) {
36
+ ET_LOG (Error, " Method: %s" , method.c_str ());
37
+ }
38
+ ET_CHECK_MSG (
39
+ methods_exist,
40
+ " Missing required methods (image_encoder, token_embedding, text_decoder) in the model" );
41
+ }
42
+ bool methods_loaded = module_->is_method_loaded (" image_encoder" ) &&
43
+ module_->is_method_loaded (" token_embedding" ) &&
44
+ module_->is_method_loaded (" text_decoder" );
45
+ return methods_loaded && tokenizer_ && text_decoder_runner_ &&
46
+ text_prefiller_ && image_prefiller_ && text_token_generator_;
47
+ }
48
+
49
+ Error LlavaRunner::load () {
50
+ if (is_loaded ()) {
51
+ return Error::Ok;
52
+ }
53
+ stats_.model_load_start_ms = util::time_in_ms ();
54
+
55
+ ET_CHECK_OK_OR_RETURN_ERROR (module_->load_method (" image_encoder" ));
56
+ ET_CHECK_OK_OR_RETURN_ERROR (module_->load_method (" token_embedding" ));
57
+ ET_CHECK_OK_OR_RETURN_ERROR (module_->load_method (" text_decoder" ));
58
+
59
+ // Load the tokenizer
60
+ tokenizer_ = std::make_unique<BPETokenizer>();
61
+ tokenizer_->load (tokenizer_path_);
62
+
63
+ // Load the text decoder runner
64
+ text_decoder_runner_ = std::make_unique<LlavaTextDecoderRunner>(
65
+ module_.get (), tokenizer_->vocab_size (), temperature_);
66
+
67
+ // Load the text prefiller
68
+ text_prefiller_ = std::make_unique<TextPrefiller>(
69
+ tokenizer_.get (),
70
+ text_decoder_runner_.get (),
71
+ /* use_kv_cache=*/ true ,
72
+ /* enable_parallel_prefill=*/ true );
73
+
74
+ // Load the image prefiller
75
+ image_prefiller_ = std::make_unique<LlavaImagePrefiller>(module_.get ());
76
+
77
+ // Load the text token generator
78
+ text_token_generator_ = std::make_unique<TextTokenGenerator>(
79
+ tokenizer_.get (),
80
+ text_decoder_runner_.get (),
81
+ /* use_kv_cache=*/ true ,
82
+ tokenizer_->eos_tok (),
83
+ &stats_);
84
+
85
+ stats_.model_load_end_ms = util::time_in_ms ();
86
+ return Error::Ok;
87
+ }
88
+
89
+ Error LlavaRunner::generate (
90
+ std::vector<Image>& images,
91
+ const std::string& prompt,
92
+ int32_t seq_len,
93
+ std::function<void (const std::string&)> token_callback,
94
+ std::function<void(const Stats&)> stats_callback) {
95
+ ET_CHECK_MSG (!prompt.empty (), " Prompt cannot be null" );
96
+ if (!is_loaded ()) {
97
+ ET_CHECK_OK_OR_RETURN_ERROR (load ());
98
+ }
99
+
100
+ // Wrap the token_callback with print function
101
+ std::function<void (const std::string&)> wrapped_callback =
102
+ [token_callback](const std::string& piece) {
103
+ util::safe_printf (piece.c_str ());
104
+ fflush (stdout);
105
+ if (token_callback) {
106
+ token_callback (piece);
107
+ }
108
+ };
109
+
110
+ int64_t pos = 0 ;
111
+
112
+ // prefill preset prompt
113
+ std::vector<uint64_t > preset_prompt_tokens =
114
+ ET_UNWRAP (tokenizer_->encode (kPresetPrompt , /* bos=*/ 1 , /* eos=*/ 0 ));
115
+ size_t num_preset_tokens = preset_prompt_tokens.size ();
116
+
117
+ ET_UNWRAP (text_prefiller_->prefill (preset_prompt_tokens, pos));
118
+ pos += num_preset_tokens;
119
+
120
+ // prefill images
121
+ for (auto & image : images) {
122
+ auto logits = ET_UNWRAP (image_prefiller_->prefill (image, pos));
123
+ pos += logits.size (1 );
124
+ }
125
+
126
+ // prefill user prompt. No BOS because preset prompt already has it.
127
+ std::vector<uint64_t > user_prompt_tokens =
128
+ ET_UNWRAP (tokenizer_->encode (prompt, /* bos=*/ 0 , /* eos=*/ 0 ));
129
+ size_t num_user_tokens = user_prompt_tokens.size ();
130
+
131
+ uint64_t prefill_next_token = ET_UNWRAP (
132
+ text_prefiller_->prefill (user_prompt_tokens, pos, wrapped_callback));
133
+ pos += num_user_tokens;
134
+
135
+ // Generate tokens
136
+ int64_t num_generated_tokens = ET_UNWRAP (text_token_generator_->generate (
137
+ {prefill_next_token}, pos, seq_len, wrapped_callback));
138
+
139
+ // Bookkeeping
140
+ stats_.num_prompt_tokens = num_preset_tokens + num_user_tokens;
141
+ stats_.num_generated_tokens = num_generated_tokens;
142
+ ::executorch::llm::print_report (stats_);
143
+ if (stats_callback) {
144
+ stats_callback (stats_);
145
+ }
146
+
147
+ return Error::Ok;
148
+ }
149
+
150
+ } // namespace torch::executor
0 commit comments