@@ -93,6 +93,32 @@ Result<uint64_t> LlavaRunner::prefill_prompt(
93
93
return text_prefiller_->prefill (prompt_tokens, start_pos);
94
94
}
95
95
96
+ Error LlavaRunner::generate_from_pos (
97
+ const std::string& prompt,
98
+ int32_t seq_len,
99
+ int64_t start_pos,
100
+ std::function<void (const std::string&)> token_callback,
101
+ std::function<void(const ::executorch::extension::llm::Stats&)>
102
+ stats_callback) {
103
+ // prefill user prompt. No BOS because preset prompt already has it.
104
+ token_callback (prompt);
105
+
106
+ uint64_t prefill_next_token =
107
+ ET_UNWRAP (prefill_prompt (prompt, start_pos, /* bos=*/ 0 , /* eos*/ 0 ));
108
+ stats_.num_prompt_tokens = start_pos;
109
+
110
+ // Generate tokens
111
+ int64_t num_generated_tokens = ET_UNWRAP (text_token_generator_->generate (
112
+ {prefill_next_token}, start_pos, seq_len, token_callback));
113
+
114
+ // Bookkeeping
115
+ stats_.num_generated_tokens = num_generated_tokens;
116
+ ::executorch::llm::print_report (stats_);
117
+ if (stats_callback) {
118
+ stats_callback (stats_);
119
+ }
120
+ }
121
+
96
122
Error LlavaRunner::generate (
97
123
std::vector<Image> images,
98
124
const std::string& prompt,
@@ -122,25 +148,9 @@ Error LlavaRunner::generate(
122
148
// prefill images
123
149
prefill_images (images, pos);
124
150
125
- // prefill user prompt. No BOS because preset prompt already has it.
126
- wrapped_callback (prompt);
127
-
128
- uint64_t prefill_next_token =
129
- ET_UNWRAP (prefill_prompt (prompt, pos, /* bos=*/ 0 , /* eos*/ 0 ));
130
- stats_.num_prompt_tokens = pos;
131
-
132
151
// Generate tokens
133
- int64_t num_generated_tokens = ET_UNWRAP (text_token_generator_->generate (
134
- {prefill_next_token}, pos, seq_len, wrapped_callback));
135
-
136
- // Bookkeeping
137
- stats_.num_generated_tokens = num_generated_tokens;
138
- ::executorch::llm::print_report (stats_);
139
- if (stats_callback) {
140
- stats_callback (stats_);
141
- }
142
-
143
- return Error::Ok;
152
+ return generate_from_pos (
153
+ prompt, seq_len, pos, wrapped_callback, stats_callback);
144
154
}
145
155
146
156
} // namespace torch::executor
0 commit comments