@@ -235,6 +235,8 @@ def generator_custom_text_edge_cases() -> Iterator[str]:
235
235
'å' , # mpt
236
236
'\U000ac517 ' , # utf-8 encode error, falcon
237
237
'\U000522f4 ' , # utf-8 encode error, starcoder
238
+ "<s><s><unk><s>a<s>b<s>c<unk>d<unk></s>" ,
239
+ "<s> <s> <unk><s>a<s>b<s>c<unk>d<unk></s>" ,
238
240
]
239
241
240
242
@@ -334,7 +336,7 @@ def _valid(cpt):
334
336
return False
335
337
# if cpt == 0x2029: # deepseek-llm
336
338
# return False
337
- if unicodedata .category (chr (cpt )) in ( "Cn" , "Cs" , "Co" ): # undefined, surrogates, private
339
+ if unicodedata .category (chr (cpt )) in ("Cn" , "Cs" , "Co" ): # undefined, surrogates, private
338
340
return False
339
341
return True
340
342
@@ -426,6 +428,7 @@ def check_detokenizer(text: str, text1: str, text2: str) -> bool:
426
428
t_start = time .perf_counter ()
427
429
encode_errors = 0
428
430
decode_errors = 0
431
+ MAX_ERRORS = 10
429
432
430
433
logger .info ("%s: %s" % (generator .__name__ , "ini" ))
431
434
for text in generator :
@@ -444,23 +447,23 @@ def check_detokenizer(text: str, text1: str, text2: str) -> bool:
444
447
t_encode2 += t2 - t1
445
448
t_decode1 += t3 - t2
446
449
t_decode2 += t4 - t3
447
- if ids1 != ids2 :
450
+ if encode_errors < MAX_ERRORS and ids1 != ids2 :
448
451
i = find_first_mismatch (ids1 , ids2 )
449
452
ids1 = list (ids1 )[max (0 , i - 2 ) : i + 5 + 1 ]
450
453
ids2 = list (ids2 )[max (0 , i - 2 ) : i + 5 + 1 ]
451
454
logger .error (" Expected: " + str (ids1 ))
452
455
logger .error (" Result: " + str (ids2 ))
453
456
encode_errors += 1
454
457
logger .error (f" { encode_errors = } " )
455
- if not check_detokenizer (text , text1 , text2 ):
458
+ if decode_errors < MAX_ERRORS and not check_detokenizer (text , text1 , text2 ):
456
459
i = find_first_mismatch (text1 , text2 )
457
460
text1 = list (text1 [max (0 , i - 2 ) : i + 5 + 1 ])
458
461
text2 = list (text2 [max (0 , i - 2 ) : i + 5 + 1 ])
459
462
logger .error (" Expected: " + " " .join (hex (ord (x )) for x in text1 ))
460
463
logger .error (" Result: " + " " .join (hex (ord (x )) for x in text2 ))
461
464
decode_errors += 1
462
465
logger .error (f" { decode_errors = } " )
463
- if encode_errors >= 10 or decode_errors >= 10 :
466
+ if encode_errors >= MAX_ERRORS and decode_errors >= MAX_ERRORS :
464
467
logger .error (f" EXIT: { encode_errors = } { decode_errors = } " )
465
468
# raise Exception()
466
469
break
0 commit comments