Skip to content

Commit a5fa2fe

Browse files
author
jaime-m-p
committed
Style
1 parent edf375d commit a5fa2fe

File tree

1 file changed

+29
-28
lines changed

1 file changed

+29
-28
lines changed

tests/test-tokenizer-random-bpe.py

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class LibLlama:
2121
DEFAULT_PATH_LLAMA_H = "./llama.h"
2222
DEFAULT_PATH_LIBLLAMA = "./build/libllama.so" # CMakeLists.txt: BUILD_SHARED_LIBS ON
2323

24-
def __init__(self, path_llama_h:str=None, path_libllama:str=None):
24+
def __init__(self, path_llama_h: str = None, path_libllama: str = None):
2525
path_llama_h = path_llama_h or self.DEFAULT_PATH_LLAMA_H
2626
path_libllama = path_libllama or self.DEFAULT_PATH_LIBLLAMA
2727
(self.ffi, self.lib) = self._load_libllama_cffi(path_llama_h, path_libllama)
@@ -42,34 +42,35 @@ def _load_libllama_cffi(self, path_llama_h: str, path_libllama: str):
4242
ffi.cdef(source, override=True)
4343
lib = ffi.dlopen(path_libllama)
4444
return (ffi, lib)
45-
45+
4646
def model_default_params(self, **kwargs):
4747
mparams = self.lib.llama_model_default_params()
4848
for k, v in kwargs.items():
4949
setattr(mparams, k, v)
5050
return mparams
51-
51+
5252
def context_default_params(self, **kwargs):
5353
cparams = self.lib.llama_context_default_params()
5454
for k, v in kwargs.items():
5555
setattr(cparams, k, v)
5656
return cparams
5757

58+
5859
class LibLlamaModel:
5960

60-
def __init__(self, libllama:LibLlama, path_model:str, mparams={}, cparams={}):
61+
def __init__(self, libllama: LibLlama, path_model: str, mparams={}, cparams={}):
6162
self.lib = libllama.lib
6263
self.ffi = libllama.ffi
6364
if type(mparams) == dict:
6465
mparams = libllama.model_default_params(**mparams)
6566
self.model = self.lib.llama_load_model_from_file(path_model.encode(), mparams)
6667
if not self.model:
67-
raise RuntimeError("error: failed to load model '%s'"%path_model)
68+
raise RuntimeError("error: failed to load model '%s'" % path_model)
6869
if type(cparams) == dict:
6970
cparams = libllama.context_default_params(**cparams)
7071
self.ctx = self.lib.llama_new_context_with_model(self.model, cparams)
7172
if not self.ctx:
72-
raise RuntimeError("error: failed to create context for model '%s'"%path_model)
73+
raise RuntimeError("error: failed to create context for model '%s'" % path_model)
7374
n_tokens_max = self.lib.llama_n_ctx(self.ctx)
7475
self.token_ids = self.ffi.new("llama_token[]", n_tokens_max)
7576

@@ -82,7 +83,7 @@ def free(self):
8283
self.model = None
8384
self.lib = None
8485

