Skip to content

Commit d24818a

Browse files
kirklandsignfacebook-github-bot
authored andcommitted
Expose timestamp stats (#2794)
Summary: Pull Request resolved: #2794 Differential Revision: D55604786 Pulled By: kirklandsign
1 parent 399482c commit d24818a

File tree

2 files changed

+45
-44
lines changed

2 files changed

+45
-44
lines changed

examples/models/llama2/runner/runner.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,8 @@ Result<torch::executor::Tensor> Runner::run_model_step(
208208
Error Runner::generate(
209209
const std::string& prompt,
210210
int32_t seq_len,
211-
std::function<void(const std::string&)> callback) {
211+
std::function<void(const std::string&)> on_token_generated_callback,
212+
std::function<void(const TimeStampsAndStats&)> on_stats_callback) {
212213
// Prepare the inputs.
213214
// Use ones-initialized inputs.
214215
ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
@@ -364,8 +365,8 @@ Error Runner::generate(
364365
util::safe_printf(piece);
365366
fflush(stdout);
366367

367-
if (callback) {
368-
callback(piece);
368+
if (on_token_generated_callback) {
369+
on_token_generated_callback(piece);
369370
}
370371

371372
if (shouldStop_) {
@@ -386,18 +387,19 @@ Error Runner::generate(
386387
ET_LOG(Info, "Sequence length (%i tokens) reached!", seq_len);
387388
}
388389

389-
timers_.printReport(num_prompt_tokens, pos - num_prompt_tokens);
390+
timers_.num_prompt_tokens = num_prompt_tokens;
391+
timers_.num_generated_tokens = pos - num_prompt_tokens;
392+
timers_.printReport();
393+
if (on_stats_callback) {
394+
on_stats_callback(timers_);
395+
}
390396

391397
delete[] prompt_tokens;
392398
return Error::Ok;
393399
}
394400

395-
void Runner::TimeStamps::printReport(
396-
const int64_t& num_prompt_tokens,
397-
const int64_t& num_generated_tokens) {
398-
printf(
399-
"PyTorchObserver %s\n",
400-
toJsonString(num_prompt_tokens, num_generated_tokens).c_str());
401+
void Runner::TimeStampsAndStats::printReport() {
402+
printf("PyTorchObserver %s\n", toJsonString().c_str());
401403

402404
ET_LOG(
403405
Info,
@@ -449,9 +451,7 @@ void Runner::TimeStamps::printReport(
449451
(double)aggregate_sampling_time_ms / SCALING_FACTOR_UNITS_PER_SECOND);
450452
}
451453

452-
const std::string Runner::TimeStamps::toJsonString(
453-
const int64_t& num_prompt_tokens,
454-
const int64_t& num_generated_tokens) {
454+
const std::string Runner::TimeStampsAndStats::toJsonString() {
455455
std::stringstream ss;
456456
ss << "{\"prompt_tokens\":" << num_prompt_tokens << ","
457457
<< "\"generated_tokens\":" << num_generated_tokens << ","

examples/models/llama2/runner/runner.h

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,42 @@ class Runner {
3131
const std::string& tokenizer_path,
3232
const float temperature = 0.8f);
3333

34+
struct TimeStampsAndStats {
35+
// Scaling factor for timestamps - in this case, we use ms.
36+
const long SCALING_FACTOR_UNITS_PER_SECOND = 1000;
37+
// Time stamps for the different stages of the execution
38+
// model_load_start_ms: Start of model loading.
39+
long model_load_start_ms;
40+
// model_load_end_ms: End of model loading.
41+
long model_load_end_ms;
42+
// inference_start_ms: Immediately after the model is loaded (or we check
43+
// for model load), measure the inference time.
44+
long inference_start_ms;
45+
// prompt_eval_end_ms: Prompt array allocation and tokenization. Ends right
46+
// before the inference loop starts
47+
long prompt_eval_end_ms;
48+
// first_token: Timestamp when the first generated token is emitted
49+
long first_token_ms;
50+
// inference_end_ms: End of inference/generation.
51+
long inference_end_ms;
52+
// Keep a running total of the time spent in sampling.
53+
long aggregate_sampling_time_ms;
54+
// Token count from prompt
55+
int64_t num_prompt_tokens;
56+
// Token count from generated (total - prompt)
57+
int64_t num_generated_tokens;
58+
59+
void printReport();
60+
const std::string toJsonString();
61+
};
62+
3463
bool is_loaded() const;
3564
Error load();
3665
Error generate(
3766
const std::string& prompt,
3867
int32_t seq_len = 128,
39-
std::function<void(const std::string&)> callback = {});
68+
std::function<void(const std::string&)> on_token_generated_callback = {},
69+
std::function<void(const TimeStampsAndStats&)> on_stats_callback = {});
4070
void stop();
4171

4272
private:
@@ -68,36 +98,7 @@ class Runner {
6898
std::unique_ptr<Tokenizer> tokenizer_;
6999
std::unique_ptr<Sampler> sampler_;
70100
bool shouldStop_{false};
71-
72-
struct TimeStamps {
73-
// Scaling factor for timestamps - in this case, we use ms.
74-
const long SCALING_FACTOR_UNITS_PER_SECOND = 1000;
75-
// Time stamps for the different stages of the execution
76-
// model_load_start_ms: Start of model loading.
77-
long model_load_start_ms;
78-
// model_load_end_ms: End of model loading.
79-
long model_load_end_ms;
80-
// inference_start_ms: Immediately after the model is loaded (or we check
81-
// for model load), measure the inference time.
82-
long inference_start_ms;
83-
// prompt_eval_end_ms: Prompt array allocation and tokenization. Ends right
84-
// before the inference loop starts
85-
long prompt_eval_end_ms;
86-
// first_token: Timestamp when the first generated token is emitted
87-
long first_token_ms;
88-
// inference_end_ms: End of inference/generation.
89-
long inference_end_ms;
90-
// Keep a running total of the time spent in sampling.
91-
long aggregate_sampling_time_ms;
92-
93-
void printReport(
94-
const int64_t& num_prompt_tokens,
95-
const int64_t& num_generated_tokens);
96-
const std::string toJsonString(
97-
const int64_t& num_prompt_tokens,
98-
const int64_t& num_generated_tokens);
99-
};
100-
TimeStamps timers_;
101+
TimeStampsAndStats timers_;
101102
};
102103

103104
} // namespace torch::executor

0 commit comments

Comments
 (0)