Skip to content

Commit 88dcd2c

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
refactor runner to fix bugs (#2752)
Summary: Lots of bugs before with kv cache implementation of runner. Trying this out The bugs were mostly various ways that the prompt sequence was getting trashed in kv cache mode Reviewed By: iseeyuan Differential Revision: D55496044
1 parent a624345 commit 88dcd2c

File tree

5 files changed

+160
-96
lines changed

5 files changed

+160
-96
lines changed

examples/models/llama2/runner/runner.cpp

Lines changed: 137 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
// The module takes in a string as input and emits a string as output.
1111

1212
#include <executorch/examples/models/llama2/runner/runner.h>
13+
#include <executorch/extension/evalue_util/print_evalue.h>
1314
#include <executorch/extension/runner_util/managed_tensor.h>
1415

1516
#include <ctime>
@@ -121,24 +122,6 @@ T Runner::getMetadataHelper(std::string method_name, T default_val) {
121122
return res;
122123
}
123124

124-
std::vector<exec_aten::SizesType> Runner::getKVCacheShape() {
125-
// shape: (n_layers, args.max_batch_size, args.max_seq_len, self.n_kv_heads,
126-
// self.head_dim)
127-
std::vector<std::string> methods = {
128-
"get_n_layers",
129-
"get_max_batch_size",
130-
"get_max_seq_len",
131-
"get_n_kv_heads",
132-
"get_head_dim"};
133-
std::vector<int64_t> default_values = {12, 1, 128, 32, 128};
134-
std::vector<exec_aten::SizesType> result;
135-
for (int i = 0; i < methods.size(); ++i) {
136-
// convert from int64_t to int32_t
137-
result.push_back(getMetadataHelper<int64_t>(methods[i], default_values[i]));
138-
}
139-
return result;
140-
}
141-
142125
template <typename T>
143126
int32_t Runner::logitsToToken(
144127
const exec_aten::Tensor& logits_tensor,
@@ -155,6 +138,73 @@ int32_t Runner::logitsToToken(
155138
return sampler_->sample(logits_last);
156139
}
157140

141+
// Given an input token. Set up the inputs for the model and execute a single
142+
// step. Returning the logits tensor.
143+
Result<torch::executor::Tensor> Runner::run_model_step(
144+
int64_t input_token,
145+
ManagedTensor& managed_tokens,
146+
ManagedTensor& managed_start_pos,
147+
size_t max_seq_len) {
148+
// ET_LOG(Info, "Input token %" PRIu64, input_token);
149+
if (use_kv_cache_) {
150+
std::vector<EValue> inputs;
151+
auto tokens = managed_tokens.get_aliasing_tensor();
152+
auto start_pos = managed_start_pos.get_aliasing_tensor();
153+
154+
// When using kv-cache our input is always 1 token, so just update to the
155+
// latest.
156+
tokens.mutable_data_ptr<int64_t>()[0] = input_token;
157+
158+
// inputs:[tokens, start_pos]
159+
inputs.push_back(tokens);
160+
inputs.push_back(start_pos);
161+
162+
Result<std::vector<EValue>> outputs_res = module_->forward(inputs);
163+
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
164+
ET_CHECK_MSG(
165+
outputs_res.get().size() == 1,
166+
"More then one output returned from executing LLM.");
167+
ET_CHECK_MSG(
168+
outputs_res.get()[0].isTensor(),
169+
"Non Tensor Output returned from executing LLM");
170+
171+
// Bump start_pos by 1
172+
start_pos.mutable_data_ptr<int64_t>()[0]++;
173+
174+
// Return the logits tensor
175+
return outputs_res.get()[0].toTensor();
176+
} else { // no kv cache
177+
std::vector<EValue> inputs;
178+
auto tokens = managed_tokens.get_aliasing_tensor();
179+
(void)managed_start_pos; // unused
180+
181+
// When not using kv-cache our input is the entire history of tokens we have
182+
// seen, so resize input to be 1 larger and append the new token to the end.
183+
// TODO does this work in ATen mode?
184+
tokens.mutable_data_ptr<int64_t>()[tokens.size(1) - 1] = input_token;
185+
186+
// inputs:[tokens]
187+
inputs.push_back(tokens);
188+
189+
Result<std::vector<EValue>> outputs_res = module_->forward(inputs);
190+
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
191+
ET_CHECK_MSG(
192+
outputs_res.get().size() == 1,
193+
"More then one output returned from executing LLM.");
194+
ET_CHECK_MSG(
195+
outputs_res.get()[0].isTensor(),
196+
"Non Tensor Output returned from executing LLM");
197+
198+
if (tokens.size(1) < max_seq_len) {
199+
// Resize the tokens tensor to be 1 larger for next step.
200+
managed_tokens.resize({1, static_cast<int>(tokens.size(1) + 1)});
201+
}
202+
203+
// Return the logits tensor
204+
return outputs_res.get()[0].toTensor();
205+
}
206+
}
207+
158208
Error Runner::generate(
159209
const std::string& prompt,
160210
int32_t seq_len,
@@ -189,9 +239,6 @@ Error Runner::generate(
189239
prompt_tokens,
190240
&num_prompt_tokens);
191241

192-
for (int i = 0; i < num_prompt_tokens; i++) {
193-
ET_LOG(Info, "prompt_tokens[%d]: %d", i, prompt_tokens[i]);
194-
}
195242
ET_CHECK_MSG(num_prompt_tokens >= 1, "Expected at least 1 prompt token");
196243
ET_CHECK_MSG(
197244
num_prompt_tokens < max_seq_len_,
@@ -202,89 +249,94 @@ Error Runner::generate(
202249
"Sequence length exceeded - please increase the seq_len value passed to generate()");
203250

204251
// start the main loop
205-
int next; // will store the next token in the sequence
206-
int64_t pos = num_prompt_tokens - 1; // position in the sequence
207-
int token = prompt_tokens[pos]; // prefill starts from 0 to num_prompt_tokens
208-
int logits_index = 0; // index of the logits tensor in the output
209-
std::vector<exec_aten::SizesType> input_shape = {1, 1};
210-
std::vector<exec_aten::SizesType> pos_shape = {1};
252+
int64_t pos = 0; // position in the sequence
253+
211254
std::vector<int64_t> token_data; // allocate space for the tokens
212-
std::vector<int64_t> pos_data; // allocate space for the tokens
255+
std::vector<exec_aten::SizesType> token_shape = {1, seq_len};
256+
257+
std::vector<int64_t> start_pos_data; // allocate space for the tokens
258+
std::vector<exec_aten::SizesType> start_pos_shape = {1};
213259

214260
if (use_kv_cache_) {
215-
// set pos to 0, refill token by token
216-
pos = 0;
261+
// hard code these to size 1 as kv cache is locked to static size right now.
217262
token_data.resize(1);
218-
pos_data.resize(seq_len);
263+
token_shape[1] = 1;
264+
start_pos_data.resize(1);
265+
start_pos_data.push_back(0);
219266
} else {
220-
// reserve data for tokens, notice the size is still 0.
267+
// reserve data for tokens, notice the size is still 0 but the capacity is
268+
// seq_len.
221269
token_data.resize(seq_len);
222270
}
223271

224272
// initialize tensor wrappers
225-
ManagedTensor pos_managed(
226-
pos_data.data(), pos_data.size(), pos_shape, ScalarType::Long);
227-
228-
// copy prompt tokens into data
229-
for (int i = 0; i <= pos; ++i) {
230-
// @lint-ignore CLANGTIDY facebook-hte-LocalUncheckedArrayBounds
231-
token_data[i] = prompt_tokens[i];
232-
if (i > 0) {
233-
printf(
234-
"%s",
235-
ET_UNWRAP(
236-
tokenizer_->decode(prompt_tokens[i - 1], prompt_tokens[i])));
273+
ManagedTensor tokens_managed(
274+
token_data.data(),
275+
128, // TODO clean up unused 128 here as ManagedTensor ignores this arg in
276+
// ctor
277+
token_shape,
278+
ScalarType::Long);
279+
// Create with the max shape to approapriately set the capacity of this
280+
// tensor, then resize back to 1 for first input.
281+
tokens_managed.resize({1, 1});
282+
283+
ManagedTensor start_pos_managed(
284+
start_pos_data.data(), 128, start_pos_shape, ScalarType::Long);
285+
286+
int64_t prev_token;
287+
int64_t cur_token = prompt_tokens[0];
288+
289+
// If we arent using the kv cache then we can batch prefill the prompt
290+
if (!use_kv_cache_) {
291+
tokens_managed.resize({1, num_prompt_tokens});
292+
for (int i = 0; i < num_prompt_tokens - 1; i++) {
293+
tokens_managed.get_aliasing_tensor().mutable_data_ptr<int64_t>()[i] =
294+
prompt_tokens[i];
295+
}
296+
// prefill tokens up to the last prompt token and then enter the loop with
297+
// the last promp token as the current token.
298+
cur_token = prompt_tokens[num_prompt_tokens - 1];
299+
pos = num_prompt_tokens - 1;
300+
301+
// Print the prompt for consistent output between single token prefill and
302+
// batch prefill.
303+
int prev = prompt_tokens[0];
304+
int cur;
305+
for (int i = 1; i < num_prompt_tokens; i++) {
306+
cur = prompt_tokens[i];
307+
auto piece_res = tokenizer_->decode(prev, cur);
308+
ET_CHECK_OK_OR_RETURN_ERROR(piece_res.error());
309+
util::safe_printf(piece_res.get());
310+
fflush(stdout);
311+
prev = cur;
237312
}
238313
}
239314

240-
// create a 1xN int tensor with next as value
241-
while (pos + 1 < seq_len) {
242-
// ET_LOG(Info, "Generating step %d...", pos);
243-
// set the current token in the tensor
244-
std::vector<EValue> inputs;
245-
if (use_kv_cache_) {
246-
token_data[0] = token;
247-
input_shape[1] = 1;
248-
// inputs: [tokens, start_pos, k_cache, v_cache]
249-
inputs.emplace_back(pos_managed.get_aliasing_tensor());
250-
} else {
251-
// @lint-ignore CLANGTIDY facebook-hte-LocalUncheckedArrayBounds
252-
token_data[pos] = token;
253-
input_shape[1] = pos + 1;
254-
}
255-
ManagedTensor token_managed(
256-
token_data.data(), token_data.size(), input_shape, ScalarType::Long);
257-
inputs.insert(inputs.begin(), token_managed.get_aliasing_tensor());
258-
// For kv cache, inputs: [tokens, start_pos, k_cache, v_cache]
259-
// Otherwise inputs: [tokens]
260-
Result<std::vector<EValue>> outputs_res = module_->forward(inputs);
261-
ET_CHECK_MSG(
262-
outputs_res.ok(),
263-
"Execution of method forward failed with status 0x%" PRIx32,
264-
static_cast<int32_t>(outputs_res.error()));
265-
// ET_LOG(Info, "Model executed successfully.");
315+
// Generate our tokens
316+
while (pos < seq_len - 1) {
317+
// Run the model
318+
Result<torch::executor::Tensor> logits_res =
319+
run_model_step(cur_token, tokens_managed, start_pos_managed, seq_len);
266320

267-
std::vector<EValue> outputs = outputs_res.get();
268-
// Check the outputs.
269-
ET_CHECK_MSG(
270-
outputs.size() == 1 && outputs.at(0).isTensor(),
271-
"Expecting output to have exactly 1 tensor output. Got %zu outputs.",
272-
outputs.size());
273321
if (pos == num_prompt_tokens) {
274322
timers_.first_token_ms = util::time_in_ms();
275323
} else if (pos == num_prompt_tokens - 1) {
276324
timers_.prompt_eval_end_ms = util::time_in_ms();
277325
}
278-
int32_t next_tok;
279-
exec_aten::Tensor logits_tensor = outputs.at(logits_index).toTensor();
326+
327+
ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error());
328+
exec_aten::Tensor& logits_tensor = logits_res.get();
329+
330+
prev_token = cur_token;
331+
280332
long sample_start_time_ms = util::time_in_ms();
281333
switch (logits_tensor.scalar_type()) {
282334
case ScalarType::Float: {
283-
next_tok = logitsToToken<float>(logits_tensor, pos, 0);
335+
cur_token = logitsToToken<float>(logits_tensor, pos, 0);
284336
break;
285337
}
286338
case ScalarType::Half: {
287-
next_tok = logitsToToken<exec_aten::Half>(logits_tensor, pos, 0);
339+
cur_token = logitsToToken<exec_aten::Half>(logits_tensor, pos, 0);
288340
break;
289341
}
290342
default:
@@ -299,19 +351,12 @@ Error Runner::generate(
299351
// advance the state machine
300352
if (pos < num_prompt_tokens - 1) {
301353
// prefill, force the next token to be the next prompt token
302-
next = prompt_tokens[pos + 1];
303-
} else {
304-
// otherwise sample the next token from the logits
305-
next = next_tok;
354+
cur_token = prompt_tokens[pos + 1];
306355
}
307-
// ET_LOG(Info, "Output saved, next = %d", next);
308356
pos++;
309-
if (use_kv_cache_) {
310-
pos_data.at(0) = pos;
311-
}
312357

313358
// print the token as string, decode it with the Tokenizer object
314-
auto piece_res = tokenizer_->decode(token, next);
359+
auto piece_res = tokenizer_->decode(prev_token, cur_token);
315360
ET_CHECK(piece_res.ok());
316361
const char* piece = piece_res.get();
317362

@@ -328,22 +373,20 @@ Error Runner::generate(
328373
}
329374

330375
// data-dependent terminating condition: we have n_eos_ number of EOS
331-
if (pos >= num_prompt_tokens && next == eos_id_) {
376+
if (pos >= num_prompt_tokens && cur_token == eos_id_) {
332377
printf("\n");
333378
ET_LOG(Info, "\nReached to the end of generation");
334379
break;
335380
}
336-
337-
token = next;
338381
}
339382
timers_.inference_end_ms = util::time_in_ms();
340383
printf("\n");
341384

342-
if (pos + 1 == seq_len) {
385+
if (pos == seq_len) {
343386
ET_LOG(Info, "Sequence length (%i tokens) reached!", seq_len);
344387
}
345388

346-
timers_.printReport(num_prompt_tokens, (pos + 1) - num_prompt_tokens);
389+
timers_.printReport(num_prompt_tokens, pos - num_prompt_tokens);
347390

348391
delete[] prompt_tokens;
349392
return Error::Ok;

examples/models/llama2/runner/runner.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <executorch/examples/models/llama2/sampler/sampler.h>
2121
#include <executorch/examples/models/llama2/tokenizer/tokenizer.h>
2222
#include <executorch/extension/module/module.h>
23+
#include <executorch/extension/runner_util/managed_tensor.h>
2324

2425
namespace torch::executor {
2526

@@ -45,7 +46,11 @@ class Runner {
4546
template <typename T>
4647
int32_t
4748
logitsToToken(const exec_aten::Tensor& logits_tensor, int64_t pos, T _);
48-
std::vector<exec_aten::SizesType> getKVCacheShape();
49+
Result<torch::executor::Tensor> run_model_step(
50+
int64_t input_token,
51+
ManagedTensor& tokens,
52+
ManagedTensor& start_pos,
53+
size_t max_seq_len);
4954
// metadata
5055
int32_t vocab_size_;
5156
int32_t bos_id_;

examples/models/llama2/runner/targets.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def define_common_targets():
3535
"//executorch/extension/runner_util:managed_tensor" + aten_suffix,
3636
"//executorch/extension/module:module" + aten_suffix,
3737
"//executorch/kernels/quantized:generated_lib" + aten_suffix,
38+
"//executorch/runtime/core/exec_aten:lib" + aten_suffix,
39+
"//executorch/runtime/core/exec_aten/util:tensor_util" + aten_suffix,
3840
] + (_get_operator_lib(aten)) + ([
3941
# Vulkan API currently cannot build on some platforms (e.g. Apple, FBCODE)
4042
# Therefore enable it explicitly for now to avoid failing tests

extension/runner_util/managed_tensor.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88

99
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1010
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
11+
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
12+
#include <executorch/runtime/platform/assert.h>
13+
1114
#ifdef USE_ATEN_LIB
1215
#include <torch/torch.h>
1316
#else
@@ -56,10 +59,20 @@ class ManagedTensor {
5659
data_ptr_,
5760
dim_order_.data(),
5861
strides_.data(),
59-
TensorShapeDynamism::STATIC);
62+
TensorShapeDynamism::DYNAMIC_BOUND);
6063
#endif
6164
}
6265

66+
void resize(const std::vector<SizesType>& new_sizes) {
67+
ET_CHECK_MSG(
68+
new_sizes.size() == sizes_.size(),
69+
"Cannot change rank of a managed tensor");
70+
auto err = resize_tensor(
71+
this->get_aliasing_tensor(),
72+
exec_aten::ArrayRef<SizesType>(new_sizes.data(), new_sizes.size()));
73+
ET_CHECK(err == Error::Ok);
74+
}
75+
6376
/**
6477
* Get the underlying Tensor object. This is assuming the copying is cheap.
6578
*/

extension/runner_util/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,6 @@ def define_common_targets():
3838
],
3939
deps = [
4040
"//executorch/runtime/core/exec_aten:lib" + aten_suffix,
41+
"//executorch/runtime/core/exec_aten/util:tensor_util" + aten_suffix,
4142
],
4243
)

0 commit comments

Comments
 (0)