85-
def tokenize(self, text:str, n_tokens_max:int=0, add_special:bool=False, parse_special:bool=False) -> list[int]:
86+
def tokenize(self, text: str, n_tokens_max: int = 0, add_special: bool = False, parse_special: bool = False) -> list[int]:
8687
n_tokens_max = n_tokens_max if n_tokens_max > 0 else len(self.token_ids)
8788
text = text.encode("utf-8")
8889
num = self.lib.llama_tokenize(self.model, text, len(text), self.token_ids, n_tokens_max, add_special, parse_special)
@@ -91,14 +92,14 @@ def tokenize(self, text:str, n_tokens_max:int=0, add_special:bool=False, parse_s
9192
return list(self.token_ids[0:num])
9293

9394

94-
def find_first_mismatch(ids1:list[int], ids2:list[int]):
95+
def find_first_mismatch(ids1: list[int], ids2: list[int]):
9596
for i, (a,b) in enumerate(zip(ids1, ids2)):
9697
if a != b:
9798
return i
9899
return -1 if len(ids1) == len(ids2) else i
99100

100101

101-
def test_custom_texts(model:LibLlamaModel, tokenizer:PreTrainedTokenizerBase):
102+
def test_custom_texts(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase):
102103

103104
tests = [
104105
"",
@@ -153,7 +154,7 @@ def test_custom_texts(model:LibLlamaModel, tokenizer:PreTrainedTokenizerBase):
153154
'\uFEFF//', # unicode_ranges_control, 0xFEFF (BOM)
154155
]
155156

156-
for text in tests+more_tests:
157+
for text in tests + more_tests:
157158
ids1 = model.tokenize(text, parse_special=True)
158159
ids2 = tokenizer.encode(text)
159160
logger.info(repr(text))
@@ -164,22 +165,22 @@ def test_custom_texts(model:LibLlamaModel, tokenizer:PreTrainedTokenizerBase):
164165
raise Exception()
165166

166167

167-
def test_random_chars(model:LibLlamaModel, tokenizer:PreTrainedTokenizerBase, iterations=100):
168+
def test_random_chars(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase, iterations=100):
168169

169-
WHITESPACES = list(" "*20 + "\n"*5 + "\r\n"*5 + "\t"*5)
170+
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
170171
CHARS = list(set("""
171172
ABCDEFGHIJKLMNOPQRSTUVWXYZ
172173
abcdefghijklmnopqrstuvwxyz
173174
ÁÉÍÓÚÀÈÌÒÙÂÊÎÔÛÄËÏÖÜ
174175
áéíóúàèìòùâêîôûäëïöü
175176
.-,*/-+ª!"·$%&/()=?¿[]{}<>\\|@#~½¬~;:_
176177
"""))
177-
178+
178179
logger.info("Bruteforce random chars encodings ...")
179180
rand = random.Random()
180181
for m in range(iterations):
181182

182-
logger.debug("%d/%d" % (m+1,iterations))
183+
logger.debug("%d/%d" % (m + 1, iterations))
183184
rand.seed(m)
184185

185186
text = []
@@ -188,29 +189,29 @@ def test_random_chars(model:LibLlamaModel, tokenizer:PreTrainedTokenizerBase, it
188189
k = rand.randint(1, 7)
189190
word = rand.choices(CHARS, k=k)
190191
space = rand.choice(WHITESPACES)
191-
text.append("".join(word)+space)
192+
text.append("".join(word) + space)
192193
text = "".join(text)
193194

194195
ids1 = model.tokenize(text, parse_special=True)
195196
ids2 = tokenizer.encode(text)
196197
assert(ids1 == ids2)
197198

198199

199-
def test_random_vocab_chars(model:LibLlamaModel, tokenizer:PreTrainedTokenizerBase, iterations=100):
200+
def test_random_vocab_chars(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase, iterations=100):
200201

201202
logger.info("Building vocab char list ...")
202203
vocab_ids = list(tokenizer.vocab.values())
203204
vocab_text = tokenizer.decode(vocab_ids)
204205
vocab_chars = list(set(vocab_text))
205206
del vocab_ids, vocab_text
206-
207+
207208
logger.info("Bruteforce random text encodings ...")
208209
rand = random.Random()
209210
for m in range(iterations):
210211

211-
logger.debug("%d/%d" % (m+1,iterations))
212+
logger.debug("%d/%d" % (m + 1, iterations))
212213
rand.seed(m)
213-
214+
214215
text = rand.choices(vocab_chars, k=1024)
215216
text = "".join(text)
216217

@@ -219,7 +220,7 @@ def test_random_vocab_chars(model:LibLlamaModel, tokenizer:PreTrainedTokenizerBa
219220
assert(ids1 == ids2)
220221

221222

222-
def test_random_vocab_tokens(model:LibLlamaModel, tokenizer:PreTrainedTokenizerBase, iterations=100):
223+
def test_random_vocab_tokens(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase, iterations = 100):
223224

224225
logger.info("Building token list ...")
225226
space_id = tokenizer.encode(" ")[0]
@@ -230,7 +231,7 @@ def test_random_vocab_tokens(model:LibLlamaModel, tokenizer:PreTrainedTokenizerB
230231
vocab_tokens = tokenizer.decode(vocab_ids)
231232
vocab_tokens = vocab_tokens.split(" ")
232233
del vocab_ids
233-
234+
234235
logger.info("Checking single token encodings ...")
235236
for token in vocab_tokens:
236237
ids1 = model.tokenize(token, parse_special=True)
@@ -241,15 +242,15 @@ def test_random_vocab_tokens(model:LibLlamaModel, tokenizer:PreTrainedTokenizerB
241242
rand = random.Random()
242243
for m in range(iterations):
243244

244-
logger.debug("%d/%d" % (m+1,iterations))
245+
logger.debug("%d/%d" % (m + 1, iterations))
245246
rand.seed(m)
246-
247+
247248
text = []
248249
num_words = rand.randint(300, 400)
249250
for i in range(num_words):
250251
k = rand.randint(1, 3)
251252
tokens = rand.choices(vocab_tokens, k=k)
252-
tokens = [ t.strip(" \n\r\t") for t in tokens ]
253+
tokens = [t.strip(" \n\r\t") for t in tokens]
253254
sep = rand.choice(" \n\r\t")
254255
text.append("".join(tokens) + sep)
255256
text = "".join(text)
@@ -259,15 +260,15 @@ def test_random_vocab_tokens(model:LibLlamaModel, tokenizer:PreTrainedTokenizerB
259260
assert(ids1 == ids2)
260261

261262

262-
def test_random_bytes(model:LibLlamaModel, tokenizer:PreTrainedTokenizerBase, iterations=100):
263+
def test_random_bytes(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase, iterations=100):
263264

264-
WHITESPACES = list(" "*20 + "\n"*5 + "\r\n"*5 + "\t"*5)
265+
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
265266

266267
logger.info("Bruteforce random bytes encodings ...")
267268
rand = random.Random()
268269
for m in range(iterations):
269270

270-
logger.debug("%d/%d" % (m+1,iterations))
271+
logger.debug("%d/%d" % (m + 1, iterations))
271272
rand.seed(m)
272273

273274
text = []
@@ -302,6 +303,6 @@ def test_random_bytes(model:LibLlamaModel, tokenizer:PreTrainedTokenizerBase, it
302303
test_random_chars(model, tokenizer, 10_000)
303304
test_random_vocab_chars(model, tokenizer, 10_000)
304305
test_random_vocab_tokens(model, tokenizer, 10_000)
305-
#test_random_bytes(model, tokenizer, 10_000) # FAIL
306+
# test_random_bytes(model, tokenizer, 10_000) # FAIL
306307

307308
model.free()

0 commit comments

Comments
 (0)