Skip to content

Commit b43579c

Browse files
committed
[llava][18/N] Move token generation loop to a class
As titled. This PR moves the token generation loop in llama2 runner into a new class so it can be reused. ghstack-source-id: 1108ada Pull Request resolved: #4652
1 parent e71fa03 commit b43579c

File tree

8 files changed

+203
-73
lines changed

8 files changed

+203
-73
lines changed

examples/models/llama2/runner/runner.cpp

Lines changed: 19 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ Runner::Runner(
5252
}
5353

5454
bool Runner::is_loaded() const {
55-
return module_->is_loaded() && tokenizer_ && text_decoder_runner_;
55+
return module_->is_loaded() && tokenizer_ && text_decoder_runner_ &&
56+
text_prefiller_ && text_token_generator_;
5657
}
5758

5859
Error Runner::load() {
@@ -104,6 +105,13 @@ Error Runner::load() {
104105
use_kv_cache_,
105106
enable_parallel_prefill_);
106107

108+
text_token_generator_ = std::make_unique<TextTokenGenerator>(
109+
tokenizer_.get(),
110+
text_decoder_runner_.get(),
111+
use_kv_cache_,
112+
eos_id_,
113+
&stats_);
114+
107115
return Error::Ok;
108116
}
109117

@@ -176,81 +184,19 @@ Error Runner::generate(
176184
wrapped_callback(ET_UNWRAP(tokenizer_->decode(cur_token, cur_token)));
177185

178186
// start the main loop
179-
int64_t pos = num_prompt_tokens; // position in the sequence
180-
181-
// Generate the rest of the sequence
182-
std::vector<uint64_t> token_data; // allocate space for the tokens
183-
std::vector<exec_aten::SizesType> token_shape;
184-
185-
if (use_kv_cache_) {
186-
// hard code these to size 1 as kv cache is locked to static size right now.
187-
token_data = {cur_token};
188-
token_shape = {1, 1};
189-
} else {
190-
token_data = prompt_tokens;
191-
token_data.push_back(cur_token);
192-
token_shape = {1, num_prompt_tokens + 1};
193-
}
194-
195-
// initialize tensor wrappers
196-
ManagedTensor tokens_managed(
197-
token_data.data(), token_shape, ScalarType::Long);
198-
199-
ManagedTensor start_pos_managed(&pos, {1}, ScalarType::Long);
200-
201-
uint64_t prev_token;
202-
203-
// Generate our tokens
204-
while (pos < seq_len - 1) {
205-
// Run the model
206-
Result<exec_aten::Tensor> logits_res =
207-
text_decoder_runner_->step(tokens_managed, start_pos_managed);
187+
prompt_tokens.push_back(cur_token);
188+
int64_t num_generated_tokens = ET_UNWRAP(text_token_generator_->generate(
189+
prompt_tokens, num_prompt_tokens, seq_len, wrapped_callback));
208190

209-
ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error());
210-
exec_aten::Tensor& logits_tensor = logits_res.get();
211-
212-
prev_token = cur_token;
213-
214-
long sample_start_time_ms = util::time_in_ms();
215-
cur_token = text_decoder_runner_->logits_to_token(logits_tensor);
216-
stats_.aggregate_sampling_time_ms +=
217-
util::time_in_ms() - sample_start_time_ms;
218-
219-
pos++;
220-
221-
if (use_kv_cache_) {
222-
// update the token tensor. token_data will not be empty.
223-
// NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds)
224-
token_data[0] = cur_token;
225-
} else {
226-
// push it to the back
227-
token_data.push_back(cur_token);
228-
tokens_managed.resize({1, static_cast<int>(token_data.size())});
229-
}
230-
231-
// data-dependent terminating condition: we have n_eos_ number of EOS
232-
if (pos >= num_prompt_tokens && cur_token == eos_id_) {
233-
printf("\n");
234-
ET_LOG(Info, "\nReached to the end of generation");
235-
break;
236-
}
237-
238-
// print the token as string, decode it with the Tokenizer object
239-
wrapped_callback(ET_UNWRAP(tokenizer_->decode(prev_token, cur_token)));
240-
241-
if (shouldStop_) {
242-
break;
243-
}
244-
}
245191
stats_.inference_end_ms = util::time_in_ms();
246192
printf("\n");
247193

