Skip to content

Commit 9bc5d83

Browse files
author
jaime-m-p
committed
Minor + style
1 parent 1714e1a commit 9bc5d83

File tree

2 files changed

+34
-30
lines changed

2 files changed

+34
-30
lines changed

scripts/gen-unicode-data.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ class CoodepointFlags (ctypes.Structure):
1515
("is_control", ctypes.c_uint16, 1), # regex: \p{C}
1616
]
1717

18-
assert(ctypes.sizeof(CoodepointFlags) == 2)
18+
19+
assert (ctypes.sizeof(CoodepointFlags) == 2)
1920

2021

2122
MAX_CODEPOINTS = 0x110000
@@ -49,7 +50,7 @@ class CoodepointFlags (ctypes.Structure):
4950
flags.is_symbol = bool(regex_symbol.match(char))
5051
flags.is_control = bool(regex_control.match(char))
5152
flags.is_undefined = bytes(flags)[0] == 0
52-
assert(not flags.is_undefined)
53+
assert (not flags.is_undefined)
5354

5455
# whitespaces
5556
if bool(regex_whitespace.match(char)):
@@ -72,15 +73,15 @@ class CoodepointFlags (ctypes.Structure):
7273

7374

7475
# group ranges with same flags
75-
ranges_flags = [(0, codepoint_flags[0])] # start, flags
76+
ranges_flags = [(0, codepoint_flags[0])] # start, flags
7677
for codepoint, flags in enumerate(codepoint_flags):
7778
if bytes(flags) != bytes(ranges_flags[-1][1]):
7879
ranges_flags.append((codepoint, flags))
7980
ranges_flags.append((MAX_CODEPOINTS, CoodepointFlags()))
8081

8182

8283
# group ranges with same nfd
83-
ranges_nfd = [(0, 0, 0)] # start, last, nfd
84+
ranges_nfd = [(0, 0, 0)] # start, last, nfd
8485
for codepoint, norm in table_nfd:
8586
start = ranges_nfd[-1][0]
8687
if ranges_nfd[-1] != (start, codepoint - 1, norm):

tests/test-tokenizer-random.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# python3 tests/test-tokenizer-random.py ./models/ggml-vocab-llama-bpe.gguf ./models/tokenizers/llama-bpe
77
#
88

9+
import os
910
import time
1011
import logging
1112
import argparse
@@ -15,7 +16,7 @@
1516
from typing import Callable, Iterator
1617

1718
import cffi
18-
from transformers import AutoTokenizer, PreTrainedTokenizerBase
19+
from transformers import AutoTokenizer
1920

2021
logger = logging.getLogger("test-tokenizer-random-bpe")
2122

@@ -145,16 +146,16 @@ def generator_custom_text() -> Iterator[str]:
145146
def generator_custom_text_edge_cases() -> Iterator[str]:
146147
"""Edge cases found while debugging"""
147148
yield from [
148-
'\x1f-a', # unicode_ranges_control, {0x00001C, 0x00001F}
149-
'¼-a', # unicode_ranges_digit, 0x00BC
150-
'½-a', # unicode_ranges_digit, 0x00BD
151-
'¾-a', # unicode_ranges_digit, 0x00BE
152-
'a 〇b', # unicode_ranges_digit, 0x3007
153-
'Ⅵ-a', # unicode_ranges_digit, {0x00002150, 0x0000218F} // Number Forms
154-
'\uFEFF//', # unicode_ranges_control, 0xFEFF (BOM)
155-
'Cửa Việt', # llama-3, ignore_merges = true
156-
'<s>a', # TODO: Phi-3 fail
157-
'a\na', # TODO: Bert fail
149+
'\x1f-a', # unicode_ranges_control, {0x00001C, 0x00001F}
150+
'¼-a', # unicode_ranges_digit, 0x00BC
151+
'½-a', # unicode_ranges_digit, 0x00BD
152+
'¾-a', # unicode_ranges_digit, 0x00BE
153+
'a 〇b', # unicode_ranges_digit, 0x3007
154+
'Ⅵ-a', # unicode_ranges_digit, {0x00002150, 0x0000218F} // Number Forms
155+
'\uFEFF//', # unicode_ranges_control, 0xFEFF (BOM)
156+
'Cửa Việt', # llama-3, ignore_merges = true
157+
'<s>a', # TODO: Phi-3 fail
158+
'a\na', # TODO: Bert fail
158159
]
159160

160161

