Skip to content

Commit d0588e2

Browse files
mikekgfbmalfet
authored andcommitted
Torchat tiktoken (#127)
* add torchat.py and --tiktoken option * add default device to torchat * dtype handling for export_et * handle dtype args
1 parent 071f932 commit d0588e2

File tree

4 files changed

+177
-38
lines changed

4 files changed

+177
-38
lines changed

export.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,11 @@ def forward(self, idx, input_pos):
6060
return logits # sample(logits, **sampling_kwargs)
6161

6262

63-
def main(checkpoint_path, device, quantize = "{ }", args = None):
63+
def main(args):
64+
checkpoint_path = args.checkpoint_path
65+
device = args.device
66+
quantize = args.quantize
67+
6468
assert checkpoint_path.is_file(), checkpoint_path
6569

6670
print(f"Using device={device}")
@@ -201,7 +205,7 @@ def cli():
201205

202206

203207
args = parser.parse_args()
204-
main(args.checkpoint_path, args.device, args.quantize, args)
208+
main(args)
205209

206210
if __name__ == "__main__":
207211
cli()

export_et.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from generate import _load_model, decode_one_token
1515
from quantize import quantize_model
16+
from quantize import quantize_model, name_to_dtype, set_precision, get_precision
1617

1718
from model import Transformer
1819
# from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
@@ -92,23 +93,23 @@ def export_model(model, device, output_path, args=None) -> str: # noqa: C901
9293
# need to use kv sdpa?
9394
edge_config = EdgeCompileConfig(
9495
_check_ir_validity=False,
95-
_skip_type_promotion=bool(args.dtype == "fp16"),
96+
_skip_type_promotion=bool(target_precision == torch.float16),
9697
)
9798

9899
dynamic_shapes = None
99100

100-
if args.dtype is not None:
101-
if args.dtype == "fp16": # or args.quantization_mode == "int4":
102-
if state_dict_dtype != torch.float16:
103-
print("model.to torch.float16")
104-
model = model.to(dtype=torch.float16)
105-
state_dict_dtype = torch.float16
106-
elif args.dtype == "fp32":
107-
if state_dict_dtype != torch.float32:
108-
print("model.to torch.float32")
109-
model = model.to(dtype=torch.float32)
110-
else:
111-
raise ValueError(f"Unsupported dtype: {args.dtype}")
101+
target_precision = get_precision()
102+
if target_precision == torch.float16: # or args.quantization_mode=="int4":
103+
if state_dict_dtype != torch.float16:
104+
print("model.to torch.float16")
105+
model = model.to(dtype=torch.float16)
106+
state_dict_dtype = torch.float16
107+
elif target_precision = torch.float32:
108+
if state_dict_dtype != torch.float32:
109+
print("model.to torch.float32")
110+
model = model.to(dtype=torch.float32)
111+
else:
112+
raise ValueError(f"Unsupported dtype for ET export: {target_precision}")
112113

113114
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]), torch.no_grad():
114115
m = capture_pre_autograd_graph(

generate.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ def _load_model(
336336
B_INST, E_INST = "[INST]", "[/INST]"
337337

338338

339-
def main(
339+
def _main(
340340
prompt: str = "Hello, my name is",
341341
interactive: bool = False,
342342
num_samples: int = 5,
@@ -357,6 +357,7 @@ def main(
357357
pte_path=None,
358358
quantize=None,
359359
model_dtype=None,
360+
use_tiktoken=False,
360361
) -> None:
361362
"""Generates text samples based on a pre-trained Transformer model and tokenizer."""
362363
assert (
@@ -573,6 +574,28 @@ def callback(x):
573574
)
574575
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
575576

577+
def main(args):
578+
_main(
579+
args.prompt,
580+
args.interactive,
581+
args.num_samples,
582+
args.max_new_tokens,
583+
args.top_k,
584+
args.temperature,
585+
args.checkpoint_path,
586+
args.tokenizer_path,
587+
args.compile,
588+
args.compile_prefill,
589+
args.profile,
590+
args.draft_checkpoint_path,
591+
args.speculate_k,
592+
args.device,
593+
args.dso_path,
594+
args.pte_path,
595+
args.quantize,
596+
args.dtype,
597+
args.tiktoken
598+
)
576599

577600
def cli():
578601
import argparse
@@ -672,35 +695,20 @@ def cli():
672695
default="float32",
673696
help="Override the dtype of the model (default is the checkpoint dtype). Options: bf16, fp16, fp32",
674697
)
698+
parser.add_argument(
699+
"--tiktoken",
700+
action="store_true",
701+
help="Whether to use tiktoken tokenizer.",
702+
)
675703

676704

677705
args = parser.parse_args()
678706

679707
if args.seed:
680708
torch.manual_seed(args.seed)
681709

682-
main(
683-
args.prompt,
684-
args.interactive,
685-
args.num_samples,
686-
args.max_new_tokens,
687-
args.top_k,
688-
args.temperature,
689-
args.checkpoint_path,
690-
args.checkpoint_dir,
691-
args.params_path,
692-
args.tokenizer_path,
693-
args.compile,
694-
args.compile_prefill,
695-
args.profile,
696-
args.draft_checkpoint_path,
697-
args.speculate_k,
698-
args.device,
699-
args.dso_path,
700-
args.pte_path,
701-
args.quantize,
702-
args.dtype,
703-
)
710+
main(args)
711+
704712

705713
if __name__ == "__main__":
706714
cli()

torchat.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import time
8+
import os
9+
from pathlib import Path
10+
11+
import torch
12+
import torch.nn as nn
13+
from torch.export import Dim, export
14+
15+
from export import main as export_main
16+
from generate import main as generate_main
17+
18+
default_device = "cpu" # 'cuda' if torch.cuda.is_available() else 'cpu'
19+
20+
def cli():
21+
import argparse
22+
23+
parser = argparse.ArgumentParser(description="Your CLI description.")
24+
25+
parser.add_argument(
26+
"--prompt", type=str, default="Hello, my name is", help="Input prompt."
27+
)
28+
parser.add_argument(
29+
"--interactive",
30+
action="store_true",
31+
help="Whether to launch in interactive mode",
32+
)
33+
parser.add_argument(
34+
"--tiktoken",
35+
action="store_true",
36+
help="Whether to use tiktoken tokenizer.",
37+
)
38+
parser.add_argument(
39+
"--export",
40+
action="store_true",
41+
help="Use torchat to export a model.",
42+
)
43+
parser.add_argument(
44+
"--generate",
45+
action="store_true",
46+
help="Use torchat to generate a sequence using a model.",
47+
)
48+
parser.add_argument("--num-samples", type=int, default=5, help="Number of samples.")
49+
parser.add_argument(
50+
"--max-new-tokens", type=int, default=200, help="Maximum number of new tokens."
51+
)
52+
parser.add_argument("--top-k", type=int, default=200, help="Top-k for sampling.")
53+
parser.add_argument(
54+
"--temperature", type=float, default=0.8, help="Temperature for sampling."
55+
)
56+
parser.add_argument(
57+
"--compile", action="store_true", help="Whether to compile the model."
58+
)
59+
parser.add_argument(
60+
"--compile-prefill",
61+
action="store_true",
62+
help="Whether to compile the prefill (improves prefill perf, but higher compile times)",
63+
)
64+
parser.add_argument(
65+
"--profile", type=Path, default=None, help="Profile path.")
66+
parser.add_argument(
67+
"--speculate-k", type=int, default=5, help="Speculative execution depth."
68+
)
69+
parser.add_argument(
70+
"--draft-checkpoint-path",
71+
type=Path,
72+
default=None,
73+
help="Draft checkpoint path.",
74+
)
75+
#####################################################################
76+
77+
parser.add_argument(
78+
"--checkpoint-path",
79+
type=Path,
80+
default="not_specified",
81+
help="Model checkpoint path.",
82+
)
83+
parser.add_argument(
84+
"--output-pte-path",
85+
type=str,
86+
default=None,
87+
help="Filename"
88+
)
89+
parser.add_argument(
90+
"--output-dso-path",
91+
type=str,
92+
default=None,
93+
help="Filename"
94+
)
95+
parser.add_argument(
96+
"-d",
97+
"--dtype",
98+
default=None,
99+
help="Override the dtype of the model (default is the checkpoint dtype). Options: bf16, fp16, fp32",
100+
)
101+
parser.add_argument("-v", "--verbose", action="store_true")
102+
parser.add_argument(
103+
"--quantize",
104+
type=str,
105+
default="{ }",
106+
help="Quantization options."
107+
)
108+
parser.add_argument(
109+
"--device", type=str, default=default_device, help="Device to use"
110+
)
111+
112+
113+
args = parser.parse_args()
114+
115+
if args.seed:
116+
torch.manual_seed(args.seed)
117+
118+
if args.generate:
119+
generate_main(args)
120+
elif args.export:
121+
export_main(args)
122+
else:
123+
raise RuntimeError("must specify either --generate or --export")
124+
125+
if __name__ == "__main__":
126+
cli()

0 commit comments

Comments
 (0)