Skip to content

Commit 68220fe

Browse files
author
jaime-m-p
committed
Update bruteforce test
1 parent 107923c commit 68220fe

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

tests/test-tokenizer-random.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,8 @@ def generator_custom_text_edge_cases() -> Iterator[str]:
235235
'å', # mpt
236236
'\U000ac517', # utf-8 encode error, falcon
237237
'\U000522f4', # utf-8 encode error, starcoder
238+
"<s><s><unk><s>a<s>b<s>c<unk>d<unk></s>",
239+
"<s> <s> <unk><s>a<s>b<s>c<unk>d<unk></s>",
238240
]
239241

240242

@@ -334,7 +336,7 @@ def _valid(cpt):
334336
return False
335337
# if cpt == 0x2029: # deepseek-llm
336338
# return False
337-
if unicodedata.category(chr(cpt)) in ( "Cn", "Cs", "Co" ): # undefined, surrogates, private
339+
if unicodedata.category(chr(cpt)) in ("Cn", "Cs", "Co"): # undefined, surrogates, private
338340
return False
339341
return True
340342

@@ -426,6 +428,7 @@ def check_detokenizer(text: str, text1: str, text2: str) -> bool:
426428
t_start = time.perf_counter()
427429
encode_errors = 0
428430
decode_errors = 0
431+
MAX_ERRORS = 10
429432

430433
logger.info("%s: %s" % (generator.__name__, "ini"))
431434
for text in generator:
@@ -444,23 +447,23 @@ def check_detokenizer(text: str, text1: str, text2: str) -> bool:
444447
t_encode2 += t2 - t1
445448
t_decode1 += t3 - t2
446449
t_decode2 += t4 - t3
447-
if ids1 != ids2:
450+
if encode_errors < MAX_ERRORS and ids1 != ids2:
448451
i = find_first_mismatch(ids1, ids2)
449452
ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1]
450453
ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1]
451454
logger.error(" Expected: " + str(ids1))
452455
logger.error(" Result: " + str(ids2))
453456
encode_errors += 1
454457
logger.error(f" {encode_errors=}")
455-
if not check_detokenizer(text, text1, text2):
458+
if decode_errors < MAX_ERRORS and not check_detokenizer(text, text1, text2):
456459
i = find_first_mismatch(text1, text2)
457460
text1 = list(text1[max(0, i - 2) : i + 5 + 1])
458461
text2 = list(text2[max(0, i - 2) : i + 5 + 1])
459462
logger.error(" Expected: " + " ".join(hex(ord(x)) for x in text1))
460463
logger.error(" Result: " + " ".join(hex(ord(x)) for x in text2))
461464
decode_errors += 1
462465
logger.error(f" {decode_errors=}")
463-
if encode_errors >= 10 or decode_errors >= 10:
466+
if encode_errors >= MAX_ERRORS and decode_errors >= MAX_ERRORS:
464467
logger.error(f" EXIT: {encode_errors=} {decode_errors=}")
465468
# raise Exception()
466469
break

0 commit comments

Comments
 (0)