Skip to content

Commit bb205ee

Browse files
author
jaime-m-p
committed
Update and bugfix brute force random test
1 parent e44e608 commit bb205ee

File tree

1 file changed

+40
-35
lines changed

1 file changed

+40
-35
lines changed

tests/test-tokenizer-random.py

Lines changed: 40 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Test libllama tokenizer == AutoTokenizer.
2-
# Brute force random tokens/text generation.
2+
# Brute force random words/text generation.
33
#
44
# Sample usage:
55
#
@@ -12,7 +12,7 @@
1212
import subprocess
1313
import random
1414

15-
from typing import Iterator
15+
from typing import Callable, Iterator
1616

1717
import cffi
1818
from transformers import AutoTokenizer, PreTrainedTokenizerBase
@@ -152,10 +152,17 @@ def generator_custom_text_edge_cases() -> Iterator[str]:
152152
'a 〇b', # unicode_ranges_digit, 0x3007
153153
'Ⅵ-a', # unicode_ranges_digit, {0x00002150, 0x0000218F} // Number Forms
154154
'\uFEFF//', # unicode_ranges_control, 0xFEFF (BOM)
155-
'<s>a' # TODO: Phi-3 fail
155+
'Cửa Việt', # llama-3, ignore_merges = true
156+
'<s>a', # TODO: Phi-3 fail
157+
'a\na', # TODO: Bert fail
156158
]
157159

158160

161+
def generator_vocab_words(vocab: list[str]) -> Iterator[str]:
162+
"""Brute force check all vocab words"""
163+
yield from vocab
164+
165+
159166
def generator_random_chars(iterations = 100) -> Iterator[str]:
160167
"""Brute force random text with simple characters"""
161168

@@ -181,13 +188,13 @@ def generator_random_chars(iterations = 100) -> Iterator[str]:
181188
yield "".join(text)
182189

183190

184-
def generator_random_vocab_chars(tokenizer: PreTrainedTokenizerBase, iterations = 100) -> Iterator[str]:
191+
def generator_random_vocab_chars(vocab: list[str], iterations = 100) -> Iterator[str]:
185192
"""Brute force random text with vocab characters"""
186193

187-
vocab_ids = list(tokenizer.vocab.values())
188-
vocab_text = tokenizer.decode(vocab_ids, skip_special_tokens=True)
189-
vocab_chars = list(set(vocab_text))
190-
del vocab_ids, vocab_text
194+
vocab_chars = set()
195+
for word in vocab:
196+
vocab_chars.update(word)
197+
vocab_chars = list(vocab_chars)
191198

192199
rand = random.Random()
193200
for m in range(iterations):
@@ -196,19 +203,11 @@ def generator_random_vocab_chars(tokenizer: PreTrainedTokenizerBase, iterations
196203
yield "".join(text)
197204

198205

199-
def generator_random_vocab_tokens(tokenizer: PreTrainedTokenizerBase, iterations = 100) -> Iterator[str]:
200-
"""Brute force random text from vocab tokens"""
201-
202-
space_id = tokenizer.encode(" ", add_special_tokens=False)[0]
203-
vocab_ids = list(tokenizer.vocab.values())
204-
vocab_ids = list(sorted(vocab_ids + vocab_ids))
205-
for i in range(1, len(vocab_ids), 2):
206-
vocab_ids[i] = space_id
207-
vocab_tokens = tokenizer.decode(vocab_ids, skip_special_tokens=True)
208-
vocab_tokens = vocab_tokens.split(" ")
209-
del vocab_ids
206+
def generator_random_vocab_words(vocab: list[str], iterations = 100) -> Iterator[str]:
207+
"""Brute force random text from vocab words"""
210208

211-
yield from vocab_tokens
209+
vocab = [w.strip() for w in vocab]
210+
yield from vocab
212211

213212
rand = random.Random()
214213
for m in range(iterations):
@@ -217,10 +216,9 @@ def generator_random_vocab_tokens(tokenizer: PreTrainedTokenizerBase, iterations
217216
num_words = rand.randint(300, 400)
218217
for i in range(num_words):
219218
k = rand.randint(1, 3)
220-
tokens = rand.choices(vocab_tokens, k=k)
221-
tokens = [t.strip(" \n\r\t") for t in tokens]
219+
words = rand.choices(vocab, k=k)
222220
sep = rand.choice(" \n\r\t")
223-
text.append("".join(tokens) + sep)
221+
text.append("".join(words) + sep)
224222
yield "".join(text)
225223

226224

@@ -242,7 +240,7 @@ def generator_random_bytes(iterations = 100) -> Iterator[str]:
242240
yield "".join(text)
243241

244242

245-
def test_compare_tokenizer(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase, generator: Iterator[str]):
243+
def test_compare_tokenizer(func_tokenize1: Callable, func_tokenize2: Callable, generator: Iterator[str]):
246244

247245
def find_first_mismatch(ids1: list[int], ids2: list[int]):
248246
for i, (a,b) in enumerate(zip(ids1, ids2)):
@@ -255,8 +253,8 @@ def find_first_mismatch(ids1: list[int], ids2: list[int]):
255253
t0 = time.perf_counter()
256254
logger.info("%s: %s" % (generator.__name__, "ini"))
257255
for text in generator:
258-
ids1 = model.tokenize(text, add_special=False, parse_special=False)
259-
ids2 = tokenizer.encode(text, add_special_tokens=False)
256+
ids1 = func_tokenize1(text)
257+
ids2 = func_tokenize2(text)
260258
if ids1 != ids2:
261259
i = find_first_mismatch(ids1, ids2)
262260
ids1 = list(ids1)[max(0, i - 2) : i + 2 + 1]
@@ -281,15 +279,22 @@ def find_first_mismatch(ids1: list[int], ids2: list[int]):
281279

282280
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
283281

284-
model = LibLlamaModel(LibLlama(), args.vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=2048))
285-
286282
tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer)
287-
288-
test_compare_tokenizer(model, tokenizer, generator_custom_text())
289-
test_compare_tokenizer(model, tokenizer, generator_custom_text_edge_cases())
290-
test_compare_tokenizer(model, tokenizer, generator_random_chars(10_000))
291-
test_compare_tokenizer(model, tokenizer, generator_random_vocab_chars(tokenizer, 10_000))
292-
test_compare_tokenizer(model, tokenizer, generator_random_vocab_tokens(tokenizer, 10_000))
293-
# test_compare_tokenizer(model, tokenizer, generator_random_bytes(10_000)) # FAIL
283+
def func_tokenize2(text:str):
284+
return tokenizer.encode(text, add_special_tokens=False)
285+
286+
model = LibLlamaModel(LibLlama(), args.vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096))
287+
parse_special = all(len(func_tokenize2(t)) == 1 for t in tokenizer.all_special_tokens)
288+
def func_tokenize1(text:str):
289+
return model.tokenize(text, add_special=False, parse_special=parse_special)
290+
291+
vocab = tokenizer.batch_decode(list(tokenizer.get_vocab().values()), skip_special_tokens=True)
292+
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text())
293+
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text_edge_cases())
294+
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_vocab_words(vocab))
295+
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_chars(10_000))
296+
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_vocab_chars(vocab, 10_000))
297+
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_vocab_words(vocab, 10_000))
298+
# test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_bytes(10_000)) # FAIL
294299

295300
model.free()

0 commit comments

Comments
 (0)