@@ -270,7 +270,9 @@ class TransformerArgs:
270
270
norm_eps : float = 1e-5
271
271
multiple_of : int = 256
272
272
ffn_dim_multiplier : Optional [int ] = None
273
+ # Select the desired tokenizer. Defaults to sentencepiece
273
274
use_tiktoken : bool = False
275
+ use_tokenizers : bool = False
274
276
max_seq_length : int = 8192
275
277
rope_scaling : Optional [Dict [str , Any ]] = None
276
278
# For pipeline parallel
@@ -327,12 +329,14 @@ class ModelArgs:
327
329
model_type : ModelType
328
330
transformer_args : Dict [str , Dict [str , Any ]]
329
331
use_tiktoken : bool
332
+ use_tokenizers : bool
330
333
331
334
def __init__ (
332
335
self ,
333
336
transformer_args : Dict [str , Dict [str , Any ]],
334
337
model_type : ModelType = ModelType .TextOnly ,
335
338
use_tiktoken : bool = False ,
339
+ use_tokenizers : bool = False ,
336
340
) -> None :
337
341
self ._sanity_check (transformer_args , model_type )
338
342
@@ -341,6 +345,7 @@ def __init__(
341
345
342
346
# Model-level attributes
343
347
self .use_tiktoken = use_tiktoken
348
+ self .use_tokenizers = use_tokenizers
344
349
345
350
def _sanity_check (
346
351
self ,
@@ -367,7 +372,8 @@ def from_params(cls, params_path):
367
372
}
368
373
369
374
use_tiktoken = loaded_params .get ("use_tiktoken" , False )
370
- return cls (transformer_args , model_type , use_tiktoken )
375
+ use_tokenizers = loaded_params .get ("use_tokenizers" , False )
376
+ return cls (transformer_args , model_type , use_tiktoken , use_tokenizers )
371
377
372
378
@classmethod
373
379
def from_table (cls , name : str ):
0 commit comments