@@ -161,8 +161,12 @@ Error Runner::generate(
161
161
// Prepare the inputs.
162
162
// Use ones-initialized inputs.
163
163
ET_CHECK_MSG (!prompt.empty (), " Prompt cannot be null" );
164
+ timers_.model_load = util::time_in_ms ();
164
165
ET_CHECK_OK_OR_RETURN_ERROR (load ());
166
+ timers_.model_load = util::time_in_ms () - timers_.model_load ;
165
167
168
+ // First token time only measures the time it takes to encode the prompt and return a response token.
169
+ timers_.start = util::time_in_ms ();
166
170
shouldStop_ = false ;
167
171
168
172
// encode the (string) prompt into tokens sequence
@@ -173,12 +177,14 @@ Error Runner::generate(
173
177
// Set the sequence length to the max seq length if not provided
174
178
seq_len = (seq_len > 0 && seq_len <= max_seq_len_) ? seq_len : max_seq_len_;
175
179
180
+
176
181
tokenizer_->encode (
177
182
prompt.c_str (),
178
183
n_bos_,
179
184
append_eos_ ? n_eos_ : 0 ,
180
185
prompt_tokens,
181
186
&num_prompt_tokens);
187
+
182
188
for (int i = 0 ; i < num_prompt_tokens; i++) {
183
189
ET_LOG (Info, " prompt_tokens[%d]: %d" , i, prompt_tokens[i]);
184
190
}
@@ -192,8 +198,6 @@ Error Runner::generate(
192
198
" Sequence length exceeded - please increase the seq_len value passed to generate()" );
193
199
194
200
// start the main loop
195
- long start =
196
- 0 ; // used to time our code, only initialized after first iteration
197
201
int next; // will store the next token in the sequence
198
202
int64_t pos = num_prompt_tokens - 1 ; // position in the sequence
199
203
int token = prompt_tokens[pos]; // prefill starts from 0 to num_prompt_tokens
@@ -255,6 +259,7 @@ Error Runner::generate(
255
259
tokenizer_->decode (prompt_tokens[i - 1 ], prompt_tokens[i])));
256
260
}
257
261
}
262
+
258
263
// create a 1xN int tensor with next as value
259
264
while (pos < seq_len) {
260
265
// ET_LOG(Info, "Generating step %d...", pos);
@@ -290,7 +295,12 @@ Error Runner::generate(
290
295
outputs.size () > 0 ,
291
296
" Expecting output to have at least one evalue. Got %zu" ,
292
297
outputs.size ());
293
-
298
+ if (pos == num_prompt_tokens) {
299
+ timers_.first_token = util::time_in_ms () - timers_.start ;
300
+ timers_.remaining_tokens = util::time_in_ms ();
301
+ } else if (pos == num_prompt_tokens - 1 ) {
302
+ timers_.prompt_eval = util::time_in_ms () - timers_.start ;
303
+ }
294
304
int32_t next_tok;
295
305
exec_aten::Tensor logits_tensor = outputs.at (logits_index).toTensor ();
296
306
@@ -342,6 +352,7 @@ Error Runner::generate(
342
352
if (pos >= num_prompt_tokens && next == eos_id_) {
343
353
eos_counter++;
344
354
if (eos_counter == n_eos_) {
355
+ printf (" \n " );
345
356
ET_LOG (Info, " Reached to the end of generation" );
346
357
break ;
347
358
}
@@ -351,10 +362,6 @@ Error Runner::generate(
351
362
352
363
token = next;
353
364
354
- // init the timer here because the first iteration can be slower
355
- if (start == 0 ) {
356
- start = util::time_in_ms ();
357
- }
358
365
if (use_kv_cache_) {
359
366
// outputs: [k_cache, v_cache, logits, k_cache, v_cache]
360
367
memcpy (
@@ -367,23 +374,51 @@ Error Runner::generate(
367
374
v_data.size ());
368
375
}
369
376
}
377
+ timers_.remaining_tokens = util::time_in_ms () - timers_.remaining_tokens ;
378
+ timers_.end = util::time_in_ms ();
370
379
printf (" \n " );
371
380
372
381
if (pos == seq_len) {
373
382
ET_LOG (Info, " Sequence length (%i tokens) reached!" , seq_len);
374
383
}
375
- // report achieved tok/s (pos-1 because the timer starts after first
376
- // iteration)
377
- if (pos >= 1 ) {
378
- long end = util::time_in_ms ();
379
- ET_LOG (
380
- Info, " Achieved tok/s: %f\n " , (pos - 1 ) / (double )(end - start) * 1000 );
381
- }
384
+
385
+ printReport (num_prompt_tokens, pos - num_prompt_tokens);
382
386
383
387
delete[] prompt_tokens;
384
388
return Error::Ok;
385
389
}
386
390
391
+ void Runner::printReport (int64_t num_prompt_tokens, int64_t num_generated_tokens) { printf (" \n " );
392
+ double net_eval_time = (double )(timers_.first_token + timers_.remaining_tokens - timers_.prompt_eval );
393
+ ET_LOG (
394
+ Info,
395
+ " \t Prompt Tokens: %ld Generated Tokens: %ld" , num_prompt_tokens, num_generated_tokens);
396
+ ET_LOG (
397
+ Info,
398
+ " \t Model Load Time:\t\t %f (seconds)" ,
399
+ ((double )(timers_.model_load ) / 1000 ));
400
+ ET_LOG (
401
+ Info,
402
+ " \t Total inference time:\t\t %f (seconds)\t\t Token Rate: \t %f (tokens/second)" , (double )(timers_.end - timers_.start ) / 1000 ,
403
+ (num_generated_tokens) / (double )(timers_.end - timers_.start ) * 1000 );
404
+ ET_LOG (
405
+ Info,
406
+ " \t\t Time to first token:\t %f (seconds)" ,
407
+ ((double )(timers_.first_token ) / 1000 ));
408
+ ET_LOG (
409
+ Info,
410
+ " \t\t\t Prompt eval:\t %f (seconds)\t\t Token Rate: \t %f (tokens/second)" ,
411
+ ((double )(timers_.prompt_eval ) / 1000 ), (num_prompt_tokens) / (double )(timers_.prompt_eval ) * 1000 );
412
+ ET_LOG (
413
+ Info,
414
+ " \t\t Remaining %ld tokens:\t %f (seconds)\t\t Token Rate: \t %f (tokens/second)" , num_generated_tokens-1 , (double )(timers_.remaining_tokens ) / 1000 ,
415
+ (num_generated_tokens - 1 ) / (double )(timers_.remaining_tokens ) * 1000 );
416
+ ET_LOG (
417
+ Info,
418
+ " \t\t Net evaluation time:\t %f (seconds)\t\t Token Rate: \t %f (tokens/second)" ,
419
+ (net_eval_time / 1000 ), (num_generated_tokens) / net_eval_time * 1000 );
420
+ }
421
+
387
422
void Runner::stop () {
388
423
shouldStop_ = true ;
389
424
}
0 commit comments