Skip to content

Commit 4a28063

Browse files
author
jaime-m-p
committed
Update brute force test:
Detokenize special tokens. Replace errors with '\uFFFD' when detokenizing to 'utf-8'. More edge cases. Better detokenization results check.
1 parent 95a0df5 commit 4a28063

File tree

1 file changed

+32
-15
lines changed

1 file changed

+32
-15
lines changed

tests/test-tokenizer-random.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def detokenize(self, ids: list[int], special: bool = False) -> str:
107107
while num < 0 and len(self.text_buff) < (16 << 20):
108108
self.text_buff = self.ffi.new("uint8_t[]", -2 * num)
109109
num = self.lib.llama_detokenize(self.model, self.token_ids, len(ids), self.text_buff, len(self.text_buff), special)
110-
return str(self.ffi.buffer(self.text_buff, num), encoding="utf-8")
110+
return str(self.ffi.buffer(self.text_buff, num), encoding="utf-8", errors="replace") # replace errors with '\uFFFD'
111111

112112

113113
class Tokenizer:
@@ -144,7 +144,7 @@ def encode(self, text: str) -> list[int]:
144144
return self.model.encode(text, add_special_tokens=True)
145145

146146
def decode(self, ids: list[int]) -> str:
147-
return self.model.decode(ids, skip_special_tokens=True)
147+
return self.model.decode(ids, skip_special_tokens=False)
148148

149149

150150
class TokenizerLlamaCpp (Tokenizer):
@@ -160,7 +160,7 @@ def encode(self, text: str) -> list[int]:
160160
return self.model.tokenize(text, add_special=True, parse_special=True)
161161

162162
def decode(self, ids: list[int]) -> str:
163-
return self.model.detokenize(ids, special=False)
163+
return self.model.detokenize(ids, special=True)
164164

165165

166166
def generator_custom_text() -> Iterator[str]:
@@ -232,6 +232,9 @@ def generator_custom_text_edge_cases() -> Iterator[str]:
232232
'\xa0aC', # deepseek
233233
'\u2029 \uA3E4', # deepseek-llm
234234
"a ?",
235+
'å', # mpt
236+
'\U000ac517', # utf-8 encode error, falcon
237+
'\U000522f4', # utf-8 encode error, starcoder
235238
]
236239

237240

@@ -265,7 +268,7 @@ def generator_apostrophe() -> Iterator[str]:
265268

266269

267270
def generator_added_lr_strip(tokenizer: TokenizerGroundtruth) -> Iterator[str]:
268-
WHITESPACES = ["", " ", " ", " "]
271+
WHITESPACES = ["", " ", " ", "\n", "\r\n", "\n\n", "\t", "\t\t"]
269272
all_tokens = list(sorted(set(tokenizer.special_tokens + tokenizer.added_tokens)))
270273
for token in all_tokens:
271274
for lstrip in WHITESPACES:
@@ -329,11 +332,9 @@ def generator_unicodes() -> Iterator[str]:
329332
def _valid(cpt):
330333
if cpt >= 0x30000: # unassigned and supplement­ary
331334
return False
332-
if 0x00D800 <= cpt <= 0x00F8FF: # Surrogates
333-
return False
334335
# if cpt == 0x2029: # deepseek-llm
335336
# return False
336-
if unicodedata.category(chr(cpt)) == "Cn":
337+
if unicodedata.category(chr(cpt)) in ( "Cn", "Cs", "Co" ): # undefined, surrogates, private
337338
return False
338339
return True
339340

@@ -396,7 +397,7 @@ def generator_random_vocab_words(tokenizer: TokenizerGroundtruth, iterations=100
396397
yield "".join(text)
397398

398399

399-
def compare_tokenizers(tokenizer1: Tokenizer, tokenizer2: Tokenizer, generator: Iterator[str]):
400+
def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]):
400401

401402
def find_first_mismatch(ids1: list[int], ids2: list[int]):
402403
for i, (a, b) in enumerate(zip(ids1, ids2)):
@@ -406,12 +407,25 @@ def find_first_mismatch(ids1: list[int], ids2: list[int]):
406407
return -1
407408
return min(len(ids1), len(ids2))
408409

410+
def check_detokenizer(text: str, text1: str, text2: str) -> bool:
411+
if text1 == text2: # equal to TokenizerGroundtruth?
412+
return True
413+
# equal to source text?
414+
if tokenizer1.add_bos_token: # remove BOS
415+
if text2.startswith(tokenizer1.bos_token):
416+
text2 = text2[len(tokenizer1.bos_token):]
417+
if tokenizer1.add_eos_token: # remove EOS
418+
if text2.endswith(tokenizer1.eos_token):
419+
text2 = text2[:-len(tokenizer1.eos_token)]
420+
return text == text2
421+
409422
t_encode1 = 0
410423
t_encode2 = 0
411424
t_decode1 = 0
412425
t_decode2 = 0
413426
t_start = time.perf_counter()
414-
num_errors = 0
427+
encode_errors = 0
428+
decode_errors = 0
415429

416430
logger.info("%s: %s" % (generator.__name__, "ini"))
417431
for text in generator:
@@ -424,7 +438,7 @@ def find_first_mismatch(ids1: list[int], ids2: list[int]):
424438
t2 = time.perf_counter()
425439
text1 = tokenizer1.decode(ids1)
426440
t3 = time.perf_counter()
427-
text2 = tokenizer2.decode(ids2)
441+
text2 = tokenizer2.decode(ids1)
428442
t4 = time.perf_counter()
429443
t_encode1 += t1 - t0
430444
t_encode2 += t2 - t1
@@ -436,16 +450,18 @@ def find_first_mismatch(ids1: list[int], ids2: list[int]):
436450
ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1]
437451
logger.error(" Expected: " + str(ids1))
438452
logger.error(" Result: " + str(ids2))
439-
num_errors += 1
440-
if text1 != text2 and text != text2:
453+
encode_errors += 1
454+
logger.error(f" {encode_errors=}")
455+
if not check_detokenizer(text, text1, text2):
441456
i = find_first_mismatch(text1, text2)
442457
text1 = list(text1[max(0, i - 2) : i + 5 + 1])
443458
text2 = list(text2[max(0, i - 2) : i + 5 + 1])
444459
logger.error(" Expected: " + " ".join(hex(ord(x)) for x in text1))
445460
logger.error(" Result: " + " ".join(hex(ord(x)) for x in text2))
446-
num_errors += 1
447-
if num_errors >= 10:
448-
logger.error(f" EXIT: {num_errors=}")
461+
decode_errors += 1
462+
logger.error(f" {decode_errors=}")
463+
if encode_errors >= 10 or decode_errors >= 10:
464+
logger.error(f" EXIT: {encode_errors=} {decode_errors=}")
449465
# raise Exception()
450466
break
451467

@@ -504,6 +520,7 @@ def main(argv: list[str] = None):
504520
tokenizers = [
505521
"llama-spm", # SPM
506522
"phi-3", # SPM
523+
"baichuan", # SPM
507524
"bert-bge", # WPM
508525
"jina-v2-en", # WPM
509526
"llama-bpe", # BPE

0 commit comments

Comments
 (0)