Skip to content

Commit 6c80b3c

Browse files
committed
Added needed functionality, testing remains
1 parent 0c991be commit 6c80b3c

File tree

5 files changed

+175
-173
lines changed

5 files changed

+175
-173
lines changed

llama.cpp

Lines changed: 21 additions & 32 deletions
Large diffs are not rendered by default.

unicode-data.cpp

Lines changed: 5 additions & 1 deletion
Large diffs are not rendered by default.

unicode-data.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
#include <cstdint>
44
#include <map>
5+
#include <set>
6+
#include <string>
57
#include <utility>
68
#include <vector>
79

@@ -14,4 +16,5 @@ extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_symbol;
1416
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_control;
1517
extern const std::multimap<uint32_t, uint32_t> unicode_map_nfd;
1618
extern const std::map<char32_t, char32_t> unicode_map_lowercase;
17-
extern const std::map<std::string, std::wstring> unicode_regex_to_wregex;
19+
extern const std::map<std::string, std::wstring> unicode_regex_equivalent_wregex;
20+
extern const std::set<std::string> unicode_regex_with_custom_preprocessor;

unicode.cpp

Lines changed: 144 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -228,133 +228,138 @@ static std::vector<std::string> unicode_byte_encoding_process(const std::vector<
228228
return bpe_encoded_words;
229229
}
230230

