File tree Expand file tree Collapse file tree 1 file changed +24
-2
lines changed Expand file tree Collapse file tree 1 file changed +24
-2
lines changed Original file line number Diff line number Diff line change @@ -681,9 +681,31 @@ def callback(
681
681
# print(, end='', flush=True)
682
682
683
683
else :
684
+ assert not generator_args .chat_mode
685
+ buffer = [generator_args .prompt ]
686
+ period_id = tokenizer .encode ("." )[0 ]
687
+ done_generating = False
684
688
685
- def callback (x ):
686
- return x
689
+ def callback (
690
+ x , buffer = buffer , period_id = period_id , done_generating = done_generating
691
+ ):
692
+ if done_generating :
693
+ return
694
+ buffer .append (
695
+ tokenizer .decode ([period_id ] + x .tolist ())[1 :]
696
+ ) # I think this results in the first output token being dropped from the display which is wrong.
697
+ if x .item () == tokenizer .eos_id ():
698
+ done_generating = True
699
+ if (
700
+ is_llama3_model
701
+ and x .item () == tokenizer .special_tokens ["<|eot_id|>" ]
702
+ ):
703
+ done_generating = True
704
+ buffer = buffer [:- 1 ] # drop the eot_id from the output buffer
705
+ if len (buffer ) == 4 or done_generating :
706
+ print ("" .join (buffer ), end = "" , flush = True )
707
+ buffer .clear ()
708
+ # print(, end='', flush=True)
687
709
688
710
t0 = time .perf_counter ()
689
711
import contextlib
You can’t perform that action at this time.
0 commit comments