@@ -193,6 +193,7 @@ class TokenizerArgs:
193
193
tokenizer_path : Optional [Union [Path , str ]] = None
194
194
is_sentencepiece : bool = False
195
195
is_tiktoken : bool = False
196
+ is_tokenizers : bool = False
196
197
t : Optional [Any ] = None
197
198
198
199
def __post_init__ (self ):
@@ -202,6 +203,7 @@ def __post_init__(self):
202
203
self .t = TiktokenTokenizer (model_path = str (self .tokenizer_path ))
203
204
self .is_tiktoken = True
204
205
self .is_sentencepiece = False
206
+ self .is_tokenizers = False
205
207
return
206
208
except :
207
209
pass
@@ -212,12 +214,25 @@ def __post_init__(self):
212
214
self .t = SentencePieceProcessor (model_file = str (self .tokenizer_path ))
213
215
self .is_tiktoken = False
214
216
self .is_sentencepiece = True
217
+ self .is_tokenizers = False
218
+ return
219
+ except :
220
+ pass
221
+
222
+ try :
223
+ from tokenizer .tokenizers import TokenizersTokenizer
224
+
225
+ self .t = TokenizersTokenizer (str (self .tokenizer_path ))
226
+ self .is_tiktoken = False
227
+ self .is_sentencepiece = False
228
+ self .is_tokenizers = True
215
229
return
216
230
except :
217
231
pass
218
232
219
233
self .is_tiktoken = False
220
234
self .is_sentencepiece = False
235
+ self .is_tokenizers = False
221
236
self .t = None
222
237
return
223
238
@@ -229,16 +244,27 @@ def validate_model(
229
244
if model is None :
230
245
return
231
246
232
- if self .is_tiktoken == self .is_sentencepiece :
247
+ if len ( list ( filter ( lambda x : x , [ self .is_tiktoken , self . is_tokenizers , self .is_sentencepiece ]))) != 1 :
233
248
raise RuntimeError (f"no tokenizer was found at { self .tokenizer_path } " )
234
249
235
250
is_tiktoken = self .is_tiktoken
236
251
is_sentencepiece = self .is_sentencepiece
252
+ is_tokenizers = self .is_tokenizers
237
253
use_tiktoken = model .config .use_tiktoken
254
+ use_tokenizers = model .config .use_tokenizers
255
+ use_sentencepiece = not (use_tiktoken or use_tokenizers )
238
256
239
- if not (is_tiktoken == use_tiktoken ) or not (is_sentencepiece != use_tiktoken ):
257
+ if (
258
+ (is_tiktoken and not use_tiktoken ) or
259
+ (is_tokenizers and not use_tokenizers ) or
260
+ (is_sentencepiece and not use_sentencepiece )
261
+ ):
240
262
raise RuntimeError (
241
- f"model-specified tokenizer ({ tokenizer_setting_to_name (use_tiktoken )} ) does not match provided tokenizer ({ tokenizer_setting_to_name (is_tiktoken )} ) for { model_description } "
263
+ "model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}" .format (
264
+ tokenizer_setting_to_name (use_tiktoken , use_tokenizers ),
265
+ tokenizer_setting_to_name (is_tiktoken , is_tokenizers ),
266
+ model_description ,
267
+ )
242
268
)
243
269
244
270
return
@@ -594,5 +620,9 @@ def _initialize_model(
594
620
return model
595
621
596
622
597
- def tokenizer_setting_to_name (tiktoken : bool = False ) -> str :
598
- return "TikToken" if tiktoken else "SentencePiece"
623
+ def tokenizer_setting_to_name (tiktoken : bool , tokenizers : bool ) -> str :
624
+ if tiktoken :
625
+ return "TikToken"
626
+ if tokenizers :
627
+ return "Tokenizers"
628
+ return "SentencePiece"
0 commit comments