@@ -463,6 +463,24 @@ def get_device_info(name: str) -> str:
463
463
return torch .cuda .get_device_name (0 )
464
464
return ""
465
465
466
+ def _callback (x , buffer , period_id , done_generating , tokenizer , is_llama3_model ):
467
+ if done_generating :
468
+ return
469
+ buffer .append (
470
+ tokenizer .decode ([period_id ] + x .tolist ())[1 :]
471
+ ) # I think this results in the first output token being dropped from the display which is wrong.
472
+ if x .item () == tokenizer .eos_id ():
473
+ done_generating = True
474
+ if (
475
+ is_llama3_model
476
+ and x .item () == tokenizer .special_tokens ["<|eot_id|>" ]
477
+ ):
478
+ done_generating = True
479
+ buffer = buffer [:- 1 ] # drop the eot_id from the output buffer
480
+ if len (buffer ) == 4 or done_generating :
481
+ print ("" .join (buffer ), end = "" , flush = True )
482
+ buffer .clear ()
483
+ # print(, end='', flush=True)
466
484
467
485
def _main (
468
486
builder_args : BuilderArgs ,
@@ -612,7 +630,7 @@ def _main(
612
630
break
613
631
if not is_llama3_model :
614
632
if system_prompt :
615
- prompt = f"{ B_INST } { B_SYS } \n { system_prompt .strip ()} \n { E_SYS } \n \n { prompt .strip } { E_INST } "
633
+ prompt = f"{ B_INST } { B_SYS } \n { system_prompt .strip ()} \n { E_SYS } \n \n { prompt .strip () } { E_INST } "
616
634
system_prompt = (
617
635
None # can only provide system prompt on first interaction
618
636
)
@@ -659,53 +677,17 @@ def _main(
659
677
period_id = tokenizer .encode ("." )[0 ]
660
678
done_generating = False
661
679
662
- def callback (
663
- x , buffer = buffer , period_id = period_id , done_generating = done_generating
664
- ):
665
- if done_generating :
666
- return
667
- buffer .append (
668
- tokenizer .decode ([period_id ] + x .tolist ())[1 :]
669
- ) # I think this results in the first output token being dropped from the display which is wrong.
670
- if x .item () == tokenizer .eos_id ():
671
- done_generating = True
672
- if (
673
- is_llama3_model
674
- and x .item () == tokenizer .special_tokens ["<|eot_id|>" ]
675
- ):
676
- done_generating = True
677
- buffer = buffer [:- 1 ] # drop the eot_id from the output buffer
678
- if len (buffer ) == 4 or done_generating :
679
- print ("" .join (buffer ), end = "" , flush = True )
680
- buffer .clear ()
681
- # print(, end='', flush=True)
680
+ def callback (x ):
681
+ return _callback (x , buffer = buffer , period_id = period_id , done_generating = done_generating , tokenizer = tokenizer , is_llama3_model = is_llama3_model )
682
682
683
683
else :
684
684
assert not generator_args .chat_mode
685
685
buffer = [generator_args .prompt ]
686
686
period_id = tokenizer .encode ("." )[0 ]
687
687
done_generating = False
688
688
689
- def callback (
690
- x , buffer = buffer , period_id = period_id , done_generating = done_generating
691
- ):
692
- if done_generating :
693
- return
694
- buffer .append (
695
- tokenizer .decode ([period_id ] + x .tolist ())[1 :]
696
- ) # I think this results in the first output token being dropped from the display which is wrong.
697
- if x .item () == tokenizer .eos_id ():
698
- done_generating = True
699
- if (
700
- is_llama3_model
701
- and x .item () == tokenizer .special_tokens ["<|eot_id|>" ]
702
- ):
703
- done_generating = True
704
- buffer = buffer [:- 1 ] # drop the eot_id from the output buffer
705
- if len (buffer ) == 4 or done_generating :
706
- print ("" .join (buffer ), end = "" , flush = True )
707
- buffer .clear ()
708
- # print(, end='', flush=True)
689
+ def callback (x ):
690
+ return _callback (x , buffer = buffer , period_id = period_id , done_generating = done_generating , tokenizer = tokenizer , is_llama3_model = is_llama3_model )
709
691
710
692
t0 = time .perf_counter ()
711
693
import contextlib
0 commit comments