Skip to content

Commit cde514c

Browse files
kimishpatelfacebook-github-bot
authored andcommitted
fix runner max seq len (#2688)
Summary: Pull Request resolved: #2688 Max seq len arg when not passed uses the max seq len from the model. THis means, num tokens generated should be equal to kv cache size. However, generate loop tries to generate one more token because pos, 0 based index, is taken for num tokens Reviewed By: mergennachin, digantdesai Differential Revision: D55369776 fbshipit-source-id: 7beb38177a23449649e96184b0b0a0bb507c199f
1 parent 61b1b83 commit cde514c

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

examples/models/llama2/runner/runner.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ Error Runner::generate(
238238
}
239239

240240
// create a 1xN int tensor with next as value
241-
while (pos < seq_len) {
241+
while (pos + 1 < seq_len) {
242242
// ET_LOG(Info, "Generating step %d...", pos);
243243
// set the current token in the tensor
244244
std::vector<EValue> inputs;
@@ -339,11 +339,11 @@ Error Runner::generate(
339339
timers_.inference_end_ms = util::time_in_ms();
340340
printf("\n");
341341

342-
if (pos == seq_len) {
342+
if (pos + 1 == seq_len) {
343343
ET_LOG(Info, "Sequence length (%i tokens) reached!", seq_len);
344344
}
345345

346-
timers_.printReport(num_prompt_tokens, pos - num_prompt_tokens);
346+
timers_.printReport(num_prompt_tokens, (pos + 1) - num_prompt_tokens);
347347

348348
delete[] prompt_tokens;
349349
return Error::Ok;

0 commit comments

Comments
 (0)