Skip to content

Commit 9480258

Browse files
authored
Add HuggingFace Tokenizer support for Granite Code (#1261)
* feat(tokenizer): Add an abstract base class for additional tokenizer support Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]> * feat(tokenizers): Add a python impl of the Tokenizer interface using tokenizers This allows for all HF tokenizers to be supported in the python layer. It will need significant work to offer similar compatibility at the c++ layer. Signed-off-by: Gabe Goodhart <[email protected]> * feat(builder): Add support for using the TokenizersTokenizer in builder Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]> * feat(tokenizers): Add and plumb the option to use the "tokenizers" tokenizer Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]> * fix(tokenizers): Fix how bos/eos tokens are parsed from tokenizers (lib) Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]> * fix(hf_tokenizer): Rename to HFTokenizer and corresponding flags #1251 Branch: TokenizersTokenizer-1251 Co-Authored-By: [email protected] Signed-off-by: Gabe Goodhart <[email protected]> --------- Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 4510ba0 commit 9480258

File tree

5 files changed

+169
-7
lines changed

5 files changed

+169
-7
lines changed

tokenizer/base.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
"""
7+
Abstract base class for all tokenizer classes in python matching c++ interface.
8+
"""
9+
10+
# Standard
11+
from abc import ABC, abstractmethod
12+
from typing import List
13+
14+
15+
class TokenizerBase(ABC):
16+
__doc__ = __doc__
17+
18+
@abstractmethod
19+
def encode(self, s: str, *, bos: bool = False, eos: bool = False) -> List[int]:
20+
"""Encode the given string and optionally include bos/eos tokens"""
21+
22+
@abstractmethod
23+
def decode(self, ids: List[int]) -> str:
24+
"""Decode the given token ids into a string"""
25+
26+
@abstractmethod
27+
def bos_id(self) -> int:
28+
"""The id of the begin-of-string token"""
29+
30+
@abstractmethod
31+
def eos_id(self) -> int:
32+
"""The id of the end-of-string token"""

tokenizer/hf_tokenizer.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Standard
8+
from typing import List, Optional
9+
import json
10+
import os
11+
12+
# Third Party
13+
from tokenizers import Tokenizer
14+
15+
# Local
16+
from .base import TokenizerBase
17+
18+
19+
class HFTokenizer(TokenizerBase):
20+
"""
21+
Wrapper around the Huggingface `tokenizers` library for API compatibility
22+
"""
23+
24+
def __init__(self, file_path: str):
25+
# If the path is a directory, look for "tokenizer.json" which is
26+
# standard for transformers checkpoints and also look for the
27+
# "tokenizer_config.json" file to parse eos/bos tokens
28+
if os.path.isdir(file_path):
29+
tokenizer_path = os.path.join(file_path, "tokenizer.json")
30+
tokenizer_config_path = os.path.join(file_path, "tokenizer_config.json")
31+
else:
32+
tokenizer_path = file_path
33+
tokenizer_config_path = os.path.join(os.path.dirname(file_path), "tokenizer_config.json")
34+
if not os.path.isfile(tokenizer_path):
35+
tokenizer_config_path = None
36+
37+
# Load the tokenizer itself
38+
self._tokenizer = Tokenizer.from_file(tokenizer_path)
39+
40+
# If available, parse bos/eos tokens from the tokenizer config
41+
self._bos_id, self._eos_id = None, None
42+
if tokenizer_config_path is not None:
43+
with open(tokenizer_config_path, "r") as handle:
44+
tok_config = json.load(handle)
45+
bos_token = tok_config.get("bos_token")
46+
eos_token = tok_config.get("eos_token")
47+
if bos_token is not None:
48+
self._bos_id = self._tokenizer.token_to_id(bos_token)
49+
if eos_token is not None:
50+
self._eos_id = self._tokenizer.token_to_id(eos_token)
51+
52+
# If no eos/bos tokens found, go looking for them!
53+
if None in [self._bos_id, self._eos_id]:
54+
tok_content = json.loads(self._tokenizer.to_str())
55+
if self._bos_id is None:
56+
self._bos_id = self._look_for_special_token(tok_content, ["begin", "text"])
57+
if self._eos_id is None:
58+
self._eos_id = self._look_for_special_token(tok_content, ["end", "text"])
59+
60+
assert None not in [self._bos_id, self._eos_id], "Unable to find an BOS/EOS tokens"
61+
62+
@staticmethod
63+
def _look_for_special_token(added_tokens: dict, search_strs: List[str]) -> Optional[int]:
64+
candidate_toks = added_tokens
65+
for search_str in search_strs:
66+
candidate_toks = [
67+
tok for tok in candidate_toks
68+
if tok["special"] and search_str in tok["content"]
69+
]
70+
if len(candidate_toks) == 1:
71+
return candidate_toks[0]["id"]
72+
73+
def encode(
74+
self,
75+
s: str,
76+
*,
77+
bos: bool = False,
78+
eos: bool = False,
79+
) -> List[int]:
80+
res = self._tokenizer.encode(s, add_special_tokens=bos).ids
81+
if eos and (not res or res[-1] != self._eos_token):
82+
res.append(self._eos_token)
83+
return res
84+
85+
def decode(self, ids: List[int]) -> str:
86+
return self._tokenizer.decode(ids)
87+
88+
def bos_id(self) -> int:
89+
return self._bos_id
90+
91+
def eos_id(self) -> int:
92+
return self._eos_id

tokenizer/tiktoken.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import tiktoken
2424
from tiktoken.load import load_tiktoken_bpe
2525

26+
from .base import TokenizerBase
27+
2628

2729
logger = getLogger(__name__)
2830