248-
if (pos == seq_len) {
194+
if (num_prompt_tokens + num_generated_tokens == seq_len) {
249195
ET_LOG(Info, "Sequence length (%i tokens) reached!", seq_len);
250196
}
251197

252198
stats_.num_prompt_tokens = num_prompt_tokens;
253-
stats_.num_generated_tokens = pos - num_prompt_tokens;
199+
stats_.num_generated_tokens = num_generated_tokens;
254200
::executorch::llm::print_report(stats_);
255201
if (stats_callback) {
256202
stats_callback(stats_);
@@ -260,6 +206,10 @@ Error Runner::generate(
260206
}
261207

262208
void Runner::stop() {
263-
shouldStop_ = true;
209+
if (is_loaded()) {
210+
text_token_generator_->stop();
211+
} else {
212+
ET_LOG(Error, "Token generator is not loaded, cannot stop");
213+
}
264214
}
265215
} // namespace torch::executor

examples/models/llama2/runner/runner.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <executorch/extension/llm/runner/stats.h>
2222
#include <executorch/extension/llm/runner/text_decoder_runner.h>
2323
#include <executorch/extension/llm/runner/text_prefiller.h>
24+
#include <executorch/extension/llm/runner/text_token_generator.h>
2425
#include <executorch/extension/llm/sampler/sampler.h>
2526
#include <executorch/extension/llm/tokenizer/tokenizer.h>
2627
#include <executorch/extension/module/module.h>
@@ -66,6 +67,7 @@ class Runner {
6667
std::unique_ptr<Module> module_;
6768
std::unique_ptr<TextDecoderRunner> text_decoder_runner_;
6869
std::unique_ptr<TextPrefiller> text_prefiller_;
70+
std::unique_ptr<TextTokenGenerator> text_token_generator_;
6971
std::string tokenizer_path_;
7072
std::unique_ptr<Tokenizer> tokenizer_;
7173

examples/models/llama2/runner/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def define_common_targets():
3636
"//executorch/extension/llm/runner:stats",
3737
"//executorch/extension/llm/runner:text_decoder_runner" + aten_suffix,
3838
"//executorch/extension/llm/runner:text_prefiller" + aten_suffix,
39+
"//executorch/extension/llm/runner:text_token_generator" + aten_suffix,
3940
"//executorch/extension/evalue_util:print_evalue" + aten_suffix,
4041
"//executorch/extension/runner_util:managed_tensor" + aten_suffix,
4142
"//executorch/extension/module:module" + aten_suffix,

extension/llm/runner/stats.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88

99
// Runner stats for LLM
1010
#pragma once
11+
#include <executorch/extension/llm/runner/util.h>
12+
#include <executorch/runtime/platform/log.h>
1113
#include <cinttypes>
1214
#include <sstream>
1315
// patternlint-disable-next-line executorch-cpp-nostdinc
1416
#include <string>
15-
16-
#include <executorch/runtime/platform/log.h>
1717
namespace executorch::llm {
1818

1919
struct Stats {
@@ -40,6 +40,18 @@ struct Stats {
4040
int64_t num_prompt_tokens;
4141
// Token count from generated (total - prompt)
4242
int64_t num_generated_tokens;
43+
inline void on_sampling_begin() {
44+
aggregate_sampling_timer_start_timestamp =
45+
::torch::executor::util::time_in_ms();
46+
}
47+
inline void on_sampling_end() {
48+
aggregate_sampling_time_ms += ::torch::executor::util::time_in_ms() -
49+
aggregate_sampling_timer_start_timestamp;
50+
aggregate_sampling_timer_start_timestamp = 0;
51+
}
52+
53+
private:
54+
long aggregate_sampling_timer_start_timestamp = 0;
4355
};
4456

4557
static constexpr auto kTopp = 0.9f;

extension/llm/runner/targets.bzl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,20 @@ def define_common_targets():
4545
],
4646
)
4747

48+
runtime.cxx_library(
49+
name = "text_token_generator" + aten_suffix,
50+
exported_headers = ["text_token_generator.h"],
51+
visibility = [
52+
"@EXECUTORCH_CLIENTS",
53+
],
54+
exported_deps = [
55+
":text_decoder_runner" + aten_suffix,
56+
"//executorch/extension/llm/tokenizer:tokenizer_header",
57+
"//executorch/extension/module:module" + aten_suffix,
58+
"//executorch/extension/runner_util:managed_tensor" + aten_suffix,
59+
],
60+
)
61+
4862
runtime.cxx_library(
4963
name = "metadata_util" + aten_suffix,
5064
exported_headers = ["metadata_util.h"],

extension/llm/runner/text_decoder_runner.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ class TextDecoderRunner {
5353
return module_->is_method_loaded(method_name);
5454
}
5555

56+
inline void stop() {
57+
should_stop_ = true;
58+
}
59+
5660
/**
5761
* Sample the next token from the logits tensor.
5862
* @param logits_tensor The logits tensor.
@@ -90,6 +94,7 @@ class TextDecoderRunner {
9094
Module* module_;
9195
std::unique_ptr<Sampler> sampler_;
9296
bool use_kv_cache_;
97+
bool should_stop_{false};
9398
};
9499

95100
} // namespace torch::executor

extension/llm/runner/text_prefiller.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,11 @@ Result<uint64_t> TextPrefiller::prefill(
5959
// NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
6060
uint64_t prev = prompt_tokens[0];
6161
uint64_t cur;
62-
for (int i = 1; i < prompt_tokens.size(); i++) {
62+
for (int i = 0; i < prompt_tokens.size(); i++) {
6363
cur = prompt_tokens[i];
64-
token_callback(ET_UNWRAP(tokenizer_->decode(prev, cur)));
64+
if (cur != tokenizer_->bos_tok()) {
65+
token_callback(ET_UNWRAP(tokenizer_->decode(prev, cur)));
66+
}
6567
prev = cur;
6668
}
6769
cur_token = text_decoder_runner_->logits_to_token(outputs_res.get());
@@ -82,6 +84,11 @@ Result<uint64_t> TextPrefiller::prefill(
8284
// is bos so don't callback.
8385
exec_aten::Tensor logits_tensor = ET_UNWRAP(
8486
text_decoder_runner_->step(managed_tokens, managed_start_pos));
87+
88+
// if first token is not bos, we need to callback
89+
if (cur_token != tokenizer_->bos_tok()) {
90+
token_callback(ET_UNWRAP(tokenizer_->decode(cur_token, cur_token)));
91+
}
8592
pos = 1; // start from index 1
8693

8794
while (pos < num_prompt_tokens) {
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// Generate tokens in a loop.
10+
#pragma once
11+
12+
#include <executorch/extension/llm/runner/stats.h>
13+
#include <executorch/extension/llm/runner/text_decoder_runner.h>
14+
#include <executorch/extension/llm/tokenizer/tokenizer.h>
15+
16+
namespace torch::executor {
17+
using Stats = ::executorch::llm::Stats;
18+
19+
class TextTokenGenerator {
20+
public:
21+
TextTokenGenerator(
22+
Tokenizer* tokenizer,
23+
TextDecoderRunner* text_decoder_runner,
24+
bool use_kv_cache,
25+
uint64_t eos_id,
26+
Stats* stats)
27+
: tokenizer_(tokenizer),
28+
text_decoder_runner_(text_decoder_runner),
29+
eos_id_(eos_id),
30+
use_kv_cache_(use_kv_cache),
31+
stats_(stats) {}
32+
33+
/**
34+
* Token generation loop.
35+
* @param tokens prompt tokens as well as the first token generated by
36+
* prefill.
37+
* @param start_pos the start position of the new tokens, based on how many
38+
* prompt tokens is prefilled.
39+
* @param seq_len the total sequence length, including the prompt tokens, next
40+
* token from prefill and new tokens.
41+
* @param token_callback what to do after a token is generated.
42+
* @return how many tokens are generated.
43+
*/
44+
inline Result<int64_t> generate(
45+
std::vector<uint64_t> tokens,
46+
int64_t start_pos,
47+
int32_t seq_len,
48+
std::function<void(const std::string&)> token_callback) {
49+
ET_CHECK_MSG(
50+
!tokens.empty(), "Token generation loop shouldn't take empty tokens");
51+
int64_t pos = start_pos; // position in the sequence
52+
53+
std::vector<uint64_t> token_data; // allocate space for the tokens
54+
std::vector<exec_aten::SizesType> token_shape;
55+
56+
// Token after prefill
57+
uint64_t cur_token = tokens.back();
58+
uint64_t prev_token;
59+
60+
if (use_kv_cache_) {
61+
// hard code these to size 1 as kv cache is locked to static size right
62+
// now.
63+
token_data = {cur_token};
64+
token_shape = {1, 1};
65+
} else {
66+
token_data = tokens;
67+
token_shape = {1, static_cast<int>(tokens.size())};
68+
}
69+
70+
// initialize tensor wrappers
71+
ManagedTensor tokens_managed(
72+
token_data.data(), token_shape, ScalarType::Long);
73+
74+
ManagedTensor start_pos_managed(&pos, {1}, ScalarType::Long);
75+
76+
// Generate our tokens
77+
while (pos < seq_len) {
78+
// Run the model
79+
Result<exec_aten::Tensor> logits_res =
80+
text_decoder_runner_->step(tokens_managed, start_pos_managed);
81+
82+
ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error());
83+
exec_aten::Tensor& logits_tensor = logits_res.get();
84+
85+
prev_token = cur_token;
86+
87+
stats_->on_sampling_begin();
88+
cur_token = text_decoder_runner_->logits_to_token(logits_tensor);
89+
stats_->on_sampling_end();
90+
91+
pos++;
92+
93+
if (use_kv_cache_) {
94+
// update the token tensor. token_data will not be empty.
95+
// NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds)
96+
token_data[0] = cur_token;
97+
} else {
98+
// push it to the back
99+
token_data.push_back(cur_token);
100+
tokens_managed.resize({1, static_cast<int>(token_data.size())});
101+
}
102+
103+
// print the token as string, decode it with the Tokenizer object
104+
token_callback(ET_UNWRAP(tokenizer_->decode(prev_token, cur_token)));
105+
106+
if (should_stop_) {
107+
break;
108+
}
109+
110+
// data-dependent terminating condition: we have n_eos_ number of EOS
111+
if (cur_token == eos_id_) {
112+
printf("\n");
113+
ET_LOG(Info, "\nReached to the end of generation");
114+
break;
115+
}
116+
}
117+
return pos - start_pos;
118+
}
119+
120+
/**
121+
* Stop the generation loop.
122+
*/
123+
inline void stop() {
124+
should_stop_ = true;
125+
}
126+
127+
private:
128+
Tokenizer* tokenizer_;
129+
TextDecoderRunner* text_decoder_runner_;
130+
uint64_t eos_id_;
131+
bool use_kv_cache_;
132+
133+
// state machine
134+
bool should_stop_ = false;
135+
136+
// stats
137+
Stats* stats_;
138+
};
139+
} // namespace torch::executor

0 commit comments

Comments
 (0)