Skip to content

Commit 77cbb79

Browse files
author
jaime-m-p
committed
Refactor random tokenizer test
1 parent 70ca1fe commit 77cbb79

File tree

1 file changed

+65
-81
lines changed

1 file changed

+65
-81
lines changed

tests/test-tokenizer-random-bpe.py renamed to tests/test-tokenizer-random.py

Lines changed: 65 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
1-
# tests with BPE tokenizer
1+
# Test libllama tokenizer == AutoTokenizer.
2+
# Brute force random tokens/text generation.
23
#
3-
# sample usage:
4+
# Sample usage:
45
#
5-
# python3 tests/test-tokenizer-0-bpe.py ./models/ggml-vocab-llama-bpe.gguf ~/Data/huggingface/Meta-Llama-3-8B-Instruct/
6+
# python3 tests/test-tokenizer-random.py ./models/ggml-vocab-llama-bpe.gguf ./models/tokenizers/llama-bpe
67
#
78

9+
import time
810
import logging
911
import argparse
1012
import subprocess
1113
import random
1214

15+
from typing import Iterator
16+
1317
import cffi
1418
from transformers import AutoTokenizer, PreTrainedTokenizerBase
1519

@@ -30,7 +34,7 @@ def __init__(self, path_llama_h: str = None, path_libllama: str = None):
3034
def _load_libllama_cffi(self, path_llama_h: str, path_libllama: str):
3135
cmd = ["gcc", "-E", "-P", "-D__restrict=", "-D__attribute__(x)=", "-D__asm__(x)=", path_llama_h]
3236
res = subprocess.run(cmd, stdout=subprocess.PIPE)
33-
assert(res.returncode == 0)
37+
assert (res.returncode == 0)
3438
source = res.stdout.decode()
3539
ffi = cffi.FFI()
3640
if True: # workarounds for pycparser
@@ -61,12 +65,12 @@ class LibLlamaModel:
6165
def __init__(self, libllama: LibLlama, path_model: str, mparams={}, cparams={}):
6266
self.lib = libllama.lib
6367
self.ffi = libllama.ffi
64-
if type(mparams) == dict:
68+
if isinstance(mparams, dict):
6569
mparams = libllama.model_default_params(**mparams)
6670
self.model = self.lib.llama_load_model_from_file(path_model.encode(), mparams)
6771
if not self.model:
6872
raise RuntimeError("error: failed to load model '%s'" % path_model)
69-
if type(cparams) == dict:
73+
if isinstance(cparams, dict):
7074
cparams = libllama.context_default_params(**cparams)
7175
self.ctx = self.lib.llama_new_context_with_model(self.model, cparams)
7276
if not self.ctx:
@@ -92,18 +96,9 @@ def tokenize(self, text: str, n_tokens_max: int = 0, add_special: bool = False,
9296
return list(self.token_ids[0:num])
9397

9498

95-
def find_first_mismatch(ids1: list[int], ids2: list[int]):
96-
for i, (a,b) in enumerate(zip(ids1, ids2)):
97-
if a != b:
98-
return i
99-
if len(ids1) == len(ids2):
100-
return -1
101-
return min(len(ids1), len(ids2))
102-
103-
104-
def test_custom_texts(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase):
105-
106-
tests = [
99+
def generator_custom_text() -> Iterator[str]:
100+
"""General tests"""
101+
yield from [
107102
"",
108103
" ",
109104
" ",
@@ -146,7 +141,10 @@ def test_custom_texts(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase):
146141
"333333333",
147142
]
148143

149-
more_tests = [
144+
145+
def generator_custom_text_edge_cases() -> Iterator[str]:
146+
"""Edge cases found while debugging"""
147+
yield from [
150148
'\x1f-a', # unicode_ranges_control, {0x00001C, 0x00001F}
151149
'¼-a', # unicode_ranges_digit, 0x00BC
152150
'½-a', # unicode_ranges_digit, 0x00BD
@@ -157,18 +155,9 @@ def test_custom_texts(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase):
157155
'<s>a' # TODO: Phi-3 fail
158156
]
159157

160-
for text in tests + more_tests:
161-
ids1 = model.tokenize(text, add_special=False, parse_special=False)
162-
ids2 = tokenizer.encode(text, add_special_tokens=False)
163-
logger.info(repr(text))
164-
if ids1 != ids2:
165-
logger.info(" TokenIDs: " + str(list(ids1)))
166-
logger.info(" Expected: " + str(list(ids2)))
167-
logger.info(" Index: %d" % find_first_mismatch(ids1, ids2))
168-
raise Exception()
169-
170158

171-
def test_random_chars(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase, iterations = 100):
159+
def generator_random_chars(iterations = 100) -> Iterator[str]:
160+
"""Brute force random text with simple characters"""
172161

173162
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
174163
CHARS = list(set("""
@@ -179,75 +168,51 @@ def test_random_chars(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase,
179168
.-,*/-+ª!"·$%&/()=?¿[]{}<>\\|@#~½¬~;:_
180169
"""))
181170

182-
logger.info("Bruteforce random chars encodings ...")
183171
rand = random.Random()
184172
for m in range(iterations):
185-
186-
logger.debug("%d/%d" % (m + 1, iterations))
187173
rand.seed(m)
188-
189174
text = []
190175
num_words = rand.randint(300, 400)
191176
for i in range(num_words):
192177
k = rand.randint(1, 7)
193178
word = rand.choices(CHARS, k=k)
194179
space = rand.choice(WHITESPACES)
195180
text.append("".join(word) + space)
196-
text = "".join(text)
181+
yield "".join(text)
197182

198-
ids1 = model.tokenize(text, add_special=False, parse_special=False)
199-
ids2 = tokenizer.encode(text, add_special_tokens=False)
200-
assert(ids1 == ids2)
201183

184+
def generator_random_vocab_chars(tokenizer: PreTrainedTokenizerBase, iterations = 100) -> Iterator[str]:
185+
"""Brute force random text with vocab characters"""
202186

203-
def test_random_vocab_chars(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase, iterations = 100):
204-
205-
logger.info("Building vocab char list ...")
206187
vocab_ids = list(tokenizer.vocab.values())
207-
vocab_text = tokenizer.decode(vocab_ids)
188+
vocab_text = tokenizer.decode(vocab_ids, skip_special_tokens=True)
208189
vocab_chars = list(set(vocab_text))
209190
del vocab_ids, vocab_text
210191

211-
logger.info("Bruteforce random text encodings ...")
212192
rand = random.Random()
213193
for m in range(iterations):
214-
215-
logger.debug("%d/%d" % (m + 1, iterations))
216194
rand.seed(m)
217-
218195
text = rand.choices(vocab_chars, k=1024)
219-
text = "".join(text)
220-
221-
ids1 = model.tokenize(text, add_special=False, parse_special=False)
222-
ids2 = tokenizer.encode(text, add_special_tokens=False)
223-
assert(ids1 == ids2)
196+
yield "".join(text)
224197

225198

226-
def test_random_vocab_tokens(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase, iterations = 100):
199+
def generator_random_vocab_tokens(tokenizer: PreTrainedTokenizerBase, iterations = 100) -> Iterator[str]:
200+
"""Brute force random text from vocab tokens"""
227201

228-
logger.info("Building token list ...")
229-
space_id = tokenizer.encode(" ")[0]
202+
space_id = tokenizer.encode(" ", add_special_tokens=False)[0]
230203
vocab_ids = list(tokenizer.vocab.values())
231204
vocab_ids = list(sorted(vocab_ids + vocab_ids))
232205
for i in range(1, len(vocab_ids), 2):
233206
vocab_ids[i] = space_id
234-
vocab_tokens = tokenizer.decode(vocab_ids)
207+
vocab_tokens = tokenizer.decode(vocab_ids, skip_special_tokens=True)
235208
vocab_tokens = vocab_tokens.split(" ")
236209
del vocab_ids
237210

238-
logger.info("Checking single token encodings ...")
239-
for token in vocab_tokens:
240-
ids1 = model.tokenize(token, parse_special=True)
241-
ids2 = tokenizer.encode(token)
242-
assert(ids1 == ids2)
211+
yield from vocab_tokens
243212

244-
logger.info("Bruteforce random text encodings ...")
245213
rand = random.Random()
246214
for m in range(iterations):
247-
248-
logger.debug("%d/%d" % (m + 1, iterations))
249215
rand.seed(m)
250-
251216
text = []
252217
num_words = rand.randint(300, 400)
253218
for i in range(num_words):
@@ -256,36 +221,54 @@ def test_random_vocab_tokens(model: LibLlamaModel, tokenizer: PreTrainedTokenize
256221
tokens = [t.strip(" \n\r\t") for t in tokens]
257222
sep = rand.choice(" \n\r\t")
258223
text.append("".join(tokens) + sep)
259-
text = "".join(text)
260-
261-
ids1 = model.tokenize(text, add_special=False, parse_special=False)
262-
ids2 = tokenizer.encode(text, add_special_tokens=False)
263-
assert(ids1 == ids2)
224+
yield "".join(text)
264225

265226

266-
def test_random_bytes(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase, iterations = 100):
227+
def generator_random_bytes(iterations = 100) -> Iterator[str]:
228+
"""Brute force random bytes"""
267229

268230
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
269231

270-
logger.info("Bruteforce random bytes encodings ...")
271232
rand = random.Random()
272233
for m in range(iterations):
273-
274-
logger.debug("%d/%d" % (m + 1, iterations))
275234
rand.seed(m)
276-
277235
text = []
278236
num_words = rand.randint(300, 400)
279237
for i in range(num_words):
280238
k = rand.randint(1, 8)
281239
word = [chr(r) for r in rand.randbytes(k) if r]
282240
word.append(rand.choice(WHITESPACES))
283241
text.append("".join(word))
284-
text = "".join(text)
242+
yield "".join(text)
243+
285244

245+
def test_compare_tokenizer(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase, generator: Iterator[str]):
246+
247+
def find_first_mismatch(ids1: list[int], ids2: list[int]):
248+
for i, (a,b) in enumerate(zip(ids1, ids2)):
249+
if a != b:
250+
return i
251+
if len(ids1) == len(ids2):
252+
return -1
253+
return min(len(ids1), len(ids2))
254+
255+
t0 = time.perf_counter()
256+
logger.info("%s: %s" % (generator.__name__, "ini"))
257+
for text in generator:
286258
ids1 = model.tokenize(text, add_special=False, parse_special=False)
287259
ids2 = tokenizer.encode(text, add_special_tokens=False)
288-
assert(ids1 == ids2)
260+
if ids1 != ids2:
261+
i = find_first_mismatch(ids1, ids2)
262+
ids1 = list(ids1)[max(0, i - 2) : i + 2 + 1]
263+
ids2 = list(ids2)[max(0, i - 2) : i + 2 + 1]
264+
text2 = tokenizer.decode(ids2, skip_special_tokens=True)
265+
assert (text2 in text)
266+
logger.info(" Text: " + repr(text2))
267+
logger.info(" TokenIDs: " + str(ids1))
268+
logger.info(" Expected: " + str(ids2))
269+
raise Exception()
270+
t1 = time.perf_counter()
271+
logger.info("%s: end, time: %.3f secs" % (generator.__name__, t1 - t0))
289272

290273

291274
if __name__ == "__main__":
@@ -302,10 +285,11 @@ def test_random_bytes(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase,
302285

303286
tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer)
304287

305-
test_custom_texts(model, tokenizer)
306-
test_random_chars(model, tokenizer, 10_000)
307-
test_random_vocab_chars(model, tokenizer, 10_000)
308-
test_random_vocab_tokens(model, tokenizer, 10_000)
309-
# test_random_bytes(model, tokenizer, 10_000) # FAIL
288+
test_compare_tokenizer(model, tokenizer, generator_custom_text())
289+
test_compare_tokenizer(model, tokenizer, generator_custom_text_edge_cases())
290+
test_compare_tokenizer(model, tokenizer, generator_random_chars(10_000))
291+
test_compare_tokenizer(model, tokenizer, generator_random_vocab_chars(tokenizer, 10_000))
292+
test_compare_tokenizer(model, tokenizer, generator_random_vocab_tokens(tokenizer, 10_000))
293+
# test_compare_tokenizer(model, tokenizer, generator_random_bytes(10_000)) # FAIL
310294

311295
model.free()

0 commit comments

Comments
 (0)