Skip to content

Commit f2cba4c

Browse files
committed
fix(tokenizers): Fix how bos/eos tokens are parsed from tokenizers (lib)
Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 79e4ccb commit f2cba4c

File tree

1 file changed

+50
-22
lines changed

1 file changed

+50
-22
lines changed

tokenizer/tokenizers.py

Lines changed: 50 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# Standard
8-
from typing import List
8+
from typing import List, Optional
99
import json
10+
import os
1011

1112
# Third Party
1213
from tokenizers import Tokenizer
@@ -21,26 +22,53 @@ class TokenizersTokenizer(TokenizerBase):
2122
"""
2223

2324
def __init__(self, file_path: str):
24-
self._tokenizer = Tokenizer.from_file(file_path)
25-
# The BOS and EOS tokens are not easily visible from the tokenizer
26-
# object itself, so we extract them at construction with a sample call
27-
self._bos_token = self._tokenizer.encode("Test", add_special_tokens=True).ids[0]
28-
# There is no explicit BOS token in many tokenizers, so we look for a
29-
# single special token that most resembles the BOS token.
30-
self._eos_token = None
31-
tok_content = json.loads(self._tokenizer.to_str())
32-
end_toks = [
33-
tok for tok in tok_content['added_tokens']
34-
if tok["special"] and "end" in tok["content"]
35-
]
36-
assert end_toks, "Unable to find an EOS token in the added tokens"
37-
if len(end_toks) > 1:
38-
end_text_toks = [
39-
tok for tok in end_toks if "text" in tok["content"]
25+
# If the path is a directory, look for "tokenizer.json" which is
26+
# standard for transformers checkpoints and also look for the
27+
# "tokenizer_config.json" file to parse eos/bos tokens
28+
if os.path.isdir(file_path):
29+
tokenizer_path = os.path.join(file_path, "tokenizer.json")
30+
tokenizer_config_path = os.path.join(file_path, "tokenizer_config.json")
31+
else:
32+
tokenizer_path = file_path
33+
tokenizer_config_path = os.path.join(os.path.dirname(file_path), "tokenizer_config.json")
34+
if not os.path.isfile(tokenizer_path):
35+
tokenizer_config_path = None
36+
37+
# Load the tokenizer itself
38+
self._tokenizer = Tokenizer.from_file(tokenizer_path)
39+
40+
# If available, parse bos/eos tokens from the tokenizer config
41+
self._bos_id, self._eos_id = None, None
42+
if tokenizer_config_path is not None:
43+
with open(tokenizer_config_path, "r") as handle:
44+
tok_config = json.load(handle)
45+
bos_token = tok_config.get("bos_token")
46+
eos_token = tok_config.get("eos_token")
47+
if bos_token is not None:
48+
self._bos_id = self._tokenizer.token_to_id(bos_token)
49+
if eos_token is not None:
50+
self._eos_id = self._tokenizer.token_to_id(eos_token)
51+
52+
# If no eos/bos tokens found, go looking for them!
53+
if None in [self._bos_id, self._eos_id]:
54+
tok_content = json.loads(self._tokenizer.to_str())
55+
if self._bos_id is None:
56+
self._bos_id = self._look_for_special_token(tok_content, ["begin", "text"])
57+
if self._eos_id is None:
58+
self._eos_id = self._look_for_special_token(tok_content, ["end", "text"])
59+
60+
assert None not in [self._bos_id, self._eos_id], "Unable to find an BOS/EOS tokens"
61+
62+
@staticmethod
63+
def _look_for_special_token(added_tokens: dict, search_strs: List[str]) -> Optional[int]:
64+
candidate_toks = added_tokens
65+
for search_str in search_strs:
66+
candidate_toks = [
67+
tok for tok in candidate_toks
68+
if tok["special"] and search_str in tok["content"]
4069
]
41-
if len(end_text_toks) == 1:
42-
self._eos_token = end_text_toks[0]["id"]
43-
assert self._eos_token is not None, "Unable to find an EOS token in the added tokens"
70+
if len(candidate_toks) == 1:
71+
return candidate_toks[0]["id"]
4472

4573
def encode(
4674
self,
@@ -58,7 +86,7 @@ def decode(self, ids: List[int]) -> str:
5886
return self._tokenizer.decode(ids)
5987

6088
def bos_id(self) -> int:
61-
return self._bos_token
89+
return self._bos_id
6290

6391
def eos_id(self) -> int:
64-
return self._eos_token
92+
return self._eos_id

0 commit comments

Comments
 (0)