231-
static std::vector<std::string> unicode_custom_preprocess(const std::string & text) {
232-
std::vector<std::string> bpe_words;
231+
static std::vector<size_t> unicode_gpt2_regex_preprocess(const std::wstring & wtext, const std::vector<size_t> & offsets) {
232+
std::vector<size_t> bpe_offsets; // stroe the offset of each word
233+
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
234+
size_t start = 0;
233235

234-
std::string token = "";
235-
// GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
236-
bool collecting_numeric = false;
237-
bool collecting_letter = false;
238-
bool collecting_special = false;
239-
bool collecting_whitespace_lookahead = false;
240-
bool collecting = false;
241-
242-
std::vector<std::string> text_utf;
243-
text_utf.reserve(text.size());
244-
bpe_words.reserve(text.size());
245-
246-
const auto cpts = unicode_cpts_from_utf8(text);
247-
for (size_t i = 0; i < cpts.size(); ++i)
248-
text_utf.emplace_back(unicode_cpt_to_utf8(cpts[i]));
249-
250-
for (int i = 0; i < (int)text_utf.size(); i++) {
251-
const std::string & utf_char = text_utf[i];
252-
bool split_condition = false;
253-
int bytes_remain = text_utf.size() - i;
254-
// forward backward lookups
255-
const std::string & utf_char_next = (i + 1 < (int)text_utf.size()) ? text_utf[i + 1] : "";
256-
const std::string & utf_char_next_next = (i + 2 < (int)text_utf.size()) ? text_utf[i + 2] : "";
257-
258-
// handling contractions
259-
if (!split_condition && bytes_remain >= 2) {
260-
// 's|'t|'m|'d
261-
if (utf_char == "\'" && (utf_char_next == "s" || utf_char_next == "t" || utf_char_next == "m" || utf_char_next == "d")) {
262-
split_condition = true;
263-
}
264-
if (split_condition) {
265-
if (token.size()) {
266-
bpe_words.emplace_back(token); // push previous content as token
236+
for(auto offset : offsets) {
237+
const std::string text = unicode_wstring_to_utf8(std::wstring(wtext, start, offset));
238+
239+
std::string token = "";
240+
// GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
241+
bool collecting_numeric = false;
242+
bool collecting_letter = false;
243+
bool collecting_special = false;
244+
bool collecting_whitespace_lookahead = false;
245+
bool collecting = false;
246+
247+
std::vector<std::string> text_utf;
248+
text_utf.reserve(text.size());
249+
250+
const auto cpts = unicode_cpts_from_utf8(text);
251+
for (size_t i = 0; i < cpts.size(); ++i)
252+
text_utf.emplace_back(unicode_cpt_to_utf8(cpts[i]));
253+
254+
for (int i = 0; i < (int)text_utf.size(); i++) {
255+
const std::string & utf_char = text_utf[i];
256+
bool split_condition = false;
257+
int bytes_remain = text_utf.size() - i;
258+
// forward backward lookups
259+
const std::string & utf_char_next = (i + 1 < (int)text_utf.size()) ? text_utf[i + 1] : "";
260+
const std::string & utf_char_next_next = (i + 2 < (int)text_utf.size()) ? text_utf[i + 2] : "";
261+
262+
// handling contractions
263+
if (!split_condition && bytes_remain >= 2) {
264+
// 's|'t|'m|'d
265+
if (utf_char == "\'" && (utf_char_next == "s" || utf_char_next == "t" || utf_char_next == "m" || utf_char_next == "d")) {
266+
split_condition = true;
267+
}
268+
if (split_condition) {
269+
if (token.size()) {
270+
bpe_offsets.emplace_back(unicode_wstring_from_utf8(token).size());
271+
}
272+
token = utf_char + utf_char_next;
273+
bpe_offsets.emplace_back(unicode_wstring_from_utf8(token).size());
274+
token = "";
275+
i++;
276+
continue;
267277
}
268-
token = utf_char + utf_char_next;
269-
bpe_words.emplace_back(token);
270-
token = "";
271-
i++;
272-
continue;
273-
}
274-
}
275-
if (!split_condition && bytes_remain >= 3) {
276-
// 're|'ve|'ll
277-
if (utf_char == "\'" && (
278-
(utf_char_next == "r" && utf_char_next_next == "e") ||
279-
(utf_char_next == "v" && utf_char_next_next == "e") ||
280-
(utf_char_next == "l" && utf_char_next_next == "l"))
281-
) {
282-
split_condition = true;
283278
}
284-
if (split_condition) {
285-
// current token + next token can be defined
286-
if (token.size()) {
287-
bpe_words.emplace_back(token); // push previous content as token
279+
if (!split_condition && bytes_remain >= 3) {
280+
// 're|'ve|'ll
281+
if (utf_char == "\'" && (
282+
(utf_char_next == "r" && utf_char_next_next == "e") ||
283+
(utf_char_next == "v" && utf_char_next_next == "e") ||
284+
(utf_char_next == "l" && utf_char_next_next == "l"))
285+
) {
286+
split_condition = true;
287+
}
288+
if (split_condition) {
289+
// current token + next token can be defined
290+
if (token.size()) {
291+
bpe_offsets.emplace_back(unicode_wstring_from_utf8(token).size());
292+
}
293+
token = utf_char + utf_char_next + utf_char_next_next;
294+
bpe_offsets.emplace_back(unicode_wstring_from_utf8(token).size());
295+
token = "";
296+
i += 2;
297+
continue;
288298
}
289-
token = utf_char + utf_char_next + utf_char_next_next;
290-
bpe_words.emplace_back(token); // the contraction
291-
token = "";
292-
i += 2;
293-
continue;
294299
}
295-
}
296300

297-
if (!split_condition && !collecting) {
298-
if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || (!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER)) {
299-
collecting_letter = true;
300-
collecting = true;
301-
}
302-
else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || (!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
303-
collecting_numeric = true;
304-
collecting = true;
305-
}
306-
else if (
307-
((unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) && (unicode_cpt_type(utf_char) != CODEPOINT_TYPE_WHITESPACE)) ||
308-
(!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_DIGIT && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE)
309-
) {
310-
collecting_special = true;
311-
collecting = true;
312-
}
313-
else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_WHITESPACE) {
314-
collecting_whitespace_lookahead = true;
315-
collecting = true;
316-
}
317-
else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE) {
318-
split_condition = true;
319-
}
320-
}
321-
else if (!split_condition && collecting) {
322-
if (collecting_letter && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER) {
323-
split_condition = true;
324-
}
325-
else if (collecting_numeric && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) {
326-
split_condition = true;
327-
}
328-
else if (collecting_special && (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE)) {
329-
split_condition = true;
301+
if (!split_condition && !collecting) {
302+
if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || (!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER)) {
303+
collecting_letter = true;
304+
collecting = true;
305+
}
306+
else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || (!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
307+
collecting_numeric = true;
308+
collecting = true;
309+
}
310+
else if (
311+
((unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) && (unicode_cpt_type(utf_char) != CODEPOINT_TYPE_WHITESPACE)) ||
312+
(!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_DIGIT && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE)
313+
) {
314+
collecting_special = true;
315+
collecting = true;
316+
}
317+
else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_WHITESPACE) {
318+
collecting_whitespace_lookahead = true;
319+
collecting = true;
320+
}
321+
else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE) {
322+
split_condition = true;
323+
}
330324
}
331-
else if (collecting_whitespace_lookahead && (unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
332-
split_condition = true;
325+
else if (!split_condition && collecting) {
326+
if (collecting_letter && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER) {
327+
split_condition = true;
328+
}
329+
else if (collecting_numeric && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) {
330+
split_condition = true;
331+
}
332+
else if (collecting_special && (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE)) {
333+
split_condition = true;
334+
}
335+
else if (collecting_whitespace_lookahead && (unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
336+
split_condition = true;
337+
}
333338
}
334-
}
335339

336-
if (utf_char_next == "") {
337-
split_condition = true; // final
338-
token += utf_char;
339-
}
340+
if (utf_char_next == "") {
341+
split_condition = true; // final
342+
token += utf_char;
343+
}
340344

341-
if (split_condition) {
342-
if (token.size()) {
343-
bpe_words.emplace_back(token);
345+
if (split_condition) {
346+
if (token.size()) {
347+
bpe_offsets.emplace_back(unicode_wstring_from_utf8(token).size());
348+
}
349+
token = utf_char;
350+
collecting = false;
351+
collecting_letter = false;
352+
collecting_numeric = false;
353+
collecting_special = false;
354+
collecting_whitespace_lookahead = false;
355+
}
356+
else {
357+
token += utf_char;
344358
}
345-
token = utf_char;
346-
collecting = false;
347-
collecting_letter = false;
348-
collecting_numeric = false;
349-
collecting_special = false;
350-
collecting_whitespace_lookahead = false;
351-
}
352-
else {
353-
token += utf_char;
354359
}
355360
}
356361

357-
return bpe_words;
362+
return bpe_offsets;
358363
}
359364

360365
static std::vector<size_t> unicode_regex_preprocess(const std::wstring & text, const std::vector<size_t> & offsets, const std::wstring & regex_expr) {
@@ -386,16 +391,22 @@ static std::vector<size_t> unicode_regex_preprocess(const std::wstring & text, c
386391
return bpe_offsets;
387392
}
388393

389-
static bool unicode_regex_matched(const std::wstring & text, const std::vector<std::wstring> & regex_exprs) {
394+
static bool unicode_regex_equivalent_wregex_exists(const std::string & regex) {
395+
return unicode_regex_equivalent_wregex.find(regex) != unicode_regex_equivalent_wregex.end();
396+
}
390397

391-
for(auto & regex_expr: regex_exprs) {
392-
std::wregex expr(regex_expr);
393-
if(std::regex_match(text, expr)) {
394-
return true;
395-
}
398+
static bool unicode_regex_with_custom_preprocessor_exists(const std::string & regex) {
399+
return unicode_regex_with_custom_preprocessor.find(regex) != unicode_regex_with_custom_preprocessor.end();
400+
}
401+
402+
static std::vector<size_t> unicode_regex_custom_preprocess(const std::string & regex, const std::wstring & wtext, const std::vector<size_t> & offsets) {
403+
std::vector<size_t> bpe_offsets;
404+
405+
if(regex == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") {
406+
bpe_offsets = unicode_gpt2_regex_preprocess(wtext, offsets);
396407
}
397408

398-
return false;
409+
return bpe_offsets;
399410
}
400411

401412
//
@@ -479,33 +490,29 @@ char32_t unicode_tolower(char32_t cp) {
479490
auto it = unicode_map_lowercase.find(cp);
480491
return it == unicode_map_lowercase.end() ? cp : it->second;
481492
}
482-
483-
bool unicode_wregex_exists(const std::string & regex) {
484-
return unicode_regex_to_wregex.find(regex) != unicode_regex_to_wregex.end();
485-
}
486-
487-
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::wstring> & regex_exprs) {
493+
494+
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
488495
std::wstring wtext = unicode_wstring_from_utf8(text);
489496

490497
std::vector<size_t> bpe_offsets = {wtext.size()};
491498

492499
for(auto & regex_expr : regex_exprs) {
493-
bpe_offsets = unicode_regex_preprocess(wtext, bpe_offsets, regex_expr);
500+
501+
if (unicode_regex_equivalent_wregex_exists(regex_expr)) {
502+
const std::wstring& wregex_expr = unicode_regex_equivalent_wregex.at(regex_expr);
503+
bpe_offsets = unicode_regex_preprocess(wtext, bpe_offsets, wregex_expr);
504+
} else if (unicode_regex_with_custom_preprocessor_exists(regex_expr)) {
505+
bpe_offsets = unicode_regex_custom_preprocess(regex_expr, wtext, bpe_offsets);
506+
} else {
507+
throw std::runtime_error("Unicode regex is not found");
508+
}
494509
}
495510

496511
std::vector<std::string> bpe_words;
497512
bpe_words.reserve(bpe_offsets.size()); // Reserve memory for the approximate size
498513
size_t start = 0;
499-
for(size_t & offset : bpe_offsets){
500-
const auto temp_word = std::wstring(wtext, start, offset);
501-
502-
if(unicode_regex_matched(temp_word, regex_exprs)) {
503-
bpe_words.emplace_back(unicode_wstring_to_utf8(temp_word));
504-
} else {
505-
auto custom_bpe_words = unicode_custom_preprocess(unicode_wstring_to_utf8(temp_word));
506-
bpe_words.insert(bpe_words.end(), custom_bpe_words.begin(), custom_bpe_words.end());
507-
}
508-
514+
for(size_t & offset : bpe_offsets) {
515+
bpe_words.emplace_back(unicode_wstring_to_utf8(std::wstring(wtext, start, offset)));
509516
start += offset;
510517
}
511518

unicode.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,4 @@ uint8_t unicode_utf8_to_byte(const std::string & utf8);
2828

2929
char32_t unicode_tolower(char32_t cp);
3030

31-
bool unicode_wregex_exists(const std::string & regex);
32-
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::wstring> & regex_exprs);
31+
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs);

0 commit comments

Comments
 (0)