Skip to content

Commit e15b509

Browse files
mikekgfbmalfet
authored andcommitted
make num samples work for directed prompt-based sequence generation (#715)
1 parent b13ecb3 commit e15b509

File tree

1 file changed

+20
-12
lines changed

1 file changed

+20
-12
lines changed

generate.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -519,8 +519,10 @@ def _main(
519519

520520
tokenizer = _initialize_tokenizer(tokenizer_args)
521521

522-
# Right now the assumption is only llama3 uses tiktokenizer and it must use tiktokenizer.
523-
# Piggy backing off of this flag then for now to identify llama3 without prompting user.
522+
# Right now the assumption is only llama3 uses tiktokenizer and it
523+
# must use tiktokenizer.
524+
# Piggy backing off of this flag then for now to identify llama3
525+
# without prompting user.
524526
is_llama3_model = tokenizer_args.is_tiktoken
525527
if generator_args.chat_mode and is_llama3_model:
526528
logging.debug(
@@ -610,7 +612,8 @@ def _main(
610612
start = -1 if generator_args.compile else 0
611613
start_pos = 0
612614

613-
# arbitrarily large number as chat mode goes until max_seq length or user exits
615+
# arbitrarily large number as chat mode goes until max_seq length
616+
# or user exits
614617
num_samples = generator_args.num_samples if not generator_args.chat_mode else 100000
615618
i = (
616619
-1
@@ -743,28 +746,33 @@ def callback(x):
743746
tokens_generated = y.size(0) - prompt_length
744747
tokens_sec = tokens_generated / t
745748
aggregate_metrics["tokens_per_sec"].append(tokens_sec)
746-
logging.debug(
749+
print(
747750
f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec"
748751
)
749-
logging.debug(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
752+
print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
750753

751754
if start_pos >= max_seq_length:
752-
print("Max Sequence Length Reached. Ending Conversation.")
753-
break
755+
print(f"[Max Sequence Length Reached. Ending Conversation.]")
756+
print(f"---------------------------------------------------")
757+
if generator_args.chat_mode:
758+
break
754759

755-
print("==========")
760+
if not generator_args.chat_mode:
761+
start_pos = 0
762+
763+
print("\n========================================\n")
756764
if is_speculative:
757765
counts_aggregated = [sum(i) for i in zip(*aggregate_metrics["accept_counts"])]
758766
acceptance_probs = [i / sum(counts_aggregated) for i in counts_aggregated]
759-
logging.info(f"Acceptance probs: {acceptance_probs}")
760-
logging.info(
767+
print(f"Acceptance probs: {acceptance_probs}")
768+
print(
761769
f"Mean Accepted: {sum([idx * i for idx, i in enumerate(counts_aggregated)])/sum(counts_aggregated)}"
762770
)
763771

764-
logging.info(
772+
print(
765773
f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}"
766774
)
767-
logging.info(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
775+
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
768776

769777

770778
def main(args):

0 commit comments

Comments
 (0)