5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
# Standard
8
- from typing import List
8
+ from typing import List , Optional
9
9
import json
10
+ import os
10
11
11
12
# Third Party
12
13
from tokenizers import Tokenizer
@@ -21,26 +22,53 @@ class TokenizersTokenizer(TokenizerBase):
21
22
"""
22
23
23
24
def __init__ (self , file_path : str ):
24
- self ._tokenizer = Tokenizer .from_file (file_path )
25
- # The BOS and EOS tokens are not easily visible from the tokenizer
26
- # object itself, so we extract them at construction with a sample call
27
- self ._bos_token = self ._tokenizer .encode ("Test" , add_special_tokens = True ).ids [0 ]
28
- # There is no explicit BOS token in many tokenizers, so we look for a
29
- # single special token that most resembles the BOS token.
30
- self ._eos_token = None
31
- tok_content = json .loads (self ._tokenizer .to_str ())
32
- end_toks = [
33
- tok for tok in tok_content ['added_tokens' ]
34
- if tok ["special" ] and "end" in tok ["content" ]
35
- ]
36
- assert end_toks , "Unable to find an EOS token in the added tokens"
37
- if len (end_toks ) > 1 :
38
- end_text_toks = [
39
- tok for tok in end_toks if "text" in tok ["content" ]
25
+ # If the path is a directory, look for "tokenizer.json" which is
26
+ # standard for transformers checkpoints and also look for the
27
+ # "tokenizer_config.json" file to parse eos/bos tokens
28
+ if os .path .isdir (file_path ):
29
+ tokenizer_path = os .path .join (file_path , "tokenizer.json" )
30
+ tokenizer_config_path = os .path .join (file_path , "tokenizer_config.json" )
31
+ else :
32
+ tokenizer_path = file_path
33
+ tokenizer_config_path = os .path .join (os .path .dirname (file_path ), "tokenizer_config.json" )
34
+ if not os .path .isfile (tokenizer_path ):
35
+ tokenizer_config_path = None
36
+
37
+ # Load the tokenizer itself
38
+ self ._tokenizer = Tokenizer .from_file (tokenizer_path )
39
+
40
+ # If available, parse bos/eos tokens from the tokenizer config
41
+ self ._bos_id , self ._eos_id = None , None
42
+ if tokenizer_config_path is not None :
43
+ with open (tokenizer_config_path , "r" ) as handle :
44
+ tok_config = json .load (handle )
45
+ bos_token = tok_config .get ("bos_token" )
46
+ eos_token = tok_config .get ("eos_token" )
47
+ if bos_token is not None :
48
+ self ._bos_id = self ._tokenizer .token_to_id (bos_token )
49
+ if eos_token is not None :
50
+ self ._eos_id = self ._tokenizer .token_to_id (eos_token )
51
+
52
+ # If no eos/bos tokens found, go looking for them!
53
+ if None in [self ._bos_id , self ._eos_id ]:
54
+ tok_content = json .loads (self ._tokenizer .to_str ())
55
+ if self ._bos_id is None :
56
+ self ._bos_id = self ._look_for_special_token (tok_content , ["begin" , "text" ])
57
+ if self ._eos_id is None :
58
+ self ._eos_id = self ._look_for_special_token (tok_content , ["end" , "text" ])
59
+
60
+ assert None not in [self ._bos_id , self ._eos_id ], "Unable to find an BOS/EOS tokens"
61
+
62
+ @staticmethod
63
+ def _look_for_special_token (added_tokens : dict , search_strs : List [str ]) -> Optional [int ]:
64
+ candidate_toks = added_tokens
65
+ for search_str in search_strs :
66
+ candidate_toks = [
67
+ tok for tok in candidate_toks
68
+ if tok ["special" ] and search_str in tok ["content" ]
40
69
]
41
- if len (end_text_toks ) == 1 :
42
- self ._eos_token = end_text_toks [0 ]["id" ]
43
- assert self ._eos_token is not None , "Unable to find an EOS token in the added tokens"
70
+ if len (candidate_toks ) == 1 :
71
+ return candidate_toks [0 ]["id" ]
44
72
45
73
def encode (
46
74
self ,
@@ -58,7 +86,7 @@ def decode(self, ids: List[int]) -> str:
58
86
return self ._tokenizer .decode (ids )
59
87
60
88
def bos_id (self ) -> int :
61
- return self ._bos_token
89
+ return self ._bos_id
62
90
63
91
def eos_id (self ) -> int :
64
- return self ._eos_token
92
+ return self ._eos_id
0 commit comments