Skip to content

Commit 5a628a9

Browse files
kirklandsignfacebook-github-bot
authored andcommitted
Expose timestamp stats (#2794)
Summary: Pull Request resolved: #2794 Reviewed By: shoumikhin Differential Revision: D55604786 Pulled By: kirklandsign
1 parent 88b6cd2 commit 5a628a9

File tree

2 files changed

+90
-80
lines changed

2 files changed

+90
-80
lines changed

examples/models/llama2/runner/runner.cpp

Lines changed: 61 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
namespace torch::executor {
3030
namespace {
3131
static constexpr auto kTopp = 0.9f;
32+
void printReport(const Runner::Stats& stats);
33+
std::string statsToJsonString(const Runner::Stats& stats);
3234
} // namespace
3335

3436
Runner::Runner(
@@ -208,20 +210,21 @@ Result<torch::executor::Tensor> Runner::run_model_step(
208210
Error Runner::generate(
209211
const std::string& prompt,
210212
int32_t seq_len,
211-
std::function<void(const std::string&)> callback) {
213+
std::function<void(const std::string&)> token_callback,
214+
std::function<void(const Stats&)> stats_callback) {
212215
// Prepare the inputs.
213216
// Use ones-initialized inputs.
214217
ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
215218
if (!is_loaded()) {
216-
timers_.model_load_start_ms = util::time_in_ms();
219+
stats_.model_load_start_ms = util::time_in_ms();
217220
ET_CHECK_OK_OR_RETURN_ERROR(load());
218-
timers_.model_load_end_ms = util::time_in_ms();
221+
stats_.model_load_end_ms = util::time_in_ms();
219222
}
220223

221224
// First token time only measures the time it takes to encode the prompt and
222225
// return a response token.
223226

224-
timers_.inference_start_ms = util::time_in_ms();
227+
stats_.inference_start_ms = util::time_in_ms();
225228
shouldStop_ = false;
226229

227230
// encode the (string) prompt into tokens sequence
@@ -319,9 +322,9 @@ Error Runner::generate(
319322
run_model_step(cur_token, tokens_managed, start_pos_managed, seq_len);
320323

321324
if (pos == num_prompt_tokens) {
322-
timers_.first_token_ms = util::time_in_ms();
325+
stats_.first_token_ms = util::time_in_ms();
323326
} else if (pos == num_prompt_tokens - 1) {
324-
timers_.prompt_eval_end_ms = util::time_in_ms();
327+
stats_.prompt_eval_end_ms = util::time_in_ms();
325328
}
326329

327330
ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error());
@@ -345,7 +348,7 @@ Error Runner::generate(
345348
"Unsupported dtype output %hhd",
346349
static_cast<int8_t>(logits_tensor.scalar_type()));
347350
}
348-
timers_.aggregate_sampling_time_ms +=
351+
stats_.aggregate_sampling_time_ms +=
349352
util::time_in_ms() - sample_start_time_ms;
350353

351354
// advance the state machine
@@ -364,8 +367,8 @@ Error Runner::generate(
364367
util::safe_printf(piece);
365368
fflush(stdout);
366369

367-
if (callback) {
368-
callback(piece);
370+
if (token_callback) {
371+
token_callback(piece);
369372
}
370373

371374
if (shouldStop_) {
@@ -379,93 +382,102 @@ Error Runner::generate(
379382
break;
380383
}
381384
}
382-
timers_.inference_end_ms = util::time_in_ms();
385+
stats_.inference_end_ms = util::time_in_ms();
383386
printf("\n");
384387

385388
if (pos == seq_len) {
386389
ET_LOG(Info, "Sequence length (%i tokens) reached!", seq_len);
387390
}
388391

389-
timers_.printReport(num_prompt_tokens, pos - num_prompt_tokens);
392+
stats_.num_prompt_tokens = num_prompt_tokens;
393+
stats_.num_generated_tokens = pos - num_prompt_tokens;
394+
printReport(stats_);
395+
if (stats_callback) {
396+
stats_callback(stats_);
397+
}
390398

391399
delete[] prompt_tokens;
392400
return Error::Ok;
393401
}
394402

395-
void Runner::TimeStamps::printReport(
396-
const int64_t& num_prompt_tokens,
397-
const int64_t& num_generated_tokens) {
398-
printf(
399-
"PyTorchObserver %s\n",
400-
toJsonString(num_prompt_tokens, num_generated_tokens).c_str());
403+
namespace {
404+
void printReport(const Runner::Stats& stats) {
405+
printf("PyTorchObserver %s\n", statsToJsonString(stats).c_str());
401406

402407
ET_LOG(
403408
Info,
404409
"\tPrompt Tokens: %" PRIu64 " Generated Tokens: %" PRIu64,
405-
num_prompt_tokens,
406-
num_generated_tokens);
410+
stats.num_prompt_tokens,
411+
stats.num_generated_tokens);
407412

408413
ET_LOG(
409414
Info,
410415
"\tModel Load Time:\t\t%f (seconds)",
411-
((double)(model_load_end_ms - model_load_start_ms) /
412-
SCALING_FACTOR_UNITS_PER_SECOND));
413-
double inference_time_ms = (double)(inference_end_ms - inference_start_ms);
416+
((double)(stats.model_load_end_ms - stats.model_load_start_ms) /
417+
stats.SCALING_FACTOR_UNITS_PER_SECOND));
418+
double inference_time_ms =
419+
(double)(stats.inference_end_ms - stats.inference_start_ms);
414420
ET_LOG(
415421
Info,
416422
"\tTotal inference time:\t\t%f (seconds)\t\t Rate: \t%f (tokens/second)",
417-
inference_time_ms / SCALING_FACTOR_UNITS_PER_SECOND,
423+
inference_time_ms / stats.SCALING_FACTOR_UNITS_PER_SECOND,
418424

419-
(num_generated_tokens) / (double)(inference_end_ms - inference_start_ms) *
420-
SCALING_FACTOR_UNITS_PER_SECOND);
421-
double prompt_eval_time = (double)(prompt_eval_end_ms - inference_start_ms);
425+
(stats.num_generated_tokens) /
426+
(double)(stats.inference_end_ms - stats.inference_start_ms) *
427+
stats.SCALING_FACTOR_UNITS_PER_SECOND);
428+
double prompt_eval_time =
429+
(double)(stats.prompt_eval_end_ms - stats.inference_start_ms);
422430
ET_LOG(
423431
Info,
424432
"\t\tPrompt evaluation:\t%f (seconds)\t\t Rate: \t%f (tokens/second)",
425-
prompt_eval_time / SCALING_FACTOR_UNITS_PER_SECOND,
426-
(num_prompt_tokens) / prompt_eval_time * SCALING_FACTOR_UNITS_PER_SECOND);
433+
prompt_eval_time / stats.SCALING_FACTOR_UNITS_PER_SECOND,
434+
(stats.num_prompt_tokens) / prompt_eval_time *
435+
stats.SCALING_FACTOR_UNITS_PER_SECOND);
427436

428-
double eval_time = (double)(inference_end_ms - prompt_eval_end_ms);
437+
double eval_time =
438+
(double)(stats.inference_end_ms - stats.prompt_eval_end_ms);
429439
ET_LOG(
430440
Info,
431441
"\t\tGenerated %" PRIu64
432442
" tokens:\t%f (seconds)\t\t Rate: \t%f (tokens/second)",
433-
num_generated_tokens,
434-
eval_time / SCALING_FACTOR_UNITS_PER_SECOND,
435-
num_generated_tokens / eval_time * SCALING_FACTOR_UNITS_PER_SECOND);
443+
stats.num_generated_tokens,
444+
eval_time / stats.SCALING_FACTOR_UNITS_PER_SECOND,
445+
stats.num_generated_tokens / eval_time *
446+
stats.SCALING_FACTOR_UNITS_PER_SECOND);
436447

437448
// Time to first token is measured from the start of inference, excluding
438449
// model load time.
439450
ET_LOG(
440451
Info,
441452
"\tTime to first generated token:\t%f (seconds)",
442-
((double)(first_token_ms - inference_start_ms) /
443-
SCALING_FACTOR_UNITS_PER_SECOND));
453+
((double)(stats.first_token_ms - stats.inference_start_ms) /
454+
stats.SCALING_FACTOR_UNITS_PER_SECOND));
444455

445456
ET_LOG(
446457
Info,
447458
"\tSampling time over %" PRIu64 " tokens:\t%f (seconds)",
448-
num_prompt_tokens + num_generated_tokens,
449-
(double)aggregate_sampling_time_ms / SCALING_FACTOR_UNITS_PER_SECOND);
459+
stats.num_prompt_tokens + stats.num_generated_tokens,
460+
(double)stats.aggregate_sampling_time_ms /
461+
stats.SCALING_FACTOR_UNITS_PER_SECOND);
450462
}
451463