@@ -163,7 +164,7 @@ def generator_vocab_words(vocab: list[str]) -> Iterator[str]:
163164
yield from vocab
164165

165166

166-
def generator_random_chars(iterations = 100) -> Iterator[str]:
167+
def generator_random_chars(iterations=100) -> Iterator[str]:
167168
"""Brute force random text with simple characters"""
168169

169170
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
@@ -188,7 +189,7 @@ def generator_random_chars(iterations = 100) -> Iterator[str]:
188189
yield "".join(text)
189190

190191

191-
def generator_random_vocab_chars(vocab: list[str], iterations = 100) -> Iterator[str]:
192+
def generator_random_vocab_chars(vocab: list[str], iterations=100) -> Iterator[str]:
192193
"""Brute force random text with vocab characters"""
193194

194195
vocab_chars = set()
@@ -203,7 +204,7 @@ def generator_random_vocab_chars(vocab: list[str], iterations = 100) -> Iterator
203204
yield "".join(text)
204205

205206

206-
def generator_random_vocab_words(vocab: list[str], iterations = 100) -> Iterator[str]:
207+
def generator_random_vocab_words(vocab: list[str], iterations=100) -> Iterator[str]:
207208
"""Brute force random text from vocab words"""
208209

209210
vocab = [w.strip() for w in vocab]
@@ -222,7 +223,7 @@ def generator_random_vocab_words(vocab: list[str], iterations = 100) -> Iterator
222223
yield "".join(text)
223224

224225

225-
def generator_random_bytes(iterations = 100) -> Iterator[str]:
226+
def generator_random_bytes(iterations=100) -> Iterator[str]:
226227
"""Brute force random bytes"""
227228

228229
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
@@ -243,7 +244,7 @@ def generator_random_bytes(iterations = 100) -> Iterator[str]:
243244
def test_compare_tokenizer(func_tokenize1: Callable, func_tokenize2: Callable, generator: Iterator[str]):
244245

245246
def find_first_mismatch(ids1: list[int], ids2: list[int]):
246-
for i, (a,b) in enumerate(zip(ids1, ids2)):
247+
for i, (a, b) in enumerate(zip(ids1, ids2)):
247248
if a != b:
248249
return i
249250
if len(ids1) == len(ids2):
@@ -259,33 +260,31 @@ def find_first_mismatch(ids1: list[int], ids2: list[int]):
259260
i = find_first_mismatch(ids1, ids2)
260261
ids1 = list(ids1)[max(0, i - 2) : i + 2 + 1]
261262
ids2 = list(ids2)[max(0, i - 2) : i + 2 + 1]
262-
text2 = tokenizer.decode(ids2, skip_special_tokens=True)
263-
#assert (text2 in text)
264-
logger.info(" Text: " + repr(text2))
265263
logger.info(" TokenIDs: " + str(ids1))
266264
logger.info(" Expected: " + str(ids2))
267265
raise Exception()
268266
t1 = time.perf_counter()
269267
logger.info("%s: end, time: %.3f secs" % (generator.__name__, t1 - t0))
270268

271269

272-
if __name__ == "__main__":
273-
270+
def main(argv: list[str] = None):
274271
parser = argparse.ArgumentParser()
275272
parser.add_argument("vocab_file", help="path to vocab 'gguf' file")
276273
parser.add_argument("dir_tokenizer", help="directory containing 'tokenizer.model' file")
277274
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
278-
args = parser.parse_args()
275+
args = parser.parse_args(argv)
279276

280277
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
281278

279+
model = LibLlamaModel(LibLlama(), args.vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096))
282280
tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer)
283-
def func_tokenize2(text:str):
281+
282+
def func_tokenize2(text: str):
284283
return tokenizer.encode(text, add_special_tokens=False)
285-
286-
model = LibLlamaModel(LibLlama(), args.vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096))
284+
287285
parse_special = all(len(func_tokenize2(t)) == 1 for t in tokenizer.all_special_tokens)
288-
def func_tokenize1(text:str):
286+
287+
def func_tokenize1(text: str):
289288
return model.tokenize(text, add_special=False, parse_special=parse_special)
290289

291290
vocab = list(sorted(tokenizer.batch_decode(list(tokenizer.get_vocab().values()), skip_special_tokens=True)))
@@ -298,3 +297,7 @@ def func_tokenize1(text:str):
298297
# test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_bytes(10_000)) # FAIL
299298

300299
model.free()
300+
301+
302+
if __name__ == "__main__":
303+
main()

0 commit comments

Comments
 (0)