Skip to content

Commit dff3368

Browse files
committed
Use a callback to report TimeStampsAndStats to users
We keep track of num_prompt_tokens and num_generated_tokens along with the timers. Then we report everything back to user.
1 parent 5efc1c4 commit dff3368

File tree

2 files changed

+30
-27
lines changed

2 files changed

+30
-27
lines changed

examples/models/llama2/runner/runner.cpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,8 @@ Result<torch::executor::Tensor> Runner::run_model_step(
208208
Error Runner::generate(
209209
const std::string& prompt,
210210
int32_t seq_len,
211-
std::function<void(const std::string&)> callback) {
211+
std::function<void(const std::string&)> on_token_generated_callback,
212+
std::function<void(const TimeStampsAndStats&)> on_stats_callback) {
212213
// Prepare the inputs.
213214
// Use ones-initialized inputs.
214215
ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
@@ -364,8 +365,8 @@ Error Runner::generate(
364365
util::safe_printf(piece);
365366
fflush(stdout);
366367

367-
if (callback) {
368-
callback(piece);
368+
if (on_token_generated_callback) {
369+
on_token_generated_callback(piece);
369370
}
370371

371372
if (shouldStop_) {
@@ -386,18 +387,21 @@ Error Runner::generate(
386387
ET_LOG(Info, "Sequence length (%i tokens) reached!", seq_len);
387388
}
388389

389-
timers_.printReport(num_prompt_tokens, pos - num_prompt_tokens);
390+
timers_.num_prompt_tokens = num_prompt_tokens;
391+
timers_.num_generated_tokens = pos - num_prompt_tokens;
392+
timers_.printReport();
393+
if (on_stats_callback) {
394+
on_stats_callback(timers_);
395+
}
390396

391397
delete[] prompt_tokens;
392398
return Error::Ok;
393399
}
394400

395-
void Runner::TimeStamps::printReport(
396-
const int64_t& num_prompt_tokens,
397-
const int64_t& num_generated_tokens) {
401+
void Runner::TimeStampsAndStats::printReport() {
398402
printf(
399403
"PyTorchObserver %s\n",
400-
toJsonString(num_prompt_tokens, num_generated_tokens).c_str());
404+
toJsonString().c_str());
401405

402406
ET_LOG(
403407
Info,
@@ -449,9 +453,7 @@ void Runner::TimeStamps::printReport(
449453
(double)aggregate_sampling_time_ms / SCALING_FACTOR_UNITS_PER_SECOND);
450454
}
451455

452-
const std::string Runner::TimeStamps::toJsonString(
453-
const int64_t& num_prompt_tokens,
454-
const int64_t& num_generated_tokens) {
456+
const std::string Runner::TimeStampsAndStats::toJsonString() {
455457
std::stringstream ss;
456458
ss << "{\"prompt_tokens\":" << num_prompt_tokens << ","
457459
<< "\"generated_tokens\":" << num_generated_tokens << ","

examples/models/llama2/runner/runner.h

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,7 @@ class Runner {
3131
const std::string& tokenizer_path,
3232
const float temperature = 0.8f);
3333

34-
bool is_loaded() const;
35-
Error load();
36-
Error generate(
37-
const std::string& prompt,
38-
int32_t seq_len = 128,
39-
std::function<void(const std::string&)> callback = {});
40-
void stop();
41-
42-
struct TimeStamps {
34+
struct TimeStampsAndStats {
4335
// Scaling factor for timestamps - in this case, we use ms.
4436
const long SCALING_FACTOR_UNITS_PER_SECOND = 1000;
4537
// Time stamps for the different stages of the execution
@@ -59,15 +51,23 @@ class Runner {
5951
long inference_end_ms;
6052
// Keep a running total of the time spent in sampling.
6153
long aggregate_sampling_time_ms;
54+
// Token count from prompt
55+
int64_t num_prompt_tokens;
56+
// Token count from generated (total - prompt)
57+
int64_t num_generated_tokens;
6258

63-
void printReport(
64-
const int64_t& num_prompt_tokens,
65-
const int64_t& num_generated_tokens);
66-
const std::string toJsonString(
67-
const int64_t& num_prompt_tokens,
68-
const int64_t& num_generated_tokens);
59+
void printReport();
60+
const std::string toJsonString();
6961
};
70-
TimeStamps timers_;
62+
63+
bool is_loaded() const;
64+
Error load();
65+
Error generate(
66+
const std::string& prompt,
67+
int32_t seq_len = 128,
68+
std::function<void(const std::string&)> on_token_generated_callback = {},
69+
std::function<void(const TimeStampsAndStats&)> on_stats_callback = {});
70+
void stop();
7171

7272
private:
7373
// metadata
@@ -98,6 +98,7 @@ class Runner {
9898
std::unique_ptr<Tokenizer> tokenizer_;
9999
std::unique_ptr<Sampler> sampler_;
100100
bool shouldStop_{false};
101+
TimeStampsAndStats timers_;
101102
};
102103

103104
} // namespace torch::executor

0 commit comments

Comments
 (0)