Skip to content

Commit e677fe8

Browse files
metascroymalfet
authored andcommitted
stream results in generate.py (#571)
1 parent 9d10748 commit e677fe8

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

generate.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -681,9 +681,31 @@ def callback(
681681
# print(, end='', flush=True)
682682

683683
else:
684+
assert not generator_args.chat_mode
685+
buffer = [generator_args.prompt]
686+
period_id = tokenizer.encode(".")[0]
687+
done_generating = False
684688

685-
def callback(x):
686-
return x
689+
def callback(
690+
x, buffer=buffer, period_id=period_id, done_generating=done_generating
691+
):
692+
if done_generating:
693+
return
694+
buffer.append(
695+
tokenizer.decode([period_id] + x.tolist())[1:]
696+
) # I think this results in the first output token being dropped from the display which is wrong.
697+
if x.item() == tokenizer.eos_id():
698+
done_generating = True
699+
if (
700+
is_llama3_model
701+
and x.item() == tokenizer.special_tokens["<|eot_id|>"]
702+
):
703+
done_generating = True
704+
buffer = buffer[:-1] # drop the eot_id from the output buffer
705+
if len(buffer) == 4 or done_generating:
706+
print("".join(buffer), end="", flush=True)
707+
buffer.clear()
708+
# print(, end='', flush=True)
687709

688710
t0 = time.perf_counter()
689711
import contextlib

0 commit comments

Comments
 (0)