6
6
# python3 tests/test-tokenizer-random.py ./models/ggml-vocab-llama-bpe.gguf ./models/tokenizers/llama-bpe
7
7
#
8
8
9
+ import os
9
10
import time
10
11
import logging
11
12
import argparse
15
16
from typing import Callable , Iterator
16
17
17
18
import cffi
18
- from transformers import AutoTokenizer , PreTrainedTokenizerBase
19
+ from transformers import AutoTokenizer
19
20
20
21
logger = logging .getLogger ("test-tokenizer-random-bpe" )
21
22
@@ -145,16 +146,16 @@ def generator_custom_text() -> Iterator[str]:
145
146
def generator_custom_text_edge_cases () -> Iterator [str ]:
146
147
"""Edge cases found while debugging"""
147
148
yield from [
148
- '\x1f -a' , # unicode_ranges_control, {0x00001C, 0x00001F}
149
- '¼-a' , # unicode_ranges_digit, 0x00BC
150
- '½-a' , # unicode_ranges_digit, 0x00BD
151
- '¾-a' , # unicode_ranges_digit, 0x00BE
152
- 'a 〇b' , # unicode_ranges_digit, 0x3007
153
- 'Ⅵ-a' , # unicode_ranges_digit, {0x00002150, 0x0000218F} // Number Forms
154
- '\uFEFF //' , # unicode_ranges_control, 0xFEFF (BOM)
155
- 'Cửa Việt' , # llama-3, ignore_merges = true
156
- '<s>a' , # TODO: Phi-3 fail
157
- 'a\n a' , # TODO: Bert fail
149
+ '\x1f -a' , # unicode_ranges_control, {0x00001C, 0x00001F}
150
+ '¼-a' , # unicode_ranges_digit, 0x00BC
151
+ '½-a' , # unicode_ranges_digit, 0x00BD
152
+ '¾-a' , # unicode_ranges_digit, 0x00BE
153
+ 'a 〇b' , # unicode_ranges_digit, 0x3007
154
+ 'Ⅵ-a' , # unicode_ranges_digit, {0x00002150, 0x0000218F} // Number Forms
155
+ '\uFEFF //' , # unicode_ranges_control, 0xFEFF (BOM)
156
+ 'Cửa Việt' , # llama-3, ignore_merges = true
157
+ '<s>a' , # TODO: Phi-3 fail
158
+ 'a\n a' , # TODO: Bert fail
158
159
]
159
160
160
161
@@ -163,7 +164,7 @@ def generator_vocab_words(vocab: list[str]) -> Iterator[str]:
163
164
yield from vocab
164
165
165
166
166
- def generator_random_chars (iterations = 100 ) -> Iterator [str ]:
167
+ def generator_random_chars (iterations = 100 ) -> Iterator [str ]:
167
168
"""Brute force random text with simple characters"""
168
169
169
170
WHITESPACES = list (" " * 20 + "\n " * 5 + "\r \n " * 5 + "\t " * 5 )
@@ -188,7 +189,7 @@ def generator_random_chars(iterations = 100) -> Iterator[str]:
188
189
yield "" .join (text )
189
190
190
191
191
- def generator_random_vocab_chars (vocab : list [str ], iterations = 100 ) -> Iterator [str ]:
192
+ def generator_random_vocab_chars (vocab : list [str ], iterations = 100 ) -> Iterator [str ]:
192
193
"""Brute force random text with vocab characters"""
193
194
194
195
vocab_chars = set ()
@@ -203,7 +204,7 @@ def generator_random_vocab_chars(vocab: list[str], iterations = 100) -> Iterator
203
204
yield "" .join (text )
204
205
205
206
206
- def generator_random_vocab_words (vocab : list [str ], iterations = 100 ) -> Iterator [str ]:
207
+ def generator_random_vocab_words (vocab : list [str ], iterations = 100 ) -> Iterator [str ]:
207
208
"""Brute force random text from vocab words"""
208
209
209
210
vocab = [w .strip () for w in vocab ]
@@ -222,7 +223,7 @@ def generator_random_vocab_words(vocab: list[str], iterations = 100) -> Iterator
222
223
yield "" .join (text )
223
224
224
225
225
- def generator_random_bytes (iterations = 100 ) -> Iterator [str ]:
226
+ def generator_random_bytes (iterations = 100 ) -> Iterator [str ]:
226
227
"""Brute force random bytes"""
227
228
228
229
WHITESPACES = list (" " * 20 + "\n " * 5 + "\r \n " * 5 + "\t " * 5 )
@@ -243,7 +244,7 @@ def generator_random_bytes(iterations = 100) -> Iterator[str]:
243
244
def test_compare_tokenizer (func_tokenize1 : Callable , func_tokenize2 : Callable , generator : Iterator [str ]):
244
245
245
246
def find_first_mismatch (ids1 : list [int ], ids2 : list [int ]):
246
- for i , (a ,b ) in enumerate (zip (ids1 , ids2 )):
247
+ for i , (a , b ) in enumerate (zip (ids1 , ids2 )):
247
248
if a != b :
248
249
return i
249
250
if len (ids1 ) == len (ids2 ):
@@ -259,33 +260,31 @@ def find_first_mismatch(ids1: list[int], ids2: list[int]):
259
260
i = find_first_mismatch (ids1 , ids2 )
260
261
ids1 = list (ids1 )[max (0 , i - 2 ) : i + 2 + 1 ]
261
262
ids2 = list (ids2 )[max (0 , i - 2 ) : i + 2 + 1 ]
262
- text2 = tokenizer .decode (ids2 , skip_special_tokens = True )
263
- #assert (text2 in text)
264
- logger .info (" Text: " + repr (text2 ))
265
263
logger .info (" TokenIDs: " + str (ids1 ))
266
264
logger .info (" Expected: " + str (ids2 ))
267
265
raise Exception ()
268
266
t1 = time .perf_counter ()
269
267
logger .info ("%s: end, time: %.3f secs" % (generator .__name__ , t1 - t0 ))
270
268
271
269
272
- if __name__ == "__main__" :
273
-
270
+ def main (argv : list [str ] = None ):
274
271
parser = argparse .ArgumentParser ()
275
272
parser .add_argument ("vocab_file" , help = "path to vocab 'gguf' file" )
276
273
parser .add_argument ("dir_tokenizer" , help = "directory containing 'tokenizer.model' file" )
277
274
parser .add_argument ("--verbose" , action = "store_true" , help = "increase output verbosity" )
278
- args = parser .parse_args ()
275
+ args = parser .parse_args (argv )
279
276
280
277
logging .basicConfig (level = logging .DEBUG if args .verbose else logging .INFO )
281
278
279
+ model = LibLlamaModel (LibLlama (), args .vocab_file , mparams = dict (vocab_only = True ), cparams = dict (n_ctx = 4096 ))
282
280
tokenizer = AutoTokenizer .from_pretrained (args .dir_tokenizer )
283
- def func_tokenize2 (text :str ):
281
+
282
+ def func_tokenize2 (text : str ):
284
283
return tokenizer .encode (text , add_special_tokens = False )
285
-
286
- model = LibLlamaModel (LibLlama (), args .vocab_file , mparams = dict (vocab_only = True ), cparams = dict (n_ctx = 4096 ))
284
+
287
285
parse_special = all (len (func_tokenize2 (t )) == 1 for t in tokenizer .all_special_tokens )
288
- def func_tokenize1 (text :str ):
286
+
287
+ def func_tokenize1 (text : str ):
289
288
return model .tokenize (text , add_special = False , parse_special = parse_special )
290
289
291
290
vocab = list (sorted (tokenizer .batch_decode (list (tokenizer .get_vocab ().values ()), skip_special_tokens = True )))
@@ -298,3 +297,7 @@ def func_tokenize1(text:str):
298
297
# test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_bytes(10_000)) # FAIL
299
298
300
299
model .free ()
300
+
301
+
302
+ if __name__ == "__main__" :
303
+ main ()
0 commit comments