@@ -325,9 +325,8 @@ def main(
325
325
max_new_tokens : int = 100 ,
326
326
top_k : int = 200 ,
327
327
temperature : float = 0.8 ,
328
- checkpoint_path : Path = Path (
329
- "checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"
330
- ),
328
+ checkpoint_path : Optional [Path ] = None ,
329
+ tokenizer_path : Optional [Path ] = None ,
331
330
compile : bool = True ,
332
331
compile_prefill : bool = False ,
333
332
profile : Optional [Path ] = None ,
@@ -339,14 +338,21 @@ def main(
339
338
quantize = None ,
340
339
) -> None :
341
340
"""Generates text samples based on a pre-trained Transformer model and tokenizer."""
342
- assert checkpoint_path .is_file (), checkpoint_path
343
-
344
- torch .manual_seed (1234 )
345
-
346
- tokenizer_path = checkpoint_path .parent / "tokenizer.model"
341
+ assert (
342
+ (checkpoint_path and checkpoint_path .is_file ()) or
343
+ (dso_path and Path (dso_path ).is_file ()) or
344
+ (pte_path and Path (pte_path ).is_file ())
345
+ ), "need to specified a valid checkpoint path, DSO path, or PTE path"
346
+ assert not (dso_path and pte_path ), "specify either DSO path or PTE path, but not both"
347
+
348
+ if (checkpoint_path and (dso_path or pte_path )):
349
+ print ("Warning: checkpoint path ignored because an exported DSO or PTE path specified" )
350
+
351
+ if not tokenizer_path :
352
+ tokenizer_path = checkpoint_path .parent / "tokenizer.model"
347
353
assert tokenizer_path .is_file (), tokenizer_path
348
354
349
- global print
355
+ # global print
350
356
# from tp import maybe_init_dist
351
357
# rank = maybe_init_dist()
352
358
use_tp = False
@@ -540,10 +546,22 @@ def cli():
540
546
parser .add_argument (
541
547
"--temperature" , type = float , default = 0.8 , help = "Temperature for sampling."
542
548
)
549
+ parser .add_argument (
550
+ "--seed" ,
551
+ type = int ,
552
+ default = 1234 , # set None for release
553
+ help = "Initialize torch seed"
554
+ )
543
555
parser .add_argument (
544
556
"--checkpoint-path" ,
545
557
type = Path ,
546
- default = Path ("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth" ),
558
+ default = None ,
559
+ help = "Model checkpoint path." ,
560
+ )
561
+ parser .add_argument (
562
+ "--tokenizer-path" ,
563
+ type = Path ,
564
+ default = None ,
547
565
help = "Model checkpoint path." ,
548
566
)
549
567
parser .add_argument (
@@ -590,6 +608,10 @@ def cli():
590
608
591
609
592
610
args = parser .parse_args ()
611
+
612
+ if args .seed :
613
+ torch .manual_seed (args .seed )
614
+
593
615
main (
594
616
args .prompt ,
595
617
args .interactive ,
@@ -598,6 +620,7 @@ def cli():
598
620
args .top_k ,
599
621
args .temperature ,
600
622
args .checkpoint_path ,
623
+ args .tokenizer_path ,
601
624
args .compile ,
602
625
args .compile_prefill ,
603
626
args .profile ,
0 commit comments