Skip to content

Commit b4750be

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
refactor runner to fix bugs (#2752)
Summary: Lots of bugs before with kv cache implementation of runner. Trying this out Differential Revision: D55496044
1 parent 292ea5f commit b4750be

File tree

5 files changed

+160
-96
lines changed

5 files changed

+160
-96
lines changed

examples/models/llama2/runner/runner.cpp

Lines changed: 137 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
// The module takes in a string as input and emits a string as output.
1111

1212
#include <executorch/examples/models/llama2/runner/runner.h>
13+
#include <executorch/extension/evalue_util/print_evalue.h>
1314
#include <executorch/extension/runner_util/managed_tensor.h>
1415

1516
#include <ctime>
@@ -121,24 +122,6 @@ T Runner::getMetadataHelper(std::string method_name, T default_val) {
121122
return res;
122123
}
123124

124-
std::vector<exec_aten::SizesType> Runner::getKVCacheShape() {
125-
// shape: (n_layers, args.max_batch_size, args.max_seq_len, self.n_kv_heads,
126-
// self.head_dim)
127-
std::vector<std::string> methods = {
128-
"get_n_layers",
129-
"get_max_batch_size",
130-
"get_max_seq_len",
131-
"get_n_kv_heads",
132-
"get_head_dim"};
133-
std::vector<int64_t> default_values = {12, 1, 128, 32, 128};
134-
std::vector<exec_aten::SizesType> result;
135-
for (int i = 0; i < methods.size(); ++i) {
136-
// convert from int64_t to int32_t
137-
result.push_back(getMetadataHelper<int64_t>(methods[i], default_values[i]));
138-
}
139-
return result;
140-
}
141-
142125
template <typename T>
143126
int32_t Runner::logitsToToken(
144127
const exec_aten::Tensor& logits_tensor,
@@ -155,6 +138,73 @@ int32_t Runner::logitsToToken(
155138
return sampler_->sample(logits_last);
156139
}
157140

141+
// Given an input token. Set up the inputs for the model and execute a single
142+
// step. Returning the logits tensor.
143+
Result<torch::executor::Tensor> Runner::run_model_step(
144+
int64_t input_token,
145+
ManagedTensor& managed_tokens,
146+
ManagedTensor& managed_start_pos,
147+
size_t max_seq_len) {
148+
// ET_LOG(Info, "Input token %" PRIu64, input_token);
149+
if (use_kv_cache_) {
150+
std::vector<EValue> inputs;
151+
auto tokens = managed_tokens.get_aliasing_tensor();
152+
auto start_pos = managed_start_pos.get_aliasing_tensor();
153+
154+
// When using kv-cache our input is always 1 token, so just update to the
155+
// latest.
156+
tokens.mutable_data_ptr<int64_t>()[0] = input_token;
157+
158+
// inputs:[tokens, start_pos]
159+
inputs.push_back(tokens);
160+
inputs.push_back(start_pos);
161+
162+
Result<std::vector<EValue>> outputs_res = module_->forward(inputs);
163+
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
164+
ET_CHECK_MSG(
165+
outputs_res.get().size() == 1,
166+
"More then one output returned from executing LLM.");
167+
ET_CHECK_MSG(
168+
outputs_res.get()[0].isTensor(),
169+
"Non Tensor Output returned from executing LLM");
170+
171+
// Bump start_pos by 1
172+
start_pos.mutable_data_ptr<int64_t>()[0]++;
173+
174+
// Return the logits tensor
175+
return outputs_res.get()[0].toTensor();
176+
} else { // no kv cache
177+
std::vector<EValue> inputs;
178+
auto tokens = managed_tokens.get_aliasing_tensor();
179+
auto start_pos = managed_start_pos.get_aliasing_tensor();
180+
181+
// When not using kv-cache our input is the entire history of tokens we have
182+
// seen, so resize input to be 1 larger and append the new token to the end.
183+
// TODO does this work in ATen mode?
184+
tokens.mutable_data_ptr<int64_t>()[tokens.size(1) - 1] = input_token;
185+
186+
// inputs:[tokens]
187+
inputs.push_back(tokens);
188+
189+
Result<std::vector<EValue>> outputs_res = module_->forward(inputs);
190+
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
191+
ET_CHECK_MSG(
192+
outputs_res.get().size() == 1,
193+
"More then one output returned from executing LLM.");
194+
ET_CHECK_MSG(
195+
outputs_res.get()[0].isTensor(),
196+
"Non Tensor Output returned from executing LLM");
197+
198+
if (tokens.size(1) < max_seq_len) {
199+
// Resize the tokens tensor to be 1 larger for next step.
200+
managed_tokens.resize({1, static_cast<int>(tokens.size(1) + 1)});
201+
}
202+
203+
// Return the logits tensor
204+
return outputs_res.get()[0].toTensor();
205+
}
206+
}
207+
158208
Error Runner::generate(
159209
const std::string& prompt,
160210
int32_t seq_len,
@@ -189,9 +239,6 @@ Error Runner::generate(
189239
prompt_tokens,
190240
&num_prompt_tokens);
191241

192-
for (int i = 0; i < num_prompt_tokens; i++) {
193-
ET_LOG(Info, "prompt_tokens[%d]: %d", i, prompt_tokens[i]);
194-
}
195242
ET_CHECK_MSG(num_prompt_tokens >= 1, "Expected at least 1 prompt token");
196243
ET_CHECK_MSG(
197244
num_prompt_tokens < max_seq_len_,
@@ -202,89 +249,94 @@ Error Runner::generate(
202249
"Sequence length exceeded - please increase the seq_len value passed to generate()");
203250

204251
// start the main loop
205-
int next; // will store the next token in the sequence
206-
int64_t pos = num_prompt_tokens - 1; // position in the sequence
207-
int token = prompt_tokens[pos]; // prefill starts from 0 to num_prompt_tokens
208-
int logits_index = 0; // index of the logits tensor in the output
209-
std::vector<exec_aten::SizesType> input_shape = {1, 1};
210-
std::vector<exec_aten::SizesType> pos_shape = {1};
252+
int64_t pos = 0; // position in the sequence
253+
211254
std::vector<int64_t> token_data; // allocate space for the tokens
212-
std::vector<int64_t> pos_data; // allocate space for the tokens
255+
std::vector<exec_aten::SizesType> token_shape = {1, seq_len};
256+
257+
std::vector<int64_t> start_pos_data; // allocate space for the tokens
258+
std::vector<exec_aten::SizesType> start_pos_shape = {1};
213259

214260
if (use_kv_cache_) {
215-
// set pos to 0, refill token by token
216-
pos = 0;
261+
// hard code these to size 1 as kv cache is locked to static size right now.
217262
token_data.resize(1);
218-
pos_data.resize(seq_len);
263+
token_shape[1] = 1;
264+
start_pos_data.resize(1);
265+
start_pos_data[0] = 0;
219266
} else {
220-
// reserve data for tokens, notice the size is still 0.
267+
// reserve data for tokens, notice the size is still 0 but the capacity is
268+
// seq_len.
221269
token_data.resize(seq_len);
222270
}
223271

224272
// initialize tensor wrappers
225-
ManagedTensor pos_managed(
226-
pos_data.data(), pos_data.size(), pos_shape, ScalarType::Long);
227-
228-
// copy prompt tokens into data
229-
for (int i = 0; i <= pos; ++i) {
230-
// @lint-ignore CLANGTIDY facebook-hte-LocalUncheckedArrayBounds
231-
token_data[i] = prompt_tokens[i];
232-
if (i > 0) {
233-
printf(
234-
"%s",
235-
ET_UNWRAP(
236-
tokenizer_->decode(prompt_tokens[i - 1], prompt_tokens[i])));
273+
ManagedTensor tokens_managed(
274+
token_data.data(),
275+
128, // TODO clean up unused 128 here as ManagedTensor ignores this arg in
276+
// ctor
277+
token_shape,
278+
ScalarType::Long);
279+
// Create with the max shape to approapriately set the capacity of this
280+
// tensor, then resize back to 1 for first input.
281+
tokens_managed.resize({1, 1});
282+
283+
ManagedTensor start_pos_managed(
284+
start_pos_data.data(), 128, start_pos_shape, ScalarType::Long);
285+
286+
int64_t prev_token = -1;
287+
int64_t cur_token = prompt_tokens[0];
288+
289+
// If we arent using the kv cache then we can batch prefill the prompt
290+
if (!use_kv_cache_) {
291+
tokens_managed.resize({1, num_prompt_tokens});
292+
for (int i = 0; i < num_prompt_tokens - 1; i++) {
293+
tokens_managed.get_aliasing_tensor().mutable_data_ptr<int64_t>()[i] =
294+
prompt_tokens[i];
295+
}
296+
// prefill tokens up to the last prompt token and then enter the loop with
297+
// the last promp token as the current token.
298+
cur_token = prompt_tokens[num_prompt_tokens - 1];
299+
pos = num_prompt_tokens - 1;
300+
301+
// Print the prompt for consistent output between single token prefill and
302+
// batch prefill.
303+
int prev = prompt_tokens[0];
304+
int cur = -1;
305+
for (int i = 1; i < num_prompt_tokens; i++) {
306+
cur = prompt_tokens[i];
307+
auto piece_res = tokenizer_->decode(prev, cur);
308+
ET_CHECK_OK_OR_RETURN_ERROR(piece_res.error());
309+
util::safe_printf(piece_res.get());
310+
fflush(stdout);
311+
prev = cur;
237312
}
238313
}
239314

240-
// create a 1xN int tensor with next as value
241-
while (pos + 1 < seq_len) {
242-
// ET_LOG(Info, "Generating step %d...", pos);
243-
// set the current token in the tensor
244-
std::vector<EValue> inputs;
245-
if (use_kv_cache_) {
246-
token_data[0] = token;
247-
input_shape[1] = 1;
248-
// inputs: [tokens, start_pos, k_cache, v_cache]
249-
inputs.emplace_back(pos_managed.get_aliasing_tensor());
250-
} else {
251-
// @lint-ignore CLANGTIDY facebook-hte-LocalUncheckedArrayBounds
252-
token_data[pos] = token;
253-
input_shape[1] = pos + 1;
254-
}
255-
ManagedTensor token_managed(
256-
token_data.data(), token_data.size(), input_shape, ScalarType::Long);
257-
inputs.insert(inputs.begin(), token_managed.get_aliasing_tensor());
258-
// For kv cache, inputs: [tokens, start_pos, k_cache, v_cache]
259-
// Otherwise inputs: [tokens]
260-
Result<std::vector<EValue>> outputs_res = module_->forward(inputs);
261-
ET_CHECK_MSG(
262-
outputs_res.ok(),
263-
"Execution of method forward failed with status 0x%" PRIx32,
264-
static_cast<int32_t>(outputs_res.error()));
265-
// ET_LOG(Info, "Model executed successfully.");
315+
// Generate our tokens
316+
while (pos < seq_len - 1) {
317+
// Run the model
318+
Result<torch::executor::Tensor> logits_res =
319+
run_model_step(cur_token, tokens_managed, start_pos_managed, seq_len);
266320

267-
std::vector<EValue> outputs = outputs_res.get();
268-
// Check the outputs.
269-
ET_CHECK_MSG(
270-
outputs.size() == 1 && outputs.at(0).isTensor(),
271-
"Expecting output to have exactly 1 tensor output. Got %zu outputs.",
272-
outputs.size());
273321
if (pos == num_prompt_tokens) {
274322
timers_.first_token_ms = util::time_in_ms();
275323
} else if (pos == num_prompt_tokens - 1) {
276324
timers_.prompt_eval_end_ms = util::time_in_ms();
277325
}
278-
int32_t next_tok;
279-
exec_aten::Tensor logits_tensor = outputs.at(logits_index).toTensor();
326+
327+
ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error());
328+
exec_aten::Tensor& logits_tensor = logits_res.get();
329+
330+
prev_token = cur_token;
331+
280332
long sample_start_time_ms = util::time_in_ms();
281333
switch (logits_tensor.scalar_type()) {
282334
case ScalarType::Float: {
283-
next_tok = logitsToToken<float>(logits_tensor, pos, 0);
335+
cur_token = logitsToToken<float>(logits_tensor, pos, 0);
284336
break;
285337
}
286338
case ScalarType::Half: {
287-
next_tok = logitsToToken<exec_aten::Half>(logits_tensor, pos, 0);
339+
cur_token = logitsToToken<exec_aten::Half>(logits_tensor, pos, 0);
288340
break;
289341
}
290342
default:
@@ -299,19 +351,12 @@ Error Runner::generate(
299351
// advance the state machine
300352
if (pos < num_prompt_tokens - 1) {
301353
// prefill, force the next token to be the next prompt token
302-
next = prompt_tokens[pos + 1];
303-
} else {
304-
// otherwise sample the next token from the logits
305-
next = next_tok;
354+
cur_token = prompt_tokens[pos + 1];
306355
}
307-
// ET_LOG(Info, "Output saved, next = %d", next);
308356
pos++;
309-
if (use_kv_cache_) {
310-
pos_data.at(0) = pos;
311-
}
312357

313358
// print the token as string, decode it with the Tokenizer object
314-
auto piece_res = tokenizer_->decode(token, next);
359+
auto piece_res = tokenizer_->decode(prev_token, cur_token);
315360
ET_CHECK(piece_res.ok());
316361
const char* piece = piece_res.get();
317362

@@ -328,22 +373,20 @@ Error Runner::generate(
328373
}
329374

330375
// data-dependent terminating condition: we have n_eos_ number of EOS
331-
if (pos >= num_prompt_tokens && next == eos_id_) {
376+
if (pos >= num_prompt_tokens && cur_token == eos_id_) {
332377
printf("\n");
333378
ET_LOG(Info, "\nReached to the end of generation");
334379
break;
335380
}
336-
337-
token = next;
338381
}
339382
timers_.inference_end_ms = util::time_in_ms();
340383
printf("\n");
341384

342-
if (pos + 1 == seq_len) {
385+
if (pos == seq_len) {
343386
ET_LOG(Info, "Sequence length (%i tokens) reached!", seq_len);
344387
}
345388

346-
timers_.printReport(num_prompt_tokens, (pos + 1) - num_prompt_tokens);
389+
timers_.printReport(num_prompt_tokens, pos - num_prompt_tokens);
347390

348391
delete[] prompt_tokens;
349392
return Error::Ok;

examples/models/llama2/runner/runner.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <executorch/examples/models/llama2/sampler/sampler.h>
2121
#include <executorch/examples/models/llama2/tokenizer/tokenizer.h>
2222
#include <executorch/extension/module/module.h>
23+
#include <executorch/extension/runner_util/managed_tensor.h>
2324

2425
namespace torch::executor {
2526

@@ -45,7 +46,11 @@ class Runner {
4546
template <typename T>
4647
int32_t
4748
logitsToToken(const exec_aten::Tensor& logits_tensor, int64_t pos, T _);
48-
std::vector<exec_aten::SizesType> getKVCacheShape();
49+
Result<torch::executor::Tensor> run_model_step(
50+
int64_t input_token,
51+
ManagedTensor& tokens,
52+
ManagedTensor& start_pos,
53+
size_t max_seq_len);
4954
// metadata
5055
int32_t vocab_size_;
5156
int32_t bos_id_;

examples/models/llama2/runner/targets.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def define_common_targets():
3535
"//executorch/extension/runner_util:managed_tensor" + aten_suffix,
3636
"//executorch/extension/module:module" + aten_suffix,
3737
"//executorch/kernels/quantized:generated_lib" + aten_suffix,
38+
"//executorch/runtime/core/exec_aten:lib" + aten_suffix,
39+
"//executorch/runtime/core/exec_aten/util:tensor_util" + aten_suffix,
3840
] + (_get_operator_lib(aten)) + ([
3941
# Vulkan API currently cannot build on some platforms (e.g. Apple, FBCODE)
4042
# Therefore enable it explicitly for now to avoid failing tests

extension/runner_util/managed_tensor.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88

99
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1010
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
11+
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
12+
#include <executorch/runtime/platform/assert.h>
13+
1114
#ifdef USE_ATEN_LIB
1215
#include <torch/torch.h>
1316
#else
@@ -56,10 +59,20 @@ class ManagedTensor {
5659
data_ptr_,
5760
dim_order_.data(),
5861
strides_.data(),
59-
TensorShapeDynamism::STATIC);
62+
TensorShapeDynamism::DYNAMIC_BOUND);
6063
#endif
6164
}
6265

66+
void resize(const std::vector<SizesType>& new_sizes) {
67+
ET_CHECK_MSG(
68+
new_sizes.size() == sizes_.size(),
69+
"Cannot change rank of a managed tensor");
70+
auto err = resize_tensor(
71+
this->get_aliasing_tensor(),
72+
exec_aten::ArrayRef<SizesType>(new_sizes.data(), new_sizes.size()));
73+
ET_CHECK(err == Error::Ok);
74+
}
75+
6376
/**
6477
* Get the underlying Tensor object. This is assuming the copying is cheap.
6578
*/

extension/runner_util/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,6 @@ def define_common_targets():
3838
],
3939
deps = [
4040
"//executorch/runtime/core/exec_aten:lib" + aten_suffix,
41+
"//executorch/runtime/core/exec_aten/util:tensor_util" + aten_suffix,
4142
],
4243
)

0 commit comments

Comments
 (0)