@@ -208,7 +208,8 @@ Result<torch::executor::Tensor> Runner::run_model_step(
208
208
Error Runner::generate (
209
209
const std::string& prompt,
210
210
int32_t seq_len,
211
- std::function<void (const std::string&)> callback) {
211
+ std::function<void (const std::string&)> on_token_generated_callback,
212
+ std::function<void(const TimeStampsAndStats&)> on_stats_callback) {
212
213
// Prepare the inputs.
213
214
// Use ones-initialized inputs.
214
215
ET_CHECK_MSG (!prompt.empty (), " Prompt cannot be null" );
@@ -364,8 +365,8 @@ Error Runner::generate(
364
365
util::safe_printf (piece);
365
366
fflush (stdout);
366
367
367
- if (callback ) {
368
- callback (piece);
368
+ if (on_token_generated_callback ) {
369
+ on_token_generated_callback (piece);
369
370
}
370
371
371
372
if (shouldStop_) {
@@ -386,18 +387,21 @@ Error Runner::generate(
386
387
ET_LOG (Info, " Sequence length (%i tokens) reached!" , seq_len);
387
388
}
388
389
389
- timers_.printReport (num_prompt_tokens, pos - num_prompt_tokens);
390
+ timers_.num_prompt_tokens = num_prompt_tokens;
391
+ timers_.num_generated_tokens = pos - num_prompt_tokens;
392
+ timers_.printReport ();
393
+ if (on_stats_callback) {
394
+ on_stats_callback (timers_);
395
+ }
390
396
391
397
delete[] prompt_tokens;
392
398
return Error::Ok;
393
399
}
394
400
395
- void Runner::TimeStamps::printReport (
396
- const int64_t & num_prompt_tokens,
397
- const int64_t & num_generated_tokens) {
401
+ void Runner::TimeStampsAndStats::printReport () {
398
402
printf (
399
403
" PyTorchObserver %s\n " ,
400
- toJsonString (num_prompt_tokens, num_generated_tokens ).c_str ());
404
+ toJsonString ().c_str ());
401
405
402
406
ET_LOG (
403
407
Info,
@@ -449,9 +453,7 @@ void Runner::TimeStamps::printReport(
449
453
(double )aggregate_sampling_time_ms / SCALING_FACTOR_UNITS_PER_SECOND);
450
454
}
451
455
452
- const std::string Runner::TimeStamps::toJsonString (
453
- const int64_t & num_prompt_tokens,
454
- const int64_t & num_generated_tokens) {
456
+ const std::string Runner::TimeStampsAndStats::toJsonString () {
455
457
std::stringstream ss;
456
458
ss << " {\" prompt_tokens\" :" << num_prompt_tokens << " ,"
457
459
<< " \" generated_tokens\" :" << num_generated_tokens << " ,"
0 commit comments