Skip to content

Commit b82b583

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
refactor runner to fix bugs
Summary: Lots of bugs before with kv cache implementation of runner. Trying this out Differential Revision: D55496044
1 parent 923ff39 commit b82b583

File tree

5 files changed

+135
-98
lines changed

5 files changed

+135
-98
lines changed

examples/models/llama2/runner/runner.cpp

Lines changed: 112 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -121,24 +121,6 @@ T Runner::getMetadataHelper(std::string method_name, T default_val) {
121121
return res;
122122
}
123123

124-
std::vector<exec_aten::SizesType> Runner::getKVCacheShape() {
125-
// shape: (n_layers, args.max_batch_size, args.max_seq_len, self.n_kv_heads,
126-
// self.head_dim)
127-
std::vector<std::string> methods = {
128-
"get_n_layers",
129-
"get_max_batch_size",
130-
"get_max_seq_len",
131-
"get_n_kv_heads",
132-
"get_head_dim"};
133-
std::vector<int64_t> default_values = {12, 1, 128, 32, 128};
134-
std::vector<exec_aten::SizesType> result;
135-
for (int i = 0; i < methods.size(); ++i) {
136-
// convert from int64_t to int32_t
137-
result.push_back(getMetadataHelper<int64_t>(methods[i], default_values[i]));
138-
}
139-
return result;
140-
}
141-
142124
template <typename T>
143125
int32_t Runner::logitsToToken(
144126
const exec_aten::Tensor& logits_tensor,
@@ -155,6 +137,73 @@ int32_t Runner::logitsToToken(
155137
return sampler_->sample(logits_last);
156138
}
157139

140+
// Given an input token. Set up the inputs for the model and execute a single
141+
// step. Returning the logits tensor.
142+
Result<torch::executor::Tensor> Runner::run_model_step(
143+
int64_t input_token,
144+
ManagedTensor& managed_tokens,
145+
ManagedTensor& managed_start_pos,
146+
size_t max_seq_len) {
147+
// ET_LOG(Info, "Input token %" PRIu64, input_token);
148+
if (use_kv_cache_) {
149+
std::vector<EValue> inputs;
150+
auto tokens = managed_tokens.get_aliasing_tensor();
151+
auto start_pos = managed_start_pos.get_aliasing_tensor();
152+
153+
// When using kv-cache our input is always 1 token, so just update to the
154+
// latest.
155+
tokens.mutable_data_ptr<int64_t>()[0] = input_token;
156+
157+
// inputs:[tokens, start_pos]
158+
inputs.push_back(tokens);
159+
inputs.push_back(start_pos);
160+
161+
Result<std::vector<EValue>> outputs_res = module_->forward(inputs);
162+
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
163+
ET_CHECK_MSG(
164+
outputs_res.get().size() == 1,
165+
"More then one output returned from executing LLM.");
166+
ET_CHECK_MSG(
167+
outputs_res.get()[0].isTensor(),
168+
"Non Tensor Output returned from executing LLM");
169+
170+
// Bump start_pos by 1
171+
start_pos.mutable_data_ptr<int64_t>()[0]++;
172+
173+
// Return the logits tensor
174+
return outputs_res.get()[0].toTensor();
175+
} else { // no kv cache
176+
std::vector<EValue> inputs;
177+
auto tokens = managed_tokens.get_aliasing_tensor();
178+
auto start_pos = managed_start_pos.get_aliasing_tensor();
179+
180+
// When not using kv-cache our input is the entire history of tokens we have
181+
// seen, so resize input to be 1 larger and append the new token to the end.
182+
// TODO does this work in ATen mode?
183+
tokens.mutable_data_ptr<int64_t>()[tokens.size(1) - 1] = input_token;
184+
185+
// inputs:[tokens]
186+
inputs.push_back(tokens);
187+
188+
Result<std::vector<EValue>> outputs_res = module_->forward(inputs);
189+
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
190+
ET_CHECK_MSG(
191+
outputs_res.get().size() == 1,
192+
"More then one output returned from executing LLM.");
193+
ET_CHECK_MSG(
194+
outputs_res.get()[0].isTensor(),
195+
"Non Tensor Output returned from executing LLM");
196+
197+
if (tokens.size(1) < max_seq_len) {
198+
// Resize the tokens tensor to be 1 larger for next step.
199+
managed_tokens.resize({1, static_cast<int>(tokens.size(1) + 1)});
200+
}
201+
202+
// Return the logits tensor
203+
return outputs_res.get()[0].toTensor();
204+
}
205+
}
206+
158207
Error Runner::generate(
159208
const std::string& prompt,
160209
int32_t seq_len,
@@ -189,9 +238,6 @@ Error Runner::generate(
189238
prompt_tokens,
190239
&num_prompt_tokens);
191240

192-
for (int i = 0; i < num_prompt_tokens; i++) {
193-
ET_LOG(Info, "prompt_tokens[%d]: %d", i, prompt_tokens[i]);
194-
}
195241
ET_CHECK_MSG(num_prompt_tokens >= 1, "Expected at least 1 prompt token");
196242
ET_CHECK_MSG(
197243
num_prompt_tokens < max_seq_len_,
@@ -202,89 +248,68 @@ Error Runner::generate(
202248
"Sequence length exceeded - please increase the seq_len value passed to generate()");
203249

204250
// start the main loop
205-
int next; // will store the next token in the sequence
206-
int64_t pos = num_prompt_tokens - 1; // position in the sequence
207-
int token = prompt_tokens[pos]; // prefill starts from 0 to num_prompt_tokens
208-
int logits_index = 0; // index of the logits tensor in the output
209-
std::vector<exec_aten::SizesType> input_shape = {1, 1};
210-
std::vector<exec_aten::SizesType> pos_shape = {1};
251+
int64_t pos = 0; // position in the sequence
252+
211253
std::vector<int64_t> token_data; // allocate space for the tokens
212-
std::vector<int64_t> pos_data; // allocate space for the tokens
254+
std::vector<exec_aten::SizesType> token_shape = {1, seq_len};
255+
256+
std::vector<int64_t> start_pos_data; // allocate space for the tokens
257+
std::vector<exec_aten::SizesType> start_pos_shape = {1};
213258

214259
if (use_kv_cache_) {
215-
// set pos to 0, refill token by token
216-
pos = 0;
260+
// hard code these to size 1 as kv cache is locked to static size right now.
217261
token_data.resize(1);
218-
pos_data.resize(seq_len);
262+
token_shape[1] = 1;
263+
start_pos_data.resize(1);
264+
start_pos_data[0] = 0;
219265
} else {
220-
// reserve data for tokens, notice the size is still 0.
266+
// reserve data for tokens, notice the size is still 0 but the capacity is
267+
// seq_len.
221268
token_data.resize(seq_len);
222269
}
223270

224271
// initialize tensor wrappers
225-
ManagedTensor pos_managed(
226-
pos_data.data(), pos_data.size(), pos_shape, ScalarType::Long);
227-
228-
// copy prompt tokens into data
229-
for (int i = 0; i <= pos; ++i) {
230-
// @lint-ignore CLANGTIDY facebook-hte-LocalUncheckedArrayBounds
231-
token_data[i] = prompt_tokens[i];
232-
if (i > 0) {
233-
printf(
234-
"%s",
235-
ET_UNWRAP(
236-
tokenizer_->decode(prompt_tokens[i - 1], prompt_tokens[i])));
237-
}
238-
}
239-
240-
// create a 1xN int tensor with next as value
241-
while (pos + 1 < seq_len) {
272+
ManagedTensor tokens_managed(
273+
token_data.data(),
274+
128, // TODO clean up unused 128 here as ManagedTensor ignores this arg in
275+
// ctor
276+
token_shape,
277+
ScalarType::Long);
278+
// Create with the max shape to approapriately set the capacity of this
279+
// tensor, then resize back to 1 for first input.
280+
tokens_managed.resize({1, 1});
281+
282+
ManagedTensor start_pos_managed(
283+
start_pos_data.data(), 128, start_pos_shape, ScalarType::Long);
284+
285+
int64_t prev_token = -1;
286+
int64_t cur_token = prompt_tokens[0];
287+
288+
while (pos < seq_len - 1) {
242289
// ET_LOG(Info, "Generating step %d...", pos);
243-
// set the current token in the tensor
244-
std::vector<EValue> inputs;
245-
if (use_kv_cache_) {
246-
token_data[0] = token;
247-
input_shape[1] = 1;
248-
// inputs: [tokens, start_pos, k_cache, v_cache]
249-
inputs.emplace_back(pos_managed.get_aliasing_tensor());
250-
} else {
251-
// @lint-ignore CLANGTIDY facebook-hte-LocalUncheckedArrayBounds
252-
token_data[pos] = token;
253-
input_shape[1] = pos + 1;
254-
}
255-
ManagedTensor token_managed(
256-
token_data.data(), token_data.size(), input_shape, ScalarType::Long);
257-
inputs.insert(inputs.begin(), token_managed.get_aliasing_tensor());
258-
// For kv cache, inputs: [tokens, start_pos, k_cache, v_cache]
259-
// Otherwise inputs: [tokens]
260-
Result<std::vector<EValue>> outputs_res = module_->forward(inputs);
261-
ET_CHECK_MSG(
262-
outputs_res.ok(),
263-
"Execution of method forward failed with status 0x%" PRIx32,
264-
static_cast<int32_t>(outputs_res.error()));
265-
// ET_LOG(Info, "Model executed successfully.");
266290

267-
std::vector<EValue> outputs = outputs_res.get();
268-
// Check the outputs.
269-
ET_CHECK_MSG(
270-
outputs.size() == 1 && outputs.at(0).isTensor(),
271-
"Expecting output to have exactly 1 tensor output. Got %zu outputs.",
272-
outputs.size());
291+
Result<torch::executor::Tensor> logits_res =
292+
run_model_step(cur_token, tokens_managed, start_pos_managed, seq_len);
293+
273294
if (pos == num_prompt_tokens) {
274295
timers_.first_token_ms = util::time_in_ms();
275296
} else if (pos == num_prompt_tokens - 1) {
276297
timers_.prompt_eval_end_ms = util::time_in_ms();
277298
}
278-
int32_t next_tok;
279-
exec_aten::Tensor logits_tensor = outputs.at(logits_index).toTensor();
299+
300+
ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error());
301+
exec_aten::Tensor& logits_tensor = logits_res.get();
302+
303+
prev_token = cur_token;
304+
280305
long sample_start_time_ms = util::time_in_ms();
281306
switch (logits_tensor.scalar_type()) {
282307
case ScalarType::Float: {
283-
next_tok = logitsToToken<float>(logits_tensor, pos, 0);
308+
cur_token = logitsToToken<float>(logits_tensor, pos, 0);
284309
break;
285310
}
286311
case ScalarType::Half: {
287-
next_tok = logitsToToken<exec_aten::Half>(logits_tensor, pos, 0);
312+
cur_token = logitsToToken<exec_aten::Half>(logits_tensor, pos, 0);
288313
break;
289314
}
290315
default:
@@ -299,19 +324,12 @@ Error Runner::generate(
299324
// advance the state machine
300325
if (pos < num_prompt_tokens - 1) {
301326
// prefill, force the next token to be the next prompt token
302-
next = prompt_tokens[pos + 1];
303-
} else {
304-
// otherwise sample the next token from the logits
305-
next = next_tok;
327+
cur_token = prompt_tokens[pos + 1];
306328
}
307-
// ET_LOG(Info, "Output saved, next = %d", next);
308329
pos++;
309-
if (use_kv_cache_) {
310-
pos_data.at(0) = pos;
311-
}
312330

313331
// print the token as string, decode it with the Tokenizer object
314-
auto piece_res = tokenizer_->decode(token, next);
332+
auto piece_res = tokenizer_->decode(prev_token, cur_token);
315333
ET_CHECK(piece_res.ok());
316334
const char* piece = piece_res.get();
317335

@@ -328,22 +346,20 @@ Error Runner::generate(
328346
}
329347

330348
// data-dependent terminating condition: we have n_eos_ number of EOS
331-
if (pos >= num_prompt_tokens && next == eos_id_) {
349+
if (pos >= num_prompt_tokens && cur_token == eos_id_) {
332350
printf("\n");
333351
ET_LOG(Info, "\nReached to the end of generation");
334352
break;
335353
}
336-
337-
token = next;
338354
}
339355
timers_.inference_end_ms = util::time_in_ms();
340356
printf("\n");
341357

342-
if (pos + 1 == seq_len) {
358+
if (pos == seq_len) {
343359
ET_LOG(Info, "Sequence length (%i tokens) reached!", seq_len);
344360
}
345361

346-
timers_.printReport(num_prompt_tokens, (pos + 1) - num_prompt_tokens);
362+
timers_.printReport(num_prompt_tokens, pos - num_prompt_tokens);
347363

348364
delete[] prompt_tokens;
349365
return Error::Ok;

examples/models/llama2/runner/runner.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <executorch/examples/models/llama2/sampler/sampler.h>
2121
#include <executorch/examples/models/llama2/tokenizer/tokenizer.h>
2222
#include <executorch/extension/module/module.h>
23+
#include <executorch/extension/runner_util/managed_tensor.h>
2324

2425
namespace torch::executor {
2526

@@ -45,7 +46,11 @@ class Runner {
4546
template <typename T>
4647
int32_t
4748
logitsToToken(const exec_aten::Tensor& logits_tensor, int64_t pos, T _);
48-
std::vector<exec_aten::SizesType> getKVCacheShape();
49+
Result<torch::executor::Tensor> run_model_step(
50+
int64_t input_token,
51+
ManagedTensor& tokens,
52+
ManagedTensor& start_pos,
53+
size_t max_seq_len);
4954
// metadata
5055
int32_t vocab_size_;
5156
int32_t bos_id_;

examples/models/llama2/runner/targets.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def define_common_targets():
3535
"//executorch/extension/runner_util:managed_tensor" + aten_suffix,
3636
"//executorch/extension/module:module" + aten_suffix,
3737
"//executorch/kernels/quantized:generated_lib" + aten_suffix,
38+
"//executorch/runtime/core/exec_aten:lib" + aten_suffix,
39+
"//executorch/runtime/core/exec_aten/util:tensor_util" + aten_suffix,
3840
] + (_get_operator_lib(aten)) + ([
3941
# Vulkan API currently cannot build on some platforms (e.g. Apple, FBCODE)
4042
# Therefore enable it explicitly for now to avoid failing tests

extension/runner_util/managed_tensor.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88

99
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1010
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
11+
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
12+
#include <executorch/runtime/platform/assert.h>
13+
1114
#ifdef USE_ATEN_LIB
1215
#include <torch/torch.h>
1316
#else
@@ -56,10 +59,20 @@ class ManagedTensor {
5659
data_ptr_,
5760
dim_order_.data(),
5861
strides_.data(),
59-
TensorShapeDynamism::STATIC);
62+
TensorShapeDynamism::DYNAMIC_BOUND);
6063
#endif
6164
}
6265

66+
void resize(const std::vector<SizesType>& new_sizes) {
67+
ET_CHECK_MSG(
68+
new_sizes.size() == sizes_.size(),
69+
"Cannot change rank of a managed tensor");
70+
auto err = resize_tensor(
71+
this->get_aliasing_tensor(),
72+
exec_aten::ArrayRef<SizesType>(new_sizes.data(), new_sizes.size()));
73+
ET_CHECK(err == Error::Ok);
74+
}
75+
6376
/**
6477
* Get the underlying Tensor object. This is assuming the copying is cheap.
6578
*/

extension/runner_util/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,6 @@ def define_common_targets():
3838
],
3939
deps = [
4040
"//executorch/runtime/core/exec_aten:lib" + aten_suffix,
41+
"//executorch/runtime/core/exec_aten/util:tensor_util" + aten_suffix,
4142
],
4243
)

0 commit comments

Comments
 (0)