14
14
15
15
#include < ctime>
16
16
#include < memory>
17
+ #include < sstream>
17
18
18
19
#ifdef USE_ATEN_LIB
19
20
#include < torch/torch.h>
@@ -161,8 +162,16 @@ Error Runner::generate(
161
162
// Prepare the inputs.
162
163
// Use ones-initialized inputs.
163
164
ET_CHECK_MSG (!prompt.empty (), " Prompt cannot be null" );
164
- ET_CHECK_OK_OR_RETURN_ERROR (load ());
165
+ if (!is_loaded ()) {
166
+ timers_.model_load_start_ms = util::time_in_ms ();
167
+ ET_CHECK_OK_OR_RETURN_ERROR (load ());
168
+ timers_.model_load_end_ms = util::time_in_ms ();
169
+ }
170
+
171
+ // First token time only measures the time it takes to encode the prompt and
172
+ // return a response token.
165
173
174
+ timers_.inference_start_ms = util::time_in_ms ();
166
175
shouldStop_ = false ;
167
176
168
177
// encode the (string) prompt into tokens sequence
@@ -179,6 +188,7 @@ Error Runner::generate(
179
188
append_eos_ ? n_eos_ : 0 ,
180
189
prompt_tokens,
181
190
&num_prompt_tokens);
191
+
182
192
for (int i = 0 ; i < num_prompt_tokens; i++) {
183
193
ET_LOG (Info, " prompt_tokens[%d]: %d" , i, prompt_tokens[i]);
184
194
}
@@ -192,8 +202,6 @@ Error Runner::generate(
192
202
" Sequence length exceeded - please increase the seq_len value passed to generate()" );
193
203
194
204
// start the main loop
195
- long start =
196
- 0 ; // used to time our code, only initialized after first iteration
197
205
int next; // will store the next token in the sequence
198
206
int64_t pos = num_prompt_tokens - 1 ; // position in the sequence
199
207
int token = prompt_tokens[pos]; // prefill starts from 0 to num_prompt_tokens
@@ -254,6 +262,7 @@ Error Runner::generate(
254
262
tokenizer_->decode (prompt_tokens[i - 1 ], prompt_tokens[i])));
255
263
}
256
264
}
265
+
257
266
// create a 1xN int tensor with next as value
258
267
while (pos < seq_len) {
259
268
// ET_LOG(Info, "Generating step %d...", pos);
@@ -289,10 +298,14 @@ Error Runner::generate(
289
298
outputs.size () > 0 ,
290
299
" Expecting output to have at least one evalue. Got %zu" ,
291
300
outputs.size ());
292
-
301
+ if (pos == num_prompt_tokens) {
302
+ timers_.first_token_ms = util::time_in_ms ();
303
+ } else if (pos == num_prompt_tokens - 1 ) {
304
+ timers_.prompt_eval_end_ms = util::time_in_ms ();
305
+ }
293
306
int32_t next_tok;
294
307
exec_aten::Tensor logits_tensor = outputs.at (logits_index).toTensor ();
295
-
308
+ long sample_start_time_ms = util::time_in_ms ();
296
309
switch (logits_tensor.scalar_type ()) {
297
310
case ScalarType::Float: {
298
311
next_tok = logitsToToken<float >(logits_tensor, pos, 0 );
@@ -308,6 +321,8 @@ Error Runner::generate(
308
321
" Unsupported dtype output %hhd" ,
309
322
static_cast <int8_t >(logits_tensor.scalar_type ()));
310
323
}
324
+ timers_.aggregate_sampling_time_ms +=
325
+ util::time_in_ms () - sample_start_time_ms;
311
326
312
327
// advance the state machine
313
328
if (pos < num_prompt_tokens - 1 ) {
@@ -339,16 +354,13 @@ Error Runner::generate(
339
354
340
355
// data-dependent terminating condition: we have n_eos_ number of EOS
341
356
if (pos >= num_prompt_tokens && next == eos_id_) {
342
- ET_LOG (Info, " Reached to the end of generation" );
357
+ printf (" \n " );
358
+ ET_LOG (Info, " \n Reached to the end of generation" );
343
359
break ;
344
360
}
345
361
346
362
token = next;
347
363
348
- // init the timer here because the first iteration can be slower
349
- if (start == 0 ) {
350
- start = util::time_in_ms ();
351
- }
352
364
if (use_kv_cache_) {
353
365
// outputs: [k_cache, v_cache, logits, k_cache, v_cache]
354
366
memcpy (
@@ -361,23 +373,94 @@ Error Runner::generate(
361
373
v_data.size ());
362
374
}
363
375
}
376
+ timers_.inference_end_ms = util::time_in_ms ();
364
377
printf (" \n " );
365
378
366
379
if (pos == seq_len) {
367
380
ET_LOG (Info, " Sequence length (%i tokens) reached!" , seq_len);
368
381
}
369
- // report achieved tok/s (pos-1 because the timer starts after first
370
- // iteration)
371
- if (pos >= 1 ) {
372
- long end = util::time_in_ms ();
373
- ET_LOG (
374
- Info, " Achieved tok/s: %f\n " , (pos - 1 ) / (double )(end - start) * 1000 );
375
- }
382
+
383
+ timers_.printReport (num_prompt_tokens, pos - num_prompt_tokens);
376
384
377
385
delete[] prompt_tokens;
378
386
return Error::Ok;
379
387
}
380
388
389
+ void Runner::TimeStamps::printReport (
390
+ const int64_t & num_prompt_tokens,
391
+ const int64_t & num_generated_tokens) {
392
+ ET_LOG (
393
+ Info,
394
+ " \t Prompt Tokens: %ld Generated Tokens: %ld" ,
395
+ num_prompt_tokens,
396
+ num_generated_tokens);
397
+
398
+ ET_LOG (
399
+ Info,
400
+ " \t Model Load Time:\t\t %f (seconds)" ,
401
+ ((double )(model_load_end_ms - model_load_start_ms) /
402
+ SCALING_FACTOR_UNITS_PER_SECOND));
403
+ double inference_time_ms = (double )(inference_end_ms - inference_start_ms);
404
+ ET_LOG (
405
+ Info,
406
+ " \t Total inference time:\t\t %f (seconds)\t\t Rate: \t %f (tokens/second)" ,
407
+ inference_time_ms / SCALING_FACTOR_UNITS_PER_SECOND,
408
+
409
+ (num_generated_tokens) / (double )(inference_end_ms - inference_start_ms) *
410
+ SCALING_FACTOR_UNITS_PER_SECOND);
411
+ double prompt_eval_time = (double )(prompt_eval_end_ms - inference_start_ms);
412
+ ET_LOG (
413
+ Info,
414
+ " \t\t Prompt evaluation:\t %f (seconds)\t\t Rate: \t %f (tokens/second)" ,
415
+ prompt_eval_time / SCALING_FACTOR_UNITS_PER_SECOND,
416
+ (num_prompt_tokens) / prompt_eval_time * SCALING_FACTOR_UNITS_PER_SECOND);
417
+
418
+ double eval_time = (double )(inference_end_ms - prompt_eval_end_ms);
419
+ ET_LOG (
420
+ Info,
421
+ " \t\t Generated %ld tokens:\t %f (seconds)\t\t Rate: \t %f (tokens/second)" ,
422
+ num_generated_tokens,
423
+ eval_time / SCALING_FACTOR_UNITS_PER_SECOND,
424
+ num_generated_tokens / eval_time * SCALING_FACTOR_UNITS_PER_SECOND);
425
+
426
+ // Time to first token is measured from the start of inference, excluding
427
+ // model load time.
428
+ ET_LOG (
429
+ Info,
430
+ " \t Time to first generated token:\t %f (seconds)" ,
431
+ ((double )(first_token_ms - inference_start_ms) /
432
+ SCALING_FACTOR_UNITS_PER_SECOND));
433
+
434
+ ET_LOG (
435
+ Info,
436
+ " \t Sampling time over %ld tokens:\t %f (seconds)\t\t Rate: \t %f (tokens/second)" ,
437
+ num_prompt_tokens + num_generated_tokens,
438
+ (double )aggregate_sampling_time_ms / SCALING_FACTOR_UNITS_PER_SECOND,
439
+ (num_prompt_tokens + num_generated_tokens) / (double )aggregate_sampling_time_ms * SCALING_FACTOR_UNITS_PER_SECOND);
440
+
441
+ printf (
442
+ " [llama_runner_perf_data] %s\n " ,
443
+ toJsonString (num_prompt_tokens, num_generated_tokens).c_str ());
444
+ }
445
+
446
+ const std::string Runner::TimeStamps::toJsonString (
447
+ const int64_t & num_prompt_tokens,
448
+ const int64_t & num_generated_tokens) {
449
+ std::stringstream ss;
450
+ ss << " {\" prompt_tokens\" :" << num_prompt_tokens << " ,"
451
+ << " \" generated_tokens\" :" << num_generated_tokens << " ,"
452
+ << " \" model_load_start_ms\" :" << model_load_start_ms << " ,"
453
+ << " \" model_load_end_ms\" :" << model_load_end_ms << " ,"
454
+ << " \" inference_start_ms\" :" << inference_start_ms << " ,"
455
+ << " \" inference_end_ms\" :" << inference_end_ms << " ,"
456
+ << " \" prompt_eval_end_ms\" :" << prompt_eval_end_ms << " ,"
457
+ << " \" first_token_ms\" :" << first_token_ms << " ,"
458
+ << " \" aggregate_sampling_time_ms\" :" << aggregate_sampling_time_ms << " ,"
459
+ << " \" SCALING_FACTOR_UNITS_PER_SECOND\" :"
460
+ << SCALING_FACTOR_UNITS_PER_SECOND << " }" ;
461
+ return ss.str ();
462
+ }
463
+
381
464
void Runner::stop () {
382
465
shouldStop_ = true ;
383
466
}
0 commit comments