@@ -519,8 +519,10 @@ def _main(
519
519
520
520
tokenizer = _initialize_tokenizer (tokenizer_args )
521
521
522
- # Right now the assumption is only llama3 uses tiktokenizer and it must use tiktokenizer.
523
- # Piggy backing off of this flag then for now to identify llama3 without prompting user.
522
+ # Right now the assumption is only llama3 uses tiktokenizer and it
523
+ # must use tiktokenizer.
524
+ # Piggy backing off of this flag then for now to identify llama3
525
+ # without prompting user.
524
526
is_llama3_model = tokenizer_args .is_tiktoken
525
527
if generator_args .chat_mode and is_llama3_model :
526
528
logging .debug (
@@ -610,7 +612,8 @@ def _main(
610
612
start = - 1 if generator_args .compile else 0
611
613
start_pos = 0
612
614
613
- # arbitrarily large number as chat mode goes until max_seq length or user exits
615
+ # arbitrarily large number as chat mode goes until max_seq length
616
+ # or user exits
614
617
num_samples = generator_args .num_samples if not generator_args .chat_mode else 100000
615
618
i = (
616
619
- 1
@@ -743,28 +746,33 @@ def callback(x):
743
746
tokens_generated = y .size (0 ) - prompt_length
744
747
tokens_sec = tokens_generated / t
745
748
aggregate_metrics ["tokens_per_sec" ].append (tokens_sec )
746
- logging . debug (
749
+ print (
747
750
f"Time for inference { i + 1 } : { t :.02f} sec total, { tokens_sec :.02f} tokens/sec"
748
751
)
749
- logging . debug (f"Bandwidth achieved: { model_size * tokens_sec / 1e9 :.02f} GB/s" )
752
+ print (f"Bandwidth achieved: { model_size * tokens_sec / 1e9 :.02f} GB/s" )
750
753
751
754
if start_pos >= max_seq_length :
752
- print ("Max Sequence Length Reached. Ending Conversation." )
753
- break
755
+ print (f"[Max Sequence Length Reached. Ending Conversation.]" )
756
+ print (f"---------------------------------------------------" )
757
+ if generator_args .chat_mode :
758
+ break
754
759
755
- print ("==========" )
760
+ if not generator_args .chat_mode :
761
+ start_pos = 0
762
+
763
+ print ("\n ========================================\n " )
756
764
if is_speculative :
757
765
counts_aggregated = [sum (i ) for i in zip (* aggregate_metrics ["accept_counts" ])]
758
766
acceptance_probs = [i / sum (counts_aggregated ) for i in counts_aggregated ]
759
- logging . info (f"Acceptance probs: { acceptance_probs } " )
760
- logging . info (
767
+ print (f"Acceptance probs: { acceptance_probs } " )
768
+ print (
761
769
f"Mean Accepted: { sum ([idx * i for idx , i in enumerate (counts_aggregated )])/ sum (counts_aggregated )} "
762
770
)
763
771
764
- logging . info (
772
+ print (
765
773
f"Average tokens/sec: { torch .mean (torch .tensor (aggregate_metrics ['tokens_per_sec' ])).item ():.2f} "
766
774
)
767
- logging . info (f"Memory used: { torch .cuda .max_memory_reserved () / 1e9 :.02f} GB" )
775
+ print (f"Memory used: { torch .cuda .max_memory_reserved () / 1e9 :.02f} GB" )
768
776
769
777
770
778
def main (args ):
0 commit comments