29
29
namespace torch ::executor {
30
30
namespace {
31
31
static constexpr auto kTopp = 0 .9f ;
32
+ void printReport (const Runner::Stats& stats);
33
+ std::string statsToJsonString (const Runner::Stats& stats);
32
34
} // namespace
33
35
34
36
Runner::Runner (
@@ -208,20 +210,21 @@ Result<torch::executor::Tensor> Runner::run_model_step(
208
210
Error Runner::generate (
209
211
const std::string& prompt,
210
212
int32_t seq_len,
211
- std::function<void (const std::string&)> callback) {
213
+ std::function<void (const std::string&)> token_callback,
214
+ std::function<void(const Stats&)> stats_callback) {
212
215
// Prepare the inputs.
213
216
// Use ones-initialized inputs.
214
217
ET_CHECK_MSG (!prompt.empty (), " Prompt cannot be null" );
215
218
if (!is_loaded ()) {
216
- timers_ .model_load_start_ms = util::time_in_ms ();
219
+ stats_ .model_load_start_ms = util::time_in_ms ();
217
220
ET_CHECK_OK_OR_RETURN_ERROR (load ());
218
- timers_ .model_load_end_ms = util::time_in_ms ();
221
+ stats_ .model_load_end_ms = util::time_in_ms ();
219
222
}
220
223
221
224
// First token time only measures the time it takes to encode the prompt and
222
225
// return a response token.
223
226
224
- timers_ .inference_start_ms = util::time_in_ms ();
227
+ stats_ .inference_start_ms = util::time_in_ms ();
225
228
shouldStop_ = false ;
226
229
227
230
// encode the (string) prompt into tokens sequence
@@ -319,9 +322,9 @@ Error Runner::generate(
319
322
run_model_step (cur_token, tokens_managed, start_pos_managed, seq_len);
320
323
321
324
if (pos == num_prompt_tokens) {
322
- timers_ .first_token_ms = util::time_in_ms ();
325
+ stats_ .first_token_ms = util::time_in_ms ();
323
326
} else if (pos == num_prompt_tokens - 1 ) {
324
- timers_ .prompt_eval_end_ms = util::time_in_ms ();
327
+ stats_ .prompt_eval_end_ms = util::time_in_ms ();
325
328
}
326
329
327
330
ET_CHECK_OK_OR_RETURN_ERROR (logits_res.error ());
@@ -345,7 +348,7 @@ Error Runner::generate(
345
348
" Unsupported dtype output %hhd" ,
346
349
static_cast <int8_t >(logits_tensor.scalar_type ()));
347
350
}
348
- timers_ .aggregate_sampling_time_ms +=
351
+ stats_ .aggregate_sampling_time_ms +=
349
352
util::time_in_ms () - sample_start_time_ms;
350
353
351
354
// advance the state machine
@@ -364,8 +367,8 @@ Error Runner::generate(
364
367
util::safe_printf (piece);
365
368
fflush (stdout);
366
369
367
- if (callback ) {
368
- callback (piece);
370
+ if (token_callback ) {
371
+ token_callback (piece);
369
372
}
370
373
371
374
if (shouldStop_) {
@@ -379,93 +382,102 @@ Error Runner::generate(
379
382
break ;
380
383
}
381
384
}
382
- timers_ .inference_end_ms = util::time_in_ms ();
385
+ stats_ .inference_end_ms = util::time_in_ms ();
383
386
printf (" \n " );
384
387
385
388
if (pos == seq_len) {
386
389
ET_LOG (Info, " Sequence length (%i tokens) reached!" , seq_len);
387
390
}
388
391
389
- timers_.printReport (num_prompt_tokens, pos - num_prompt_tokens);
392
+ stats_.num_prompt_tokens = num_prompt_tokens;
393
+ stats_.num_generated_tokens = pos - num_prompt_tokens;
394
+ printReport (stats_);
395
+ if (stats_callback) {
396
+ stats_callback (stats_);
397
+ }
390
398
391
399
delete[] prompt_tokens;
392
400
return Error::Ok;
393
401
}
394
402
395
- void Runner::TimeStamps::printReport (
396
- const int64_t & num_prompt_tokens,
397
- const int64_t & num_generated_tokens) {
398
- printf (
399
- " PyTorchObserver %s\n " ,
400
- toJsonString (num_prompt_tokens, num_generated_tokens).c_str ());
403
+ namespace {
404
+ void printReport (const Runner::Stats& stats) {
405
+ printf (" PyTorchObserver %s\n " , statsToJsonString (stats).c_str ());
401
406
402
407
ET_LOG (
403
408
Info,
404
409
" \t Prompt Tokens: %" PRIu64 " Generated Tokens: %" PRIu64,
405
- num_prompt_tokens,
406
- num_generated_tokens);
410
+ stats. num_prompt_tokens ,
411
+ stats. num_generated_tokens );
407
412
408
413
ET_LOG (
409
414
Info,
410
415
" \t Model Load Time:\t\t %f (seconds)" ,
411
- ((double )(model_load_end_ms - model_load_start_ms) /
412
- SCALING_FACTOR_UNITS_PER_SECOND));
413
- double inference_time_ms = (double )(inference_end_ms - inference_start_ms);
416
+ ((double )(stats.model_load_end_ms - stats.model_load_start_ms ) /
417
+ stats.SCALING_FACTOR_UNITS_PER_SECOND ));
418
+ double inference_time_ms =
419
+ (double )(stats.inference_end_ms - stats.inference_start_ms );
414
420
ET_LOG (
415
421
Info,
416
422
" \t Total inference time:\t\t %f (seconds)\t\t Rate: \t %f (tokens/second)" ,
417
- inference_time_ms / SCALING_FACTOR_UNITS_PER_SECOND,
423
+ inference_time_ms / stats. SCALING_FACTOR_UNITS_PER_SECOND ,
418
424
419
- (num_generated_tokens) / (double )(inference_end_ms - inference_start_ms) *
420
- SCALING_FACTOR_UNITS_PER_SECOND);
421
- double prompt_eval_time = (double )(prompt_eval_end_ms - inference_start_ms);
425
+ (stats.num_generated_tokens ) /
426
+ (double )(stats.inference_end_ms - stats.inference_start_ms ) *
427
+ stats.SCALING_FACTOR_UNITS_PER_SECOND );
428
+ double prompt_eval_time =
429
+ (double )(stats.prompt_eval_end_ms - stats.inference_start_ms );
422
430
ET_LOG (
423
431
Info,
424
432
" \t\t Prompt evaluation:\t %f (seconds)\t\t Rate: \t %f (tokens/second)" ,
425
- prompt_eval_time / SCALING_FACTOR_UNITS_PER_SECOND,
426
- (num_prompt_tokens) / prompt_eval_time * SCALING_FACTOR_UNITS_PER_SECOND);
433
+ prompt_eval_time / stats.SCALING_FACTOR_UNITS_PER_SECOND ,
434
+ (stats.num_prompt_tokens ) / prompt_eval_time *
435
+ stats.SCALING_FACTOR_UNITS_PER_SECOND );
427
436
428
- double eval_time = (double )(inference_end_ms - prompt_eval_end_ms);
437
+ double eval_time =
438
+ (double )(stats.inference_end_ms - stats.prompt_eval_end_ms );
429
439
ET_LOG (
430
440
Info,
431
441
" \t\t Generated %" PRIu64
432
442
" tokens:\t %f (seconds)\t\t Rate: \t %f (tokens/second)" ,
433
- num_generated_tokens,
434
- eval_time / SCALING_FACTOR_UNITS_PER_SECOND,
435
- num_generated_tokens / eval_time * SCALING_FACTOR_UNITS_PER_SECOND);
443
+ stats.num_generated_tokens ,
444
+ eval_time / stats.SCALING_FACTOR_UNITS_PER_SECOND ,
445
+ stats.num_generated_tokens / eval_time *
446
+ stats.SCALING_FACTOR_UNITS_PER_SECOND );
436
447
437
448
// Time to first token is measured from the start of inference, excluding
438
449
// model load time.
439
450
ET_LOG (
440
451
Info,
441
452
" \t Time to first generated token:\t %f (seconds)" ,
442
- ((double )(first_token_ms - inference_start_ms) /
443
- SCALING_FACTOR_UNITS_PER_SECOND));
453
+ ((double )(stats. first_token_ms - stats. inference_start_ms ) /
454
+ stats. SCALING_FACTOR_UNITS_PER_SECOND ));
444
455
445
456
ET_LOG (
446
457
Info,
447
458
" \t Sampling time over %" PRIu64 " tokens:\t %f (seconds)" ,
448
- num_prompt_tokens + num_generated_tokens,
449
- (double )aggregate_sampling_time_ms / SCALING_FACTOR_UNITS_PER_SECOND);
459
+ stats.num_prompt_tokens + stats.num_generated_tokens ,
460
+ (double )stats.aggregate_sampling_time_ms /
461
+ stats.SCALING_FACTOR_UNITS_PER_SECOND );
450
462
}
451
463
452
- const std::string Runner::TimeStamps::toJsonString (
453
- const int64_t & num_prompt_tokens,
454
- const int64_t & num_generated_tokens) {
464
+ std::string statsToJsonString (const Runner::Stats& stats) {
455
465
std::stringstream ss;
456
- ss << " {\" prompt_tokens\" :" << num_prompt_tokens << " ,"
457
- << " \" generated_tokens\" :" << num_generated_tokens << " ,"
458
- << " \" model_load_start_ms\" :" << model_load_start_ms << " ,"
459
- << " \" model_load_end_ms\" :" << model_load_end_ms << " ,"
460
- << " \" inference_start_ms\" :" << inference_start_ms << " ,"
461
- << " \" inference_end_ms\" :" << inference_end_ms << " ,"
462
- << " \" prompt_eval_end_ms\" :" << prompt_eval_end_ms << " ,"
463
- << " \" first_token_ms\" :" << first_token_ms << " ,"
464
- << " \" aggregate_sampling_time_ms\" :" << aggregate_sampling_time_ms << " ,"
466
+ ss << " {\" prompt_tokens\" :" << stats.num_prompt_tokens << " ,"
467
+ << " \" generated_tokens\" :" << stats.num_generated_tokens << " ,"
468
+ << " \" model_load_start_ms\" :" << stats.model_load_start_ms << " ,"
469
+ << " \" model_load_end_ms\" :" << stats.model_load_end_ms << " ,"
470
+ << " \" inference_start_ms\" :" << stats.inference_start_ms << " ,"
471
+ << " \" inference_end_ms\" :" << stats.inference_end_ms << " ,"
472
+ << " \" prompt_eval_end_ms\" :" << stats.prompt_eval_end_ms << " ,"
473
+ << " \" first_token_ms\" :" << stats.first_token_ms << " ,"
474
+ << " \" aggregate_sampling_time_ms\" :" << stats.aggregate_sampling_time_ms
475
+ << " ,"
465
476
<< " \" SCALING_FACTOR_UNITS_PER_SECOND\" :"
466
- << SCALING_FACTOR_UNITS_PER_SECOND << " }" ;
477
+ << stats. SCALING_FACTOR_UNITS_PER_SECOND << " }" ;
467
478
return ss.str ();
468
479
}
480
+ } // namespace
469
481
470
482
void Runner::stop () {
471
483
shouldStop_ = true ;
0 commit comments