@@ -272,7 +272,9 @@ class TransformerArgs:
272
272
norm_eps : float = 1e-5
273
273
multiple_of : int = 256
274
274
ffn_dim_multiplier : Optional [int ] = None
275
+ # Select the desired tokenizer. Defaults to sentencepiece
275
276
use_tiktoken : bool = False
277
+ use_tokenizers : bool = False
276
278
max_seq_length : int = 8192
277
279
rope_scaling : Optional [Dict [str , Any ]] = None
278
280
# For pipeline parallel
@@ -329,12 +331,14 @@ class ModelArgs:
329
331
model_type : ModelType
330
332
transformer_args : Dict [str , Dict [str , Any ]]
331
333
use_tiktoken : bool
334
+ use_tokenizers : bool
332
335
333
336
def __init__ (
334
337
self ,
335
338
transformer_args : Dict [str , Dict [str , Any ]],
336
339
model_type : ModelType = ModelType .TextOnly ,
337
340
use_tiktoken : bool = False ,
341
+ use_tokenizers : bool = False ,
338
342
) -> None :
339
343
self ._sanity_check (transformer_args , model_type )
340
344
@@ -343,6 +347,7 @@ def __init__(
343
347
344
348
# Model-level attributes
345
349
self .use_tiktoken = use_tiktoken
350
+ self .use_tokenizers = use_tokenizers
346
351
347
352
def _sanity_check (
348
353
self ,
@@ -369,7 +374,8 @@ def from_params(cls, params_path):
369
374
}
370
375
371
376
use_tiktoken = loaded_params .get ("use_tiktoken" , False )
372
- return cls (transformer_args , model_type , use_tiktoken )
377
+ use_tokenizers = loaded_params .get ("use_tokenizers" , False )
378
+ return cls (transformer_args , model_type , use_tiktoken , use_tokenizers )
373
379
374
380
@classmethod
375
381
def from_table (cls , name : str ):
0 commit comments