Skip to content

Commit e44e608

Browse files
author
jaime-m-p
committed
Replace CODEPOINT_TYPE_* with codepoint_flags
1 parent 3b3963c commit e44e608

File tree

6 files changed

+5359
-2355
lines changed

6 files changed

+5359
-2355
lines changed

llama.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12575,16 +12575,16 @@ struct llm_tokenizer_wpm {
1257512575
// to lowercase, pad chinese characters, pad punctuation
1257612576
std::string new_str = "";
1257712577
for (uint32_t code : cpts_nfd) {
12578-
int type = unicode_cpt_type(code);
12579-
if (type == CODEPOINT_TYPE_ACCENT_MARK || type == CODEPOINT_TYPE_CONTROL) {
12578+
const codepoint_flags flags = unicode_cpt_flags(code);
12579+
if (flags.is_accent_mark || flags.is_control) {
1258012580
continue;
1258112581
}
1258212582
code = unicode_tolower(code);
12583-
if (type == CODEPOINT_TYPE_SEPARATOR) {
12583+
if (flags.is_separator || flags.is_whitespace) { //####FIXME: is_separator ?
1258412584
code = ' ';
1258512585
}
1258612586
std::string s = unicode_cpt_to_utf8(code);
12587-
if (type == CODEPOINT_TYPE_PUNCTUATION || is_ascii_punct(code) || is_chinese_char(code)) {
12587+
if (flags.is_punctuation || is_ascii_punct(code) || is_chinese_char(code)) {
1258812588
new_str += " ";
1258912589
new_str += s;
1259012590
new_str += " ";

scripts/gen-unicode-data.py

Lines changed: 88 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,105 @@
11
import regex
2+
import ctypes
23

34

4-
def get_matches(regex_expr):
5-
regex_expr_compiled = regex.compile(regex_expr)
6-
unicode_ranges = []
7-
current_range = None
5+
class CoodepointFlags (ctypes.Structure):
6+
_fields_ = [ # see definition in unicode.h
7+
("is_undefined", ctypes.c_uint16, 1),
8+
("is_number", ctypes.c_uint16, 1), # regex: \p{N}
9+
("is_letter", ctypes.c_uint16, 1), # regex: \p{L}
10+
("is_separator", ctypes.c_uint16, 1), # regex: \p{Z}
11+
("is_accent_mark", ctypes.c_uint16, 1), # regex: \p{M}
12+
("is_punctuation", ctypes.c_uint16, 1), # regex: \p{P}
13+
("is_symbol", ctypes.c_uint16, 1), # regex: \p{S}
14+
("is_control", ctypes.c_uint16, 1), # regex: \p{C}
15+
]
816

9-
for codepoint in range(0x110000):
10-
char = chr(codepoint)
11-
if regex_expr_compiled.match(char):
12-
if current_range is None:
13-
current_range = [codepoint, codepoint]
14-
else:
15-
current_range[1] = codepoint
16-
elif current_range is not None:
17-
unicode_ranges.append(tuple(current_range))
18-
current_range = None
17+
assert(ctypes.sizeof(CoodepointFlags) == 2)
1918

20-
if current_range is not None:
21-
unicode_ranges.append(tuple(current_range))
2219

23-
return unicode_ranges
20+
MAX_CODEPOINTS = 0x110000
2421

22+
regex_number = regex.compile(r'\p{N}')
23+
regex_letter = regex.compile(r'\p{L}')
24+
regex_separator = regex.compile(r'\p{Z}')
25+
regex_accent_mark = regex.compile(r'\p{M}')
26+
regex_punctuation = regex.compile(r'\p{P}')
27+
regex_symbol = regex.compile(r'\p{S}')
28+
regex_control = regex.compile(r'\p{C}')
29+
regex_whitespace = regex.compile(r'\s')
2530

26-
def print_cat(mode, cat, ranges):
27-
if mode == "range":
28-
print("const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_{} = {{".format(cat)) # noqa: NP100
29-
if mode == "map":
30-
print("const std::map<uint32_t, uint32_t> unicode_map_{} = {{".format(cat)) # noqa: NP100
31-
for i, values in enumerate(ranges):
32-
end = ",\n" if (i % 4 == 3 or i + 1 == len(ranges)) else ", "
33-
values = ["0x%08X" % value for value in values]
34-
print("{" + ", ".join(values) + "}", end=end) # noqa: NP100
35-
print("};") # noqa: NP100
36-
print("") # noqa: NP100
31+
codepoint_flags = (CoodepointFlags * MAX_CODEPOINTS)()
32+
table_whitespace = []
33+
table_lowercase = []
34+
table_uppercase = []
3735

36+
for codepoint in range(MAX_CODEPOINTS):
37+
# convert codepoint to unicode character
38+
char = chr(codepoint)
3839

39-
print_cat("range", "number", get_matches(r'\p{N}'))
40-
print_cat("range", "letter", get_matches(r'\p{L}'))
41-
print_cat("range", "separator", get_matches(r'\p{Z}'))
42-
print_cat("range", "accent_mark", get_matches(r'\p{M}'))
43-
print_cat("range", "punctuation", get_matches(r'\p{P}'))
44-
print_cat("range", "symbol", get_matches(r'\p{S}'))
45-
print_cat("range", "control", get_matches(r'\p{C}'))
46-
47-
print_cat("range", "whitespace", get_matches(r'\s'))
40+
# regex categories
41+
flags = codepoint_flags[codepoint]
42+
flags.is_number = bool(regex_number.match(char))
43+
flags.is_letter = bool(regex_letter.match(char))
44+
flags.is_separator = bool(regex_separator.match(char))
45+
flags.is_accent_mark = bool(regex_accent_mark.match(char))
46+
flags.is_punctuation = bool(regex_punctuation.match(char))
47+
flags.is_symbol = bool(regex_symbol.match(char))
48+
flags.is_control = bool(regex_control.match(char))
49+
flags.is_undefined = bytes(flags)[0] == 0
50+
assert(not flags.is_undefined)
4851

52+
# whitespaces
53+
if bool(regex_whitespace.match(char)):
54+
table_whitespace.append(codepoint)
4955

50-
map_lowercase = []
51-
map_uppercase = []
52-
for codepoint in range(0x110000):
53-
char = chr(codepoint)
56+
# lowercase conversion
5457
lower = ord(char.lower()[0])
55-
upper = ord(char.upper()[0])
5658
if codepoint != lower:
57-
map_lowercase.append((codepoint, lower))
59+
table_lowercase.append((codepoint, lower))
60+
61+
# uppercase conversion
62+
upper = ord(char.upper()[0])
5863
if codepoint != upper:
59-
map_uppercase.append((codepoint, upper))
60-
print_cat("map", "lowercase", map_lowercase)
61-
print_cat("map", "uppercase", map_uppercase)
64+
table_uppercase.append((codepoint, upper))
65+
66+
67+
ranges_flags = [(0, codepoint_flags[0])]
68+
for codepoint, flags in enumerate(codepoint_flags):
69+
if bytes(flags) != bytes(ranges_flags[-1][1]):
70+
ranges_flags.append((codepoint, flags))
71+
ranges_flags.append((MAX_CODEPOINTS, CoodepointFlags()))
72+
73+
74+
# Generate 'unicode-data.cpp'
75+
76+
print("""\
77+
// generated with scripts/gen-unicode-data.py
78+
79+
#include "unicode-data.h"
80+
81+
#include <cstdint>
82+
#include <vector>
83+
#include <unordered_map>
84+
#include <unordered_set>
85+
""")
86+
87+
print("const std::vector<std::pair<uint32_t, uint16_t>> unicode_ranges_flags = { // start, flags // last=next_start-1")
88+
for codepoint, flags in ranges_flags:
89+
flags = int.from_bytes(bytes(flags), "little")
90+
print("{0x%06X, 0x%04X}," % (codepoint, flags))
91+
print("};\n")
92+
93+
print("const std::unordered_set<uint32_t> unicode_set_whitespace = {")
94+
print(", ".join("0x%06X" % cpt for cpt in table_whitespace))
95+
print("};\n")
6296

97+
print("const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase = {")
98+
for tuple in table_lowercase:
99+
print("{0x%06X, 0x%06X}," % tuple)
100+
print("};\n")
63101

64-
# TODO: generate unicode_map_nfd
102+
print("const std::unordered_map<uint32_t, uint32_t> unicode_map_uppercase = {")
103+
for tuple in table_uppercase:
104+
print("{0x%06X, 0x%06X}," % tuple)
105+
print("};\n")

0 commit comments

Comments
 (0)