Skip to content

Commit 055cf7b

Browse files
authored
Merge branch 'main' into main
2 parents 8a9e672 + e30aaa0 commit 055cf7b

File tree

12 files changed

+282
-69
lines changed

12 files changed

+282
-69
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,7 @@ We really value our community and the contributions made by our wonderful users.
575575

576576
To connect with us and other community members, we invite you to join our Slack community by filling out this [form](https://docs.google.com/forms/d/e/1FAIpQLSeADnUNW36fjKjYzyHDOzEB_abKQE9b6gqqW9NXse6O0MWh0A/viewform). Once you've joined, you can:
577577
* Head to the `#torchchat-general` channel for general questions, discussion, and community support.
578-
* Join the `#torchchat-contribution` channel if you're interested in contributing directly to project development.
578+
* Join the `#torchchat-contributors` channel if you're interested in contributing directly to project development.
579579

580580
Looking forward to discussing with you about torchchat future!
581581

tokenizer/base.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
"""
7+
Abstract base class for all tokenizer classes in python matching c++ interface.
8+
"""
9+
10+
# Standard
11+
from abc import ABC, abstractmethod
12+
from typing import List
13+
14+
15+
class TokenizerBase(ABC):
16+
__doc__ = __doc__
17+
18+
@abstractmethod
19+
def encode(self, s: str, *, bos: bool = False, eos: bool = False) -> List[int]:
20+
"""Encode the given string and optionally include bos/eos tokens"""
21+
22+
@abstractmethod
23+
def decode(self, ids: List[int]) -> str:
24+
"""Decode the given token ids into a string"""
25+
26+
@abstractmethod
27+
def bos_id(self) -> int:
28+
"""The id of the begin-of-string token"""
29+
30+
@abstractmethod
31+
def eos_id(self) -> int:
32+
"""The id of the end-of-string token"""

tokenizer/hf_tokenizer.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Standard
8+
from typing import List, Optional
9+
import json
10+
import os
11+
12+
# Third Party
13+
from tokenizers import Tokenizer
14+
15+
# Local
16+
from .base import TokenizerBase
17+
18+
19+
class HFTokenizer(TokenizerBase):
20+
"""
21+
Wrapper around the Huggingface `tokenizers` library for API compatibility
22+
"""
23+
24+
def __init__(self, file_path: str):
25+
# If the path is a directory, look for "tokenizer.json" which is
26+
# standard for transformers checkpoints and also look for the
27+
# "tokenizer_config.json" file to parse eos/bos tokens
28+
if os.path.isdir(file_path):
29+
tokenizer_path = os.path.join(file_path, "tokenizer.json")
30+
tokenizer_config_path = os.path.join(file_path, "tokenizer_config.json")
31+
else:
32+
tokenizer_path = file_path
33+
tokenizer_config_path = os.path.join(os.path.dirname(file_path), "tokenizer_config.json")
34+
if not os.path.isfile(tokenizer_path):
35+
tokenizer_config_path = None
36+
37+
# Load the tokenizer itself
38+
self._tokenizer = Tokenizer.from_file(tokenizer_path)
39+
40+
# If available, parse bos/eos tokens from the tokenizer config
41+
self._bos_id, self._eos_id = None, None
42+
if tokenizer_config_path is not None:
43+
with open(tokenizer_config_path, "r") as handle:
44+
tok_config = json.load(handle)
45+
bos_token = tok_config.get("bos_token")
46+
eos_token = tok_config.get("eos_token")
47+
if bos_token is not None:
48+
self._bos_id = self._tokenizer.token_to_id(bos_token)
49+
if eos_token is not None:
50+
self._eos_id = self._tokenizer.token_to_id(eos_token)
51+
52+
# If no eos/bos tokens found, go looking for them!
53+
if None in [self._bos_id, self._eos_id]:
54+
tok_content = json.loads(self._tokenizer.to_str())
55+
if self._bos_id is None:
56+
self._bos_id = self._look_for_special_token(tok_content, ["begin", "text"])
57+
if self._eos_id is None:
58+
self._eos_id = self._look_for_special_token(tok_content, ["end", "text"])
59+
60+
assert None not in [self._bos_id, self._eos_id], "Unable to find an BOS/EOS tokens"
61+
62+
@staticmethod
63+
def _look_for_special_token(added_tokens: dict, search_strs: List[str]) -> Optional[int]:
64+
candidate_toks = added_tokens
65+
for search_str in search_strs:
66+
candidate_toks = [
67+
tok for tok in candidate_toks
68+
if tok["special"] and search_str in tok["content"]
69+
]
70+
if len(candidate_toks) == 1:
71+
return candidate_toks[0]["id"]
72+
73+
def encode(
74+
self,
75+
s: str,
76+
*,
77+
bos: bool = False,
78+
eos: bool = False,
79+
) -> List[int]:
80+
res = self._tokenizer.encode(s, add_special_tokens=bos).ids
81+
if eos and (not res or res[-1] != self._eos_token):
82+
res.append(self._eos_token)
83+
return res
84+
85+
def decode(self, ids: List[int]) -> str:
86+
return self._tokenizer.decode(ids)
87+
88+
def bos_id(self) -> int:
89+
return self._bos_id
90+
91+
def eos_id(self) -> int:
92+
return self._eos_id

tokenizer/tiktoken.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import tiktoken
2424
from tiktoken.load import load_tiktoken_bpe
2525

26+
from .base import TokenizerBase
27+
2628

2729
logger = getLogger(__name__)
2830

@@ -38,7 +40,7 @@ class Message(TypedDict):
3840
Dialog = Sequence[Message]
3941

4042

41-
class Tokenizer:
43+
class Tokenizer(TokenizerBase):
4244
"""
4345
tokenizing and encoding/decoding text using the Tiktoken tokenizer.
4446
"""

torchchat/cli/builder.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ class TokenizerArgs:
215215
tokenizer_path: Optional[Union[Path, str]] = None
216216
is_sentencepiece: bool = False
217217
is_tiktoken: bool = False
218+
is_hf_tokenizer: bool = False
218219
t: Optional[Any] = None
219220

220221
def __post_init__(self):
@@ -224,6 +225,7 @@ def __post_init__(self):
224225
self.t = TiktokenTokenizer(model_path=str(self.tokenizer_path))
225226
self.is_tiktoken = True
226227
self.is_sentencepiece = False
228+
self.is_hf_tokenizer = False
227229
return
228230
except:
229231
pass
@@ -234,12 +236,25 @@ def __post_init__(self):
234236
self.t = SentencePieceProcessor(model_file=str(self.tokenizer_path))
235237
self.is_tiktoken = False
236238
self.is_sentencepiece = True
239+
self.is_hf_tokenizer = False
240+
return
241+
except:
242+
pass
243+
244+
try:
245+
from tokenizer.hf_tokenizer import HFTokenizer
246+
247+
self.t = HFTokenizer(str(self.tokenizer_path))
248+
self.is_tiktoken = False
249+
self.is_sentencepiece = False
250+
self.is_hf_tokenizer = True
237251
return
238252
except:
239253
pass
240254

241255
self.is_tiktoken = False
242256
self.is_sentencepiece = False
257+
self.is_hf_tokenizer = False
243258
self.t = None
244259
return
245260

@@ -251,16 +266,27 @@ def validate_model(
251266
if model is None:
252267
return
253268

254-
if self.is_tiktoken == self.is_sentencepiece:
269+
if sum([self.is_tiktoken, self.is_hf_tokenizer, self.is_sentencepiece]) != 1:
255270
raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}")
256271

257272
is_tiktoken = self.is_tiktoken
258273
is_sentencepiece = self.is_sentencepiece
274+
is_hf_tokenizer = self.is_hf_tokenizer
259275
use_tiktoken = model.config.use_tiktoken
276+
use_hf_tokenizer = model.config.use_hf_tokenizer
277+
use_sentencepiece = not (use_tiktoken or use_hf_tokenizer)
260278

261-
if not (is_tiktoken == use_tiktoken) or not (is_sentencepiece != use_tiktoken):
279+
if (
280+
(is_tiktoken and not use_tiktoken) or
281+
(is_hf_tokenizer and not use_hf_tokenizer) or
282+
(is_sentencepiece and not use_sentencepiece)
283+
):
262284
raise RuntimeError(
263-
f"model-specified tokenizer ({tokenizer_setting_to_name(use_tiktoken)}) does not match provided tokenizer ({tokenizer_setting_to_name(is_tiktoken)}) for {model_description}"
285+
"model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}".format(
286+
tokenizer_setting_to_name(use_tiktoken, use_hf_tokenizer),
287+
tokenizer_setting_to_name(is_tiktoken, is_hf_tokenizer),
288+
model_description,
289+
)
264290
)
265291

266292
return
@@ -510,6 +536,15 @@ def _load_model(builder_args: BuilderArgs) -> Model:
510536
model = _load_model_default(builder_args)
511537
# model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims)
512538

539+
if builder_args.dso_path or builder_args.aoti_package_path:
540+
# AOTI-compoiled model will load its own weights.
541+
# Release weights here to avoid OOM
542+
import gc
543+
if hasattr(model, "model"):
544+
model.model = None
545+
gc.collect()
546+
torch.cuda.empty_cache()
547+
513548
model = model.to(device=builder_args.device, dtype=builder_args.precision)
514549
return model.eval()
515550

@@ -558,6 +593,12 @@ def _initialize_model(
558593
# attributes will NOT be seen on by AOTI-compiled forward
559594
# function, e.g. calling model.setup_cache will NOT touch
560595
# AOTI compiled and maintained model buffers such as kv_cache.
596+
# Using cpp runner to run AOTI compiled model is recommended.
597+
598+
def do_nothing(max_batch_size, max_seq_length):
599+
pass
600+
model.setup_caches = do_nothing
601+
561602
model.forward = torch._export.aot_load(
562603
str(builder_args.dso_path.absolute()), builder_args.device
563604
)
@@ -591,6 +632,11 @@ def _initialize_model(
591632
aoti_compiled_model = load_package(
592633
str(builder_args.aoti_package_path.absolute())
593634
)
635+
636+
def do_nothing(max_batch_size, max_seq_length):
637+
pass
638+
model.setup_caches = do_nothing
639+
594640
model.forward = aoti_compiled_model
595641
metadata = aoti_compiled_model.get_metadata()
596642
builder_args.device = metadata["AOTI_DEVICE_KEY"]
@@ -655,5 +701,9 @@ def _initialize_model(
655701
return model
656702

657703

658-
def tokenizer_setting_to_name(tiktoken: bool = False) -> str:
659-
return "TikToken" if tiktoken else "SentencePiece"
704+
def tokenizer_setting_to_name(tiktoken: bool, tokenizers: bool) -> str:
705+
if tiktoken:
706+
return "TikToken"
707+
if tokenizers:
708+
return "Tokenizers"
709+
return "SentencePiece"

torchchat/cli/cli.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,16 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import argparse
8+
import importlib.metadata
89
import json
910
import logging
1011
import os
1112
import sys
1213
from pathlib import Path
1314

14-
import torch
15-
16-
from torchchat.cli.download import download_and_convert, is_model_downloaded
17-
1815
from torchchat.utils.build_utils import (
1916
allowable_dtype_names,
2017
allowable_params_table,
21-
get_device_str,
2218
)
2319

2420
logging.basicConfig(level=logging.INFO, format="%(message)s")
@@ -42,6 +38,9 @@
4238

4339
# Handle CLI arguments that are common to a majority of subcommands.
4440
def check_args(args, verb: str) -> None:
41+
# Local import to avoid unnecessary expensive imports
42+
from torchchat.cli.download import download_and_convert, is_model_downloaded
43+
4544
# Handle model download. Skip this for download, since it has slightly
4645
# different semantics.
4746
if (
@@ -498,9 +497,10 @@ def _add_speculative_execution_args(parser) -> None:
498497

499498

500499
def arg_init(args):
501-
if not (torch.__version__ > "2.3"):
500+
torch_version = importlib.metadata.version("torch")
501+
if not torch_version or (torch_version <= "2.3"):
502502
raise RuntimeError(
503-
f"You are using PyTorch {torch.__version__}. At this time, torchchat uses the latest PyTorch technology with high-performance kernels only available in PyTorch nightly until the PyTorch 2.4 release"
503+
f"You are using PyTorch {torch_version}. At this time, torchchat uses the latest PyTorch technology with high-performance kernels only available in PyTorch nightly until the PyTorch 2.4 release"
504504
)
505505

506506
if sys.version_info.major != 3 or sys.version_info.minor < 10:
@@ -521,6 +521,9 @@ def arg_init(args):
521521
raise RuntimeError("Device not supported by ExecuTorch")
522522
args.device = "cpu"
523523
else:
524+
# Localized import to minimize expensive imports
525+
from torchchat.utils.build_utils import get_device_str
526+
524527
args.device = get_device_str(
525528
args.quantize.get("executor", {}).get("accelerator", args.device)
526529
)
@@ -534,5 +537,8 @@ def arg_init(args):
534537
vars(args)["compile_prefill"] = False
535538

536539
if hasattr(args, "seed") and args.seed:
540+
# Localized import to minimize expensive imports
541+
import torch
542+
537543
torch.manual_seed(args.seed)
538544
return args

0 commit comments

Comments
 (0)