@@ -38,7 +40,7 @@ class Message(TypedDict):
3840
Dialog = Sequence[Message]
3941

4042

41-
class Tokenizer:
43+
class Tokenizer(TokenizerBase):
4244
"""
4345
tokenizing and encoding/decoding text using the Tiktoken tokenizer.
4446
"""

torchchat/cli/builder.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ class TokenizerArgs:
215215
tokenizer_path: Optional[Union[Path, str]] = None
216216
is_sentencepiece: bool = False
217217
is_tiktoken: bool = False
218+
is_hf_tokenizer: bool = False
218219
t: Optional[Any] = None
219220

220221
def __post_init__(self):
@@ -224,6 +225,7 @@ def __post_init__(self):
224225
self.t = TiktokenTokenizer(model_path=str(self.tokenizer_path))
225226
self.is_tiktoken = True
226227
self.is_sentencepiece = False
228+
self.is_hf_tokenizer = False
227229
return
228230
except:
229231
pass
@@ -234,12 +236,25 @@ def __post_init__(self):
234236
self.t = SentencePieceProcessor(model_file=str(self.tokenizer_path))
235237
self.is_tiktoken = False
236238
self.is_sentencepiece = True
239+
self.is_hf_tokenizer = False
240+
return
241+
except:
242+
pass
243+
244+
try:
245+
from tokenizer.hf_tokenizer import HFTokenizer
246+
247+
self.t = HFTokenizer(str(self.tokenizer_path))
248+
self.is_tiktoken = False
249+
self.is_sentencepiece = False
250+
self.is_hf_tokenizer = True
237251
return
238252
except:
239253
pass
240254

241255
self.is_tiktoken = False
242256
self.is_sentencepiece = False
257+
self.is_hf_tokenizer = False
243258
self.t = None
244259
return
245260

@@ -251,16 +266,27 @@ def validate_model(
251266
if model is None:
252267
return
253268

254-
if self.is_tiktoken == self.is_sentencepiece:
269+
if sum([self.is_tiktoken, self.is_hf_tokenizer, self.is_sentencepiece]) != 1:
255270
raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}")
256271

257272
is_tiktoken = self.is_tiktoken
258273
is_sentencepiece = self.is_sentencepiece
274+
is_hf_tokenizer = self.is_hf_tokenizer
259275
use_tiktoken = model.config.use_tiktoken
276+
use_hf_tokenizer = model.config.use_hf_tokenizer
277+
use_sentencepiece = not (use_tiktoken or use_hf_tokenizer)
260278

261-
if not (is_tiktoken == use_tiktoken) or not (is_sentencepiece != use_tiktoken):
279+
if (
280+
(is_tiktoken and not use_tiktoken) or
281+
(is_hf_tokenizer and not use_hf_tokenizer) or
282+
(is_sentencepiece and not use_sentencepiece)
283+
):
262284
raise RuntimeError(
263-
f"model-specified tokenizer ({tokenizer_setting_to_name(use_tiktoken)}) does not match provided tokenizer ({tokenizer_setting_to_name(is_tiktoken)}) for {model_description}"
285+
"model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}".format(
286+
tokenizer_setting_to_name(use_tiktoken, use_hf_tokenizer),
287+
tokenizer_setting_to_name(is_tiktoken, is_hf_tokenizer),
288+
model_description,
289+
)
264290
)
265291

266292
return
@@ -655,5 +681,9 @@ def _initialize_model(
655681
return model
656682

657683

658-
def tokenizer_setting_to_name(tiktoken: bool = False) -> str:
659-
return "TikToken" if tiktoken else "SentencePiece"
684+
def tokenizer_setting_to_name(tiktoken: bool, tokenizers: bool) -> str:
685+
if tiktoken:
686+
return "TikToken"
687+
if tokenizers:
688+
return "Tokenizers"
689+
return "SentencePiece"

torchchat/model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,9 @@ class TransformerArgs:
270270
norm_eps: float = 1e-5
271271
multiple_of: int = 256
272272
ffn_dim_multiplier: Optional[int] = None
273+
# Select the desired tokenizer. Defaults to sentencepiece
273274
use_tiktoken: bool = False
275+
use_hf_tokenizer: bool = False
274276
max_seq_length: int = 8192
275277
rope_scaling: Optional[Dict[str, Any]] = None
276278
# For pipeline parallel
@@ -327,12 +329,14 @@ class ModelArgs:
327329
model_type: ModelType
328330
transformer_args: Dict[str, Dict[str, Any]]
329331
use_tiktoken: bool
332+
use_hf_tokenizer: bool
330333

331334
def __init__(
332335
self,
333336
transformer_args: Dict[str, Dict[str, Any]],
334337
model_type: ModelType = ModelType.TextOnly,
335338
use_tiktoken: bool = False,
339+
use_hf_tokenizer: bool = False,
336340
) -> None:
337341
self._sanity_check(transformer_args, model_type)
338342

@@ -341,6 +345,7 @@ def __init__(
341345

342346
# Model-level attributes
343347
self.use_tiktoken = use_tiktoken
348+
self.use_hf_tokenizer = use_hf_tokenizer
344349

345350
def _sanity_check(
346351
self,
@@ -367,7 +372,8 @@ def from_params(cls, params_path):
367372
}
368373

369374
use_tiktoken = loaded_params.get("use_tiktoken", False)
370-
return cls(transformer_args, model_type, use_tiktoken)
375+
use_hf_tokenizer = loaded_params.get("use_hf_tokenizer", False)
376+
return cls(transformer_args, model_type, use_tiktoken, use_hf_tokenizer)
371377

372378
@classmethod
373379
def from_table(cls, name: str):

0 commit comments

Comments
 (0)