@@ -190,6 +190,7 @@ class TokenizerArgs:
190
190
tokenizer_path : Optional [Union [Path , str ]] = None
191
191
is_sentencepiece : bool = False
192
192
is_tiktoken : bool = False
193
+ is_tokenizers : bool = False
193
194
t : Optional [Any ] = None
194
195
195
196
def __post_init__ (self ):
@@ -199,6 +200,7 @@ def __post_init__(self):
199
200
self .t = TiktokenTokenizer (model_path = str (self .tokenizer_path ))
200
201
self .is_tiktoken = True
201
202
self .is_sentencepiece = False
203
+ self .is_tokenizers = False
202
204
return
203
205
except :
204
206
pass
@@ -209,12 +211,25 @@ def __post_init__(self):
209
211
self .t = SentencePieceProcessor (model_file = str (self .tokenizer_path ))
210
212
self .is_tiktoken = False
211
213
self .is_sentencepiece = True
214
+ self .is_tokenizers = False
215
+ return
216
+ except :
217
+ pass
218
+
219
+ try :
220
+ from tokenizer .tokenizers import TokenizersTokenizer
221
+
222
+ self .t = TokenizersTokenizer (str (self .tokenizer_path ))
223
+ self .is_tiktoken = False
224
+ self .is_sentencepiece = False
225
+ self .is_tokenizers = True
212
226
return
213
227
except :
214
228
pass
215
229
216
230
self .is_tiktoken = False
217
231
self .is_sentencepiece = False
232
+ self .is_tokenizers = False
218
233
self .t = None
219
234
return
220
235
@@ -226,16 +241,27 @@ def validate_model(
226
241
if model is None :
227
242
return
228
243
229
- if self .is_tiktoken == self .is_sentencepiece :
244
+ if len ( list ( filter ( lambda x : x , [ self .is_tiktoken , self . is_tokenizers , self .is_sentencepiece ]))) != 1 :
230
245
raise RuntimeError (f"no tokenizer was found at { self .tokenizer_path } " )
231
246
232
247
is_tiktoken = self .is_tiktoken
233
248
is_sentencepiece = self .is_sentencepiece
249
+ is_tokenizers = self .is_tokenizers
234
250
use_tiktoken = model .config .use_tiktoken
251
+ use_tokenizers = model .config .use_tokenizers
252
+ use_sentencepiece = not (use_tiktoken or use_tokenizers )
235
253
236
- if not (is_tiktoken == use_tiktoken ) or not (is_sentencepiece != use_tiktoken ):
254
+ if (
255
+ (is_tiktoken and not use_tiktoken ) or
256
+ (is_tokenizers and not use_tokenizers ) or
257
+ (is_sentencepiece and not use_sentencepiece )
258
+ ):
237
259
raise RuntimeError (
238
- f"model-specified tokenizer ({ tokenizer_setting_to_name (use_tiktoken )} ) does not match provided tokenizer ({ tokenizer_setting_to_name (is_tiktoken )} ) for { model_description } "
260
+ "model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}" .format (
261
+ tokenizer_setting_to_name (use_tiktoken , use_tokenizers ),
262
+ tokenizer_setting_to_name (is_tiktoken , is_tokenizers ),
263
+ model_description ,
264
+ )
239
265
)
240
266
241
267
return
@@ -591,5 +617,9 @@ def _initialize_model(
591
617
return model
592
618
593
619
594
- def tokenizer_setting_to_name (tiktoken : bool = False ) -> str :
595
- return "TikToken" if tiktoken else "SentencePiece"
620
+ def tokenizer_setting_to_name (tiktoken : bool , tokenizers : bool ) -> str :
621
+ if tiktoken :
622
+ return "TikToken"
623
+ if tokenizers :
624
+ return "Tokenizers"
625
+ return "SentencePiece"
0 commit comments