452-
const std::string Runner::TimeStamps::toJsonString(
453-
const int64_t& num_prompt_tokens,
454-
const int64_t& num_generated_tokens) {
464+
std::string statsToJsonString(const Runner::Stats& stats) {
455465
std::stringstream ss;
456-
ss << "{\"prompt_tokens\":" << num_prompt_tokens << ","
457-
<< "\"generated_tokens\":" << num_generated_tokens << ","
458-
<< "\"model_load_start_ms\":" << model_load_start_ms << ","
459-
<< "\"model_load_end_ms\":" << model_load_end_ms << ","
460-
<< "\"inference_start_ms\":" << inference_start_ms << ","
461-
<< "\"inference_end_ms\":" << inference_end_ms << ","
462-
<< "\"prompt_eval_end_ms\":" << prompt_eval_end_ms << ","
463-
<< "\"first_token_ms\":" << first_token_ms << ","
464-
<< "\"aggregate_sampling_time_ms\":" << aggregate_sampling_time_ms << ","
466+
ss << "{\"prompt_tokens\":" << stats.num_prompt_tokens << ","
467+
<< "\"generated_tokens\":" << stats.num_generated_tokens << ","
468+
<< "\"model_load_start_ms\":" << stats.model_load_start_ms << ","
469+
<< "\"model_load_end_ms\":" << stats.model_load_end_ms << ","
470+
<< "\"inference_start_ms\":" << stats.inference_start_ms << ","
471+
<< "\"inference_end_ms\":" << stats.inference_end_ms << ","
472+
<< "\"prompt_eval_end_ms\":" << stats.prompt_eval_end_ms << ","
473+
<< "\"first_token_ms\":" << stats.first_token_ms << ","
474+
<< "\"aggregate_sampling_time_ms\":" << stats.aggregate_sampling_time_ms
475+
<< ","
465476
<< "\"SCALING_FACTOR_UNITS_PER_SECOND\":"
466-
<< SCALING_FACTOR_UNITS_PER_SECOND << "}";
477+
<< stats.SCALING_FACTOR_UNITS_PER_SECOND << "}";
467478
return ss.str();
468479
}
480+
} // namespace
469481

470482
void Runner::stop() {
471483
shouldStop_ = true;

examples/models/llama2/runner/runner.h

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,39 @@ class Runner {
3131
const std::string& tokenizer_path,
3232
const float temperature = 0.8f);
3333

34+
struct Stats {
35+
// Scaling factor for timestamps - in this case, we use ms.
36+
const long SCALING_FACTOR_UNITS_PER_SECOND = 1000;
37+
// Time stamps for the different stages of the execution
38+
// model_load_start_ms: Start of model loading.
39+
long model_load_start_ms;
40+
// model_load_end_ms: End of model loading.
41+
long model_load_end_ms;
42+
// inference_start_ms: Immediately after the model is loaded (or we check
43+
// for model load), measure the inference time.
44+
long inference_start_ms;
45+
// prompt_eval_end_ms: Prompt array allocation and tokenization. Ends right
46+
// before the inference loop starts
47+
long prompt_eval_end_ms;
48+
// first_token: Timestamp when the first generated token is emitted
49+
long first_token_ms;
50+
// inference_end_ms: End of inference/generation.
51+
long inference_end_ms;
52+
// Keep a running total of the time spent in sampling.
53+
long aggregate_sampling_time_ms;
54+
// Token count from prompt
55+
int64_t num_prompt_tokens;
56+
// Token count from generated (total - prompt)
57+
int64_t num_generated_tokens;
58+
};
59+
3460
bool is_loaded() const;
3561
Error load();
3662
Error generate(
3763
const std::string& prompt,
3864
int32_t seq_len = 128,
39-
std::function<void(const std::string&)> callback = {});
65+
std::function<void(const std::string&)> token_callback = {},
66+
std::function<void(const Stats&)> stats_callback = {});
4067
void stop();
4168

4269
private:
@@ -68,36 +95,7 @@ class Runner {
6895
std::unique_ptr<Tokenizer> tokenizer_;
6996
std::unique_ptr<Sampler> sampler_;
7097
bool shouldStop_{false};
71-
72-
struct TimeStamps {
73-
// Scaling factor for timestamps - in this case, we use ms.
74-
const long SCALING_FACTOR_UNITS_PER_SECOND = 1000;
75-
// Time stamps for the different stages of the execution
76-
// model_load_start_ms: Start of model loading.
77-
long model_load_start_ms;
78-
// model_load_end_ms: End of model loading.
79-
long model_load_end_ms;
80-
// inference_start_ms: Immediately after the model is loaded (or we check
81-
// for model load), measure the inference time.
82-
long inference_start_ms;
83-
// prompt_eval_end_ms: Prompt array allocation and tokenization. Ends right
84-
// before the inference loop starts
85-
long prompt_eval_end_ms;
86-
// first_token: Timestamp when the first generated token is emitted
87-
long first_token_ms;
88-
// inference_end_ms: End of inference/generation.
89-
long inference_end_ms;
90-
// Keep a running total of the time spent in sampling.
91-
long aggregate_sampling_time_ms;
92-
93-
void printReport(
94-
const int64_t& num_prompt_tokens,
95-
const int64_t& num_generated_tokens);
96-
const std::string toJsonString(
97-
const int64_t& num_prompt_tokens,
98-
const int64_t& num_generated_tokens);
99-
};
100-
TimeStamps timers_;
98+
Stats stats_;
10199
};
102100

103101
} // namespace torch::executor

0 commit comments

Comments
 (0)