Skip to content

Commit caee336

Browse files
Varun Purifacebook-github-bot
authored andcommitted
Add time to first token for llama runner (#2141)
Summary: Pull Request resolved: #2141 Add time to first generated token & other features - Since we're measuring the first token time, the token rate is measured both at the * Model Load Time - just a timer around ET_CHECK_OK_OR_RETURN_ERROR(load()); * Total inference time - Immediately after model load until the end of the inference loop * >>First token time - From immediately after the model load until the first generated (not prompt) token is printed. * >>>>Prompt eval - (comparable to llama.cpp prompt_eval_time) prompt array allocation and tokenization. Ends right before the inference loop starts * >>Remaining tokens - immediately after the first token is outputted until the end of the inference loop * >>Net eval time - (comparable to llama.cpp eval_time) Total time spent generating tokens. * Sample time - amount of time spent sampling per token (present in llama.cpp) bypass-github-executorch-ci-checks bypass-github-pytorch-ci-checks Reviewed By: digantdesai, Jack-Khuu Differential Revision: D54223564 fbshipit-source-id: 3846903b56d20e2d4159fae63de3d471e5677c51
1 parent 9e83fde commit caee336

File tree

2 files changed

+133
-17
lines changed

2 files changed

+133
-17
lines changed

examples/models/llama2/runner/runner.cpp

Lines changed: 103 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include <ctime>
1616
#include <memory>
17+
#include <sstream>
1718

1819
#ifdef USE_ATEN_LIB
1920
#include <torch/torch.h>
@@ -161,8 +162,16 @@ Error Runner::generate(
161162
// Prepare the inputs.
162163
// Use ones-initialized inputs.
163164
ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
164-
ET_CHECK_OK_OR_RETURN_ERROR(load());
165+
if (!is_loaded()) {
166+
timers_.model_load_start_ms = util::time_in_ms();
167+
ET_CHECK_OK_OR_RETURN_ERROR(load());
168+
timers_.model_load_end_ms = util::time_in_ms();
169+
}
170+
171+
// First token time only measures the time it takes to encode the prompt and
172+
// return a response token.
165173

174+
timers_.inference_start_ms = util::time_in_ms();
166175
shouldStop_ = false;
167176

168177
// encode the (string) prompt into tokens sequence
@@ -179,6 +188,7 @@ Error Runner::generate(
179188
append_eos_ ? n_eos_ : 0,
180189
prompt_tokens,
181190
&num_prompt_tokens);
191+
182192
for (int i = 0; i < num_prompt_tokens; i++) {
183193
ET_LOG(Info, "prompt_tokens[%d]: %d", i, prompt_tokens[i]);
184194
}
@@ -192,8 +202,6 @@ Error Runner::generate(
192202
"Sequence length exceeded - please increase the seq_len value passed to generate()");
193203

194204
// start the main loop
195-
long start =
196-
0; // used to time our code, only initialized after first iteration
197205
int next; // will store the next token in the sequence
198206
int64_t pos = num_prompt_tokens - 1; // position in the sequence
199207
int token = prompt_tokens[pos]; // prefill starts from 0 to num_prompt_tokens
@@ -254,6 +262,7 @@ Error Runner::generate(
254262
tokenizer_->decode(prompt_tokens[i - 1], prompt_tokens[i])));
255263
}
256264
}
265+
257266
// create a 1xN int tensor with next as value
258267
while (pos < seq_len) {
259268
// ET_LOG(Info, "Generating step %d...", pos);
@@ -289,10 +298,14 @@ Error Runner::generate(
289298
outputs.size() > 0,
290299
"Expecting output to have at least one evalue. Got %zu",
291300
outputs.size());
292-
301+
if (pos == num_prompt_tokens) {
302+
timers_.first_token_ms = util::time_in_ms();
303+
} else if (pos == num_prompt_tokens - 1) {
304+
timers_.prompt_eval_end_ms = util::time_in_ms();
305+
}
293306
int32_t next_tok;
294307
exec_aten::Tensor logits_tensor = outputs.at(logits_index).toTensor();
295-
308+
long sample_start_time_ms = util::time_in_ms();
296309
switch (logits_tensor.scalar_type()) {
297310
case ScalarType::Float: {
298311
next_tok = logitsToToken<float>(logits_tensor, pos, 0);
@@ -308,6 +321,8 @@ Error Runner::generate(
308321
"Unsupported dtype output %hhd",
309322
static_cast<int8_t>(logits_tensor.scalar_type()));
310323
}
324+
timers_.aggregate_sampling_time_ms +=
325+
util::time_in_ms() - sample_start_time_ms;
311326

312327
// advance the state machine
313328
if (pos < num_prompt_tokens - 1) {
@@ -339,16 +354,13 @@ Error Runner::generate(
339354

340355
// data-dependent terminating condition: we have n_eos_ number of EOS
341356
if (pos >= num_prompt_tokens && next == eos_id_) {
342-
ET_LOG(Info, "Reached to the end of generation");
357+
printf("\n");
358+
ET_LOG(Info, "\nReached to the end of generation");
343359
break;
344360
}
345361

346362
token = next;
347363

348-
// init the timer here because the first iteration can be slower
349-
if (start == 0) {
350-
start = util::time_in_ms();
351-
}
352364
if (use_kv_cache_) {
353365
// outputs: [k_cache, v_cache, logits, k_cache, v_cache]
354366
memcpy(
@@ -361,23 +373,97 @@ Error Runner::generate(
361373
v_data.size());
362374
}
363375
}
376+
timers_.inference_end_ms = util::time_in_ms();
364377
printf("\n");
365378

366379
if (pos == seq_len) {
367380
ET_LOG(Info, "Sequence length (%i tokens) reached!", seq_len);
368381
}
369-
// report achieved tok/s (pos-1 because the timer starts after first
370-
// iteration)
371-
if (pos >= 1) {
372-
long end = util::time_in_ms();
373-
ET_LOG(
374-
Info, "Achieved tok/s: %f\n", (pos - 1) / (double)(end - start) * 1000);
375-
}
382+
383+
timers_.printReport(num_prompt_tokens, pos - num_prompt_tokens);
376384

377385
delete[] prompt_tokens;
378386
return Error::Ok;
379387
}
380388

389+
void Runner::TimeStamps::printReport(
390+
const int64_t& num_prompt_tokens,
391+
const int64_t& num_generated_tokens) {
392+
printf(
393+
"PyTorchObserver %s\n",
394+
toJsonString(num_prompt_tokens, num_generated_tokens).c_str());
395+
396+
ET_LOG(
397+
Info,
398+
"\tPrompt Tokens: %" PRIu64 " Generated Tokens: %" PRIu64,
399+
num_prompt_tokens,
400+
num_generated_tokens);
401+
402+
ET_LOG(
403+
Info,
404+
"\tModel Load Time:\t\t%f (seconds)",
405+
((double)(model_load_end_ms - model_load_start_ms) /
406+
SCALING_FACTOR_UNITS_PER_SECOND));
407+
double inference_time_ms = (double)(inference_end_ms - inference_start_ms);
408+
ET_LOG(
409+
Info,
410+
"\tTotal inference time:\t\t%f (seconds)\t\t Rate: \t%f (tokens/second)",
411+
inference_time_ms / SCALING_FACTOR_UNITS_PER_SECOND,
412+
413+
(num_generated_tokens) / (double)(inference_end_ms - inference_start_ms) *
414+
SCALING_FACTOR_UNITS_PER_SECOND);
415+
double prompt_eval_time = (double)(prompt_eval_end_ms - inference_start_ms);
416+
ET_LOG(
417+
Info,
418+
"\t\tPrompt evaluation:\t%f (seconds)\t\t Rate: \t%f (tokens/second)",
419+
prompt_eval_time / SCALING_FACTOR_UNITS_PER_SECOND,
420+
(num_prompt_tokens) / prompt_eval_time * SCALING_FACTOR_UNITS_PER_SECOND);
421+
422+
double eval_time = (double)(inference_end_ms - prompt_eval_end_ms);
423+
ET_LOG(
424+
Info,
425+
"\t\tGenerated %" PRIu64
426+
" tokens:\t%f (seconds)\t\t Rate: \t%f (tokens/second)",
427+
num_generated_tokens,
428+
eval_time / SCALING_FACTOR_UNITS_PER_SECOND,
429+
num_generated_tokens / eval_time * SCALING_FACTOR_UNITS_PER_SECOND);
430+
431+
// Time to first token is measured from the start of inference, excluding
432+
// model load time.
433+
ET_LOG(
434+
Info,
435+
"\tTime to first generated token:\t%f (seconds)",
436+
((double)(first_token_ms - inference_start_ms) /
437+
SCALING_FACTOR_UNITS_PER_SECOND));
438+
439+
ET_LOG(
440+
Info,
441+
"\tSampling time over %" PRIu64
442+
" tokens:\t%f (seconds)\t\t Rate: \t%f (tokens/second)",
443+
num_prompt_tokens + num_generated_tokens,
444+
(double)aggregate_sampling_time_ms / SCALING_FACTOR_UNITS_PER_SECOND,
445+
(num_prompt_tokens + num_generated_tokens) /
446+
(double)aggregate_sampling_time_ms * SCALING_FACTOR_UNITS_PER_SECOND);
447+
}
448+
449+
const std::string Runner::TimeStamps::toJsonString(
450+
const int64_t& num_prompt_tokens,
451+
const int64_t& num_generated_tokens) {
452+
std::stringstream ss;
453+
ss << "{\"prompt_tokens\":" << num_prompt_tokens << ","
454+
<< "\"generated_tokens\":" << num_generated_tokens << ","
455+
<< "\"model_load_start_ms\":" << model_load_start_ms << ","
456+
<< "\"model_load_end_ms\":" << model_load_end_ms << ","
457+
<< "\"inference_start_ms\":" << inference_start_ms << ","
458+
<< "\"inference_end_ms\":" << inference_end_ms << ","
459+
<< "\"prompt_eval_end_ms\":" << prompt_eval_end_ms << ","
460+
<< "\"first_token_ms\":" << first_token_ms << ","
461+
<< "\"aggregate_sampling_time_ms\":" << aggregate_sampling_time_ms << ","
462+
<< "\"SCALING_FACTOR_UNITS_PER_SECOND\":"
463+
<< SCALING_FACTOR_UNITS_PER_SECOND << "}";
464+
return ss.str();
465+
}
466+
381467
void Runner::stop() {
382468
shouldStop_ = true;
383469
}

examples/models/llama2/runner/runner.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,36 @@ class Runner {
6363
std::unique_ptr<Tokenizer> tokenizer_;
6464
std::unique_ptr<Sampler> sampler_;
6565
bool shouldStop_{false};
66+
67+
struct TimeStamps {
68+
// Scaling factor for timestamps - in this case, we use ms.
69+
const long SCALING_FACTOR_UNITS_PER_SECOND = 1000;
70+
// Time stamps for the different stages of the execution
71+
// model_load_start_ms: Start of model loading.
72+
long model_load_start_ms;
73+
// model_load_end_ms: End of model loading.
74+
long model_load_end_ms;
75+
// inference_start_ms: Immediately after the model is loaded (or we check
76+
// for model load), measure the inference time.
77+
long inference_start_ms;
78+
// prompt_eval_end_ms: Prompt array allocation and tokenization. Ends right
79+
// before the inference loop starts
80+
long prompt_eval_end_ms;
81+
// first_token: Timestamp when the first generated token is emitted
82+
long first_token_ms;
83+
// inference_end_ms: End of inference/generation.
84+
long inference_end_ms;
85+
// Keep a running total of the time spent in sampling.
86+
long aggregate_sampling_time_ms;
87+
88+
void printReport(
89+
const int64_t& num_prompt_tokens,
90+
const int64_t& num_generated_tokens);
91+
const std::string toJsonString(
92+
const int64_t& num_prompt_tokens,
93+
const int64_t& num_generated_tokens);
94+
};
95+
TimeStamps timers_;
6696
};
6797

6898
} // namespace torch::executor

0 commit comments

Comments
 (0)