Skip to content

Commit 9745fe0

Browse files
Michael Gschwindmalfet
authored andcommitted
fixes issue #36
1 parent 15709e0 commit 9745fe0

File tree

1 file changed

+33
-10
lines changed

1 file changed

+33
-10
lines changed

generate.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -325,9 +325,8 @@ def main(
325325
max_new_tokens: int = 100,
326326
top_k: int = 200,
327327
temperature: float = 0.8,
328-
checkpoint_path: Path = Path(
329-
"checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"
330-
),
328+
checkpoint_path: Optional[Path] = None,
329+
tokenizer_path: Optional[Path] = None,
331330
compile: bool = True,
332331
compile_prefill: bool = False,
333332
profile: Optional[Path] = None,
@@ -339,14 +338,21 @@ def main(
339338
quantize=None,
340339
) -> None:
341340
"""Generates text samples based on a pre-trained Transformer model and tokenizer."""
342-
assert checkpoint_path.is_file(), checkpoint_path
343-
344-
torch.manual_seed(1234)
345-
346-
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
341+
assert (
342+
(checkpoint_path and checkpoint_path.is_file()) or
343+
(dso_path and Path(dso_path).is_file()) or
344+
(pte_path and Path(pte_path).is_file())
345+
), "need to specified a valid checkpoint path, DSO path, or PTE path"
346+
assert not (dso_path and pte_path), "specify either DSO path or PTE path, but not both"
347+
348+
if (checkpoint_path and (dso_path or pte_path)):
349+
print("Warning: checkpoint path ignored because an exported DSO or PTE path specified")
350+
351+
if not tokenizer_path:
352+
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
347353
assert tokenizer_path.is_file(), tokenizer_path
348354

349-
global print
355+
# global print
350356
# from tp import maybe_init_dist
351357
# rank = maybe_init_dist()
352358
use_tp = False
@@ -540,10 +546,22 @@ def cli():
540546
parser.add_argument(
541547
"--temperature", type=float, default=0.8, help="Temperature for sampling."
542548
)
549+
parser.add_argument(
550+
"--seed",
551+
type=int,
552+
default=1234, # set None for release
553+
help="Initialize torch seed"
554+
)
543555
parser.add_argument(
544556
"--checkpoint-path",
545557
type=Path,
546-
default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
558+
default=None,
559+
help="Model checkpoint path.",
560+
)
561+
parser.add_argument(
562+
"--tokenizer-path",
563+
type=Path,
564+
default=None,
547565
help="Model checkpoint path.",
548566
)
549567
parser.add_argument(
@@ -590,6 +608,10 @@ def cli():
590608

591609

592610
args = parser.parse_args()
611+
612+
if args.seed:
613+
torch.manual_seed(args.seed)
614+
593615
main(
594616
args.prompt,
595617
args.interactive,
@@ -598,6 +620,7 @@ def cli():
598620
args.top_k,
599621
args.temperature,
600622
args.checkpoint_path,
623+
args.tokenizer_path,
601624
args.compile,
602625
args.compile_prefill,
603626
args.profile,

0 commit comments

Comments
 (0)