@@ -107,7 +107,7 @@ def detokenize(self, ids: list[int], special: bool = False) -> str:
107
107
while num < 0 and len (self .text_buff ) < (16 << 20 ):
108
108
self .text_buff = self .ffi .new ("uint8_t[]" , - 2 * num )
109
109
num = self .lib .llama_detokenize (self .model , self .token_ids , len (ids ), self .text_buff , len (self .text_buff ), special )
110
- return str (self .ffi .buffer (self .text_buff , num ), encoding = "utf-8" )
110
+ return str (self .ffi .buffer (self .text_buff , num ), encoding = "utf-8" , errors = "replace" ) # replace errors with '\uFFFD'
111
111
112
112
113
113
class Tokenizer :
@@ -144,7 +144,7 @@ def encode(self, text: str) -> list[int]:
144
144
return self .model .encode (text , add_special_tokens = True )
145
145
146
146
def decode (self , ids : list [int ]) -> str :
147
- return self .model .decode (ids , skip_special_tokens = True )
147
+ return self .model .decode (ids , skip_special_tokens = False )
148
148
149
149
150
150
class TokenizerLlamaCpp (Tokenizer ):
@@ -160,7 +160,7 @@ def encode(self, text: str) -> list[int]:
160
160
return self .model .tokenize (text , add_special = True , parse_special = True )
161
161
162
162
def decode (self , ids : list [int ]) -> str :
163
- return self .model .detokenize (ids , special = False )
163
+ return self .model .detokenize (ids , special = True )
164
164
165
165
166
166
def generator_custom_text () -> Iterator [str ]:
@@ -232,6 +232,9 @@ def generator_custom_text_edge_cases() -> Iterator[str]:
232
232
'\xa0 aC' , # deepseek
233
233
'\u2029 \uA3E4 ' , # deepseek-llm
234
234
"a ?" ,
235
+ 'å' , # mpt
236
+ '\U000ac517 ' , # utf-8 encode error, falcon
237
+ '\U000522f4 ' , # utf-8 encode error, starcoder
235
238
]
236
239
237
240
@@ -265,7 +268,7 @@ def generator_apostrophe() -> Iterator[str]:
265
268
266
269
267
270
def generator_added_lr_strip (tokenizer : TokenizerGroundtruth ) -> Iterator [str ]:
268
- WHITESPACES = ["" , " " , " " , " " ]
271
+ WHITESPACES = ["" , " " , " " , "\n " , " \r \n " , " \n \n " , " \t " , " \t \t " ]
269
272
all_tokens = list (sorted (set (tokenizer .special_tokens + tokenizer .added_tokens )))
270
273
for token in all_tokens :
271
274
for lstrip in WHITESPACES :
@@ -329,11 +332,9 @@ def generator_unicodes() -> Iterator[str]:
329
332
def _valid (cpt ):
330
333
if cpt >= 0x30000 : # unassigned and supplementary
331
334
return False
332
- if 0x00D800 <= cpt <= 0x00F8FF : # Surrogates
333
- return False
334
335
# if cpt == 0x2029: # deepseek-llm
335
336
# return False
336
- if unicodedata .category (chr (cpt )) == "Cn" :
337
+ if unicodedata .category (chr (cpt )) in ( "Cn" , "Cs" , "Co" ): # undefined, surrogates, private
337
338
return False
338
339
return True
339
340
@@ -396,7 +397,7 @@ def generator_random_vocab_words(tokenizer: TokenizerGroundtruth, iterations=100
396
397
yield "" .join (text )
397
398
398
399
399
- def compare_tokenizers (tokenizer1 : Tokenizer , tokenizer2 : Tokenizer , generator : Iterator [str ]):
400
+ def compare_tokenizers (tokenizer1 : TokenizerGroundtruth , tokenizer2 : TokenizerLlamaCpp , generator : Iterator [str ]):
400
401
401
402
def find_first_mismatch (ids1 : list [int ], ids2 : list [int ]):
402
403
for i , (a , b ) in enumerate (zip (ids1 , ids2 )):
@@ -406,12 +407,25 @@ def find_first_mismatch(ids1: list[int], ids2: list[int]):
406
407
return - 1
407
408
return min (len (ids1 ), len (ids2 ))
408
409
410
+ def check_detokenizer (text : str , text1 : str , text2 : str ) -> bool :
411
+ if text1 == text2 : # equal to TokenizerGroundtruth?
412
+ return True
413
+ # equal to source text?
414
+ if tokenizer1 .add_bos_token : # remove BOS
415
+ if text2 .startswith (tokenizer1 .bos_token ):
416
+ text2 = text2 [len (tokenizer1 .bos_token ):]
417
+ if tokenizer1 .add_eos_token : # remove EOS
418
+ if text2 .endswith (tokenizer1 .eos_token ):
419
+ text2 = text2 [:- len (tokenizer1 .eos_token )]
420
+ return text == text2
421
+
409
422
t_encode1 = 0
410
423
t_encode2 = 0
411
424
t_decode1 = 0
412
425
t_decode2 = 0
413
426
t_start = time .perf_counter ()
414
- num_errors = 0
427
+ encode_errors = 0
428
+ decode_errors = 0
415
429
416
430
logger .info ("%s: %s" % (generator .__name__ , "ini" ))
417
431
for text in generator :
@@ -424,7 +438,7 @@ def find_first_mismatch(ids1: list[int], ids2: list[int]):
424
438
t2 = time .perf_counter ()
425
439
text1 = tokenizer1 .decode (ids1 )
426
440
t3 = time .perf_counter ()
427
- text2 = tokenizer2 .decode (ids2 )
441
+ text2 = tokenizer2 .decode (ids1 )
428
442
t4 = time .perf_counter ()
429
443
t_encode1 += t1 - t0
430
444
t_encode2 += t2 - t1
@@ -436,16 +450,18 @@ def find_first_mismatch(ids1: list[int], ids2: list[int]):
436
450
ids2 = list (ids2 )[max (0 , i - 2 ) : i + 5 + 1 ]
437
451
logger .error (" Expected: " + str (ids1 ))
438
452
logger .error (" Result: " + str (ids2 ))
439
- num_errors += 1
440
- if text1 != text2 and text != text2 :
453
+ encode_errors += 1
454
+ logger .error (f" { encode_errors = } " )
455
+ if not check_detokenizer (text , text1 , text2 ):
441
456
i = find_first_mismatch (text1 , text2 )
442
457
text1 = list (text1 [max (0 , i - 2 ) : i + 5 + 1 ])
443
458
text2 = list (text2 [max (0 , i - 2 ) : i + 5 + 1 ])
444
459
logger .error (" Expected: " + " " .join (hex (ord (x )) for x in text1 ))
445
460
logger .error (" Result: " + " " .join (hex (ord (x )) for x in text2 ))
446
- num_errors += 1
447
- if num_errors >= 10 :
448
- logger .error (f" EXIT: { num_errors = } " )
461
+ decode_errors += 1
462
+ logger .error (f" { decode_errors = } " )
463
+ if encode_errors >= 10 or decode_errors >= 10 :
464
+ logger .error (f" EXIT: { encode_errors = } { decode_errors = } " )
449
465
# raise Exception()
450
466
break
451
467
@@ -504,6 +520,7 @@ def main(argv: list[str] = None):
504
520
tokenizers = [
505
521
"llama-spm" , # SPM
506
522
"phi-3" , # SPM
523
+ "baichuan" , # SPM
507
524
"bert-bge" , # WPM
508
525
"jina-v2-en" , # WPM
509
526
"llama-bpe" , # BPE
0 commit comments