@@ -228,133 +228,138 @@ static std::vector<std::string> unicode_byte_encoding_process(const std::vector<
228
228
return bpe_encoded_words;
229
229
}
230
230
231
- static std::vector<std::string> unicode_custom_preprocess (const std::string & text) {
232
- std::vector<std::string> bpe_words;
231
+ static std::vector<size_t > unicode_gpt2_regex_preprocess (const std::wstring & wtext, const std::vector<size_t > & offsets) {
232
+ std::vector<size_t > bpe_offsets; // stroe the offset of each word
233
+ bpe_offsets.reserve (offsets.size ()); // Reserve memory for the approximate size
234
+ size_t start = 0 ;
233
235
234
- std::string token = " " ;
235
- // GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
236
- bool collecting_numeric = false ;
237
- bool collecting_letter = false ;
238
- bool collecting_special = false ;
239
- bool collecting_whitespace_lookahead = false ;
240
- bool collecting = false ;
241
-
242
- std::vector<std::string> text_utf;
243
- text_utf.reserve (text.size ());
244
- bpe_words.reserve (text.size ());
245
-
246
- const auto cpts = unicode_cpts_from_utf8 (text);
247
- for (size_t i = 0 ; i < cpts.size (); ++i)
248
- text_utf.emplace_back (unicode_cpt_to_utf8 (cpts[i]));
249
-
250
- for (int i = 0 ; i < (int )text_utf.size (); i++) {
251
- const std::string & utf_char = text_utf[i];
252
- bool split_condition = false ;
253
- int bytes_remain = text_utf.size () - i;
254
- // forward backward lookups
255
- const std::string & utf_char_next = (i + 1 < (int )text_utf.size ()) ? text_utf[i + 1 ] : " " ;
256
- const std::string & utf_char_next_next = (i + 2 < (int )text_utf.size ()) ? text_utf[i + 2 ] : " " ;
257
-
258
- // handling contractions
259
- if (!split_condition && bytes_remain >= 2 ) {
260
- // 's|'t|'m|'d
261
- if (utf_char == " \' " && (utf_char_next == " s" || utf_char_next == " t" || utf_char_next == " m" || utf_char_next == " d" )) {
262
- split_condition = true ;
263
- }
264
- if (split_condition) {
265
- if (token.size ()) {
266
- bpe_words.emplace_back (token); // push previous content as token
236
+ for (auto offset : offsets) {
237
+ const std::string text = unicode_wstring_to_utf8 (std::wstring (wtext, start, offset));
238
+
239
+ std::string token = " " ;
240
+ // GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
241
+ bool collecting_numeric = false ;
242
+ bool collecting_letter = false ;
243
+ bool collecting_special = false ;
244
+ bool collecting_whitespace_lookahead = false ;
245
+ bool collecting = false ;
246
+
247
+ std::vector<std::string> text_utf;
248
+ text_utf.reserve (text.size ());
249
+
250
+ const auto cpts = unicode_cpts_from_utf8 (text);
251
+ for (size_t i = 0 ; i < cpts.size (); ++i)
252
+ text_utf.emplace_back (unicode_cpt_to_utf8 (cpts[i]));
253
+
254
+ for (int i = 0 ; i < (int )text_utf.size (); i++) {
255
+ const std::string & utf_char = text_utf[i];
256
+ bool split_condition = false ;
257
+ int bytes_remain = text_utf.size () - i;
258
+ // forward backward lookups
259
+ const std::string & utf_char_next = (i + 1 < (int )text_utf.size ()) ? text_utf[i + 1 ] : " " ;
260
+ const std::string & utf_char_next_next = (i + 2 < (int )text_utf.size ()) ? text_utf[i + 2 ] : " " ;
261
+
262
+ // handling contractions
263
+ if (!split_condition && bytes_remain >= 2 ) {
264
+ // 's|'t|'m|'d
265
+ if (utf_char == " \' " && (utf_char_next == " s" || utf_char_next == " t" || utf_char_next == " m" || utf_char_next == " d" )) {
266
+ split_condition = true ;
267
+ }
268
+ if (split_condition) {
269
+ if (token.size ()) {
270
+ bpe_offsets.emplace_back (unicode_wstring_from_utf8 (token).size ());
271
+ }
272
+ token = utf_char + utf_char_next;
273
+ bpe_offsets.emplace_back (unicode_wstring_from_utf8 (token).size ());
274
+ token = " " ;
275
+ i++;
276
+ continue ;
267
277
}
268
- token = utf_char + utf_char_next;
269
- bpe_words.emplace_back (token);
270
- token = " " ;
271
- i++;
272
- continue ;
273
- }
274
- }
275
- if (!split_condition && bytes_remain >= 3 ) {
276
- // 're|'ve|'ll
277
- if (utf_char == " \' " && (
278
- (utf_char_next == " r" && utf_char_next_next == " e" ) ||
279
- (utf_char_next == " v" && utf_char_next_next == " e" ) ||
280
- (utf_char_next == " l" && utf_char_next_next == " l" ))
281
- ) {
282
- split_condition = true ;
283
278
}
284
- if (split_condition) {
285
- // current token + next token can be defined
286
- if (token.size ()) {
287
- bpe_words.emplace_back (token); // push previous content as token
279
+ if (!split_condition && bytes_remain >= 3 ) {
280
+ // 're|'ve|'ll
281
+ if (utf_char == " \' " && (
282
+ (utf_char_next == " r" && utf_char_next_next == " e" ) ||
283
+ (utf_char_next == " v" && utf_char_next_next == " e" ) ||
284
+ (utf_char_next == " l" && utf_char_next_next == " l" ))
285
+ ) {
286
+ split_condition = true ;
287
+ }
288
+ if (split_condition) {
289
+ // current token + next token can be defined
290
+ if (token.size ()) {
291
+ bpe_offsets.emplace_back (unicode_wstring_from_utf8 (token).size ());
292
+ }
293
+ token = utf_char + utf_char_next + utf_char_next_next;
294
+ bpe_offsets.emplace_back (unicode_wstring_from_utf8 (token).size ());
295
+ token = " " ;
296
+ i += 2 ;
297
+ continue ;
288
298
}
289
- token = utf_char + utf_char_next + utf_char_next_next;
290
- bpe_words.emplace_back (token); // the contraction
291
- token = " " ;
292
- i += 2 ;
293
- continue ;
294
299
}
295
- }
296
300
297
- if (!split_condition && !collecting) {
298
- if (unicode_cpt_type (utf_char) == CODEPOINT_TYPE_LETTER || (!token.size () && utf_char == " " && unicode_cpt_type (utf_char_next) == CODEPOINT_TYPE_LETTER)) {
299
- collecting_letter = true ;
300
- collecting = true ;
301
- }
302
- else if (unicode_cpt_type (utf_char) == CODEPOINT_TYPE_DIGIT || (!token.size () && utf_char == " " && unicode_cpt_type (utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
303
- collecting_numeric = true ;
304
- collecting = true ;
305
- }
306
- else if (
307
- ((unicode_cpt_type (utf_char) != CODEPOINT_TYPE_LETTER && unicode_cpt_type (utf_char) != CODEPOINT_TYPE_DIGIT) && (unicode_cpt_type (utf_char) != CODEPOINT_TYPE_WHITESPACE)) ||
308
- (!token.size () && utf_char == " " && unicode_cpt_type (utf_char_next) != CODEPOINT_TYPE_LETTER && unicode_cpt_type (utf_char_next) != CODEPOINT_TYPE_DIGIT && unicode_cpt_type (utf_char_next) != CODEPOINT_TYPE_WHITESPACE)
309
- ) {
310
- collecting_special = true ;
311
- collecting = true ;
312
- }
313
- else if (unicode_cpt_type (utf_char) == CODEPOINT_TYPE_WHITESPACE && unicode_cpt_type (utf_char_next) == CODEPOINT_TYPE_WHITESPACE) {
314
- collecting_whitespace_lookahead = true ;
315
- collecting = true ;
316
- }
317
- else if (unicode_cpt_type (utf_char) == CODEPOINT_TYPE_WHITESPACE) {
318
- split_condition = true ;
319
- }
320
- }
321
- else if (!split_condition && collecting) {
322
- if (collecting_letter && unicode_cpt_type (utf_char) != CODEPOINT_TYPE_LETTER) {
323
- split_condition = true ;
324
- }
325
- else if (collecting_numeric && unicode_cpt_type (utf_char) != CODEPOINT_TYPE_DIGIT) {
326
- split_condition = true ;
327
- }
328
- else if (collecting_special && (unicode_cpt_type (utf_char) == CODEPOINT_TYPE_LETTER || unicode_cpt_type (utf_char) == CODEPOINT_TYPE_DIGIT || unicode_cpt_type (utf_char) == CODEPOINT_TYPE_WHITESPACE)) {
329
- split_condition = true ;
301
+ if (!split_condition && !collecting) {
302
+ if (unicode_cpt_type (utf_char) == CODEPOINT_TYPE_LETTER || (!token.size () && utf_char == " " && unicode_cpt_type (utf_char_next) == CODEPOINT_TYPE_LETTER)) {
303
+ collecting_letter = true ;
304
+ collecting = true ;
305
+ }
306
+ else if (unicode_cpt_type (utf_char) == CODEPOINT_TYPE_DIGIT || (!token.size () && utf_char == " " && unicode_cpt_type (utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
307
+ collecting_numeric = true ;
308
+ collecting = true ;
309
+ }
310
+ else if (
311
+ ((unicode_cpt_type (utf_char) != CODEPOINT_TYPE_LETTER && unicode_cpt_type (utf_char) != CODEPOINT_TYPE_DIGIT) && (unicode_cpt_type (utf_char) != CODEPOINT_TYPE_WHITESPACE)) ||
312
+ (!token.size () && utf_char == " " && unicode_cpt_type (utf_char_next) != CODEPOINT_TYPE_LETTER && unicode_cpt_type (utf_char_next) != CODEPOINT_TYPE_DIGIT && unicode_cpt_type (utf_char_next) != CODEPOINT_TYPE_WHITESPACE)
313
+ ) {
314
+ collecting_special = true ;
315
+ collecting = true ;
316
+ }
317
+ else if (unicode_cpt_type (utf_char) == CODEPOINT_TYPE_WHITESPACE && unicode_cpt_type (utf_char_next) == CODEPOINT_TYPE_WHITESPACE) {
318
+ collecting_whitespace_lookahead = true ;
319
+ collecting = true ;
320
+ }
321
+ else if (unicode_cpt_type (utf_char) == CODEPOINT_TYPE_WHITESPACE) {
322
+ split_condition = true ;
323
+ }
330
324
}
331
- else if (collecting_whitespace_lookahead && (unicode_cpt_type (utf_char_next) == CODEPOINT_TYPE_LETTER || unicode_cpt_type (utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
332
- split_condition = true ;
325
+ else if (!split_condition && collecting) {
326
+ if (collecting_letter && unicode_cpt_type (utf_char) != CODEPOINT_TYPE_LETTER) {
327
+ split_condition = true ;
328
+ }
329
+ else if (collecting_numeric && unicode_cpt_type (utf_char) != CODEPOINT_TYPE_DIGIT) {
330
+ split_condition = true ;
331
+ }
332
+ else if (collecting_special && (unicode_cpt_type (utf_char) == CODEPOINT_TYPE_LETTER || unicode_cpt_type (utf_char) == CODEPOINT_TYPE_DIGIT || unicode_cpt_type (utf_char) == CODEPOINT_TYPE_WHITESPACE)) {
333
+ split_condition = true ;
334
+ }
335
+ else if (collecting_whitespace_lookahead && (unicode_cpt_type (utf_char_next) == CODEPOINT_TYPE_LETTER || unicode_cpt_type (utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
336
+ split_condition = true ;
337
+ }
333
338
}
334
- }
335
339
336
- if (utf_char_next == " " ) {
337
- split_condition = true ; // final
338
- token += utf_char;
339
- }
340
+ if (utf_char_next == " " ) {
341
+ split_condition = true ; // final
342
+ token += utf_char;
343
+ }
340
344
341
- if (split_condition) {
342
- if (token.size ()) {
343
- bpe_words.emplace_back (token);
345
+ if (split_condition) {
346
+ if (token.size ()) {
347
+ bpe_offsets.emplace_back (unicode_wstring_from_utf8 (token).size ());
348
+ }
349
+ token = utf_char;
350
+ collecting = false ;
351
+ collecting_letter = false ;
352
+ collecting_numeric = false ;
353
+ collecting_special = false ;
354
+ collecting_whitespace_lookahead = false ;
355
+ }
356
+ else {
357
+ token += utf_char;
344
358
}
345
- token = utf_char;
346
- collecting = false ;
347
- collecting_letter = false ;
348
- collecting_numeric = false ;
349
- collecting_special = false ;
350
- collecting_whitespace_lookahead = false ;
351
- }
352
- else {
353
- token += utf_char;
354
359
}
355
360
}
356
361
357
- return bpe_words ;
362
+ return bpe_offsets ;
358
363
}
359
364
360
365
static std::vector<size_t > unicode_regex_preprocess (const std::wstring & text, const std::vector<size_t > & offsets, const std::wstring & regex_expr) {
@@ -386,16 +391,22 @@ static std::vector<size_t> unicode_regex_preprocess(const std::wstring & text, c
386
391
return bpe_offsets;
387
392
}
388
393
389
- static bool unicode_regex_matched (const std::wstring & text, const std::vector<std::wstring> & regex_exprs) {
394
+ static bool unicode_regex_equivalent_wregex_exists (const std::string & regex) {
395
+ return unicode_regex_equivalent_wregex.find (regex) != unicode_regex_equivalent_wregex.end ();
396
+ }
390
397
391
- for (auto & regex_expr: regex_exprs) {
392
- std::wregex expr (regex_expr);
393
- if (std::regex_match (text, expr)) {
394
- return true ;
395
- }
398
+ static bool unicode_regex_with_custom_preprocessor_exists (const std::string & regex) {
399
+ return unicode_regex_with_custom_preprocessor.find (regex) != unicode_regex_with_custom_preprocessor.end ();
400
+ }
401
+
402
+ static std::vector<size_t > unicode_regex_custom_preprocess (const std::string & regex, const std::wstring & wtext, const std::vector<size_t > & offsets) {
403
+ std::vector<size_t > bpe_offsets;
404
+
405
+ if (regex == " 's|'t|'re|'ve|'m|'ll|'d| ?\\ p{L}+| ?\\ p{N}+| ?[^\\ s\\ p{L}\\ p{N}]+|\\ s+(?!\\ S)" ) {
406
+ bpe_offsets = unicode_gpt2_regex_preprocess (wtext, offsets);
396
407
}
397
408
398
- return false ;
409
+ return bpe_offsets ;
399
410
}
400
411
401
412
//
@@ -479,33 +490,29 @@ char32_t unicode_tolower(char32_t cp) {
479
490
auto it = unicode_map_lowercase.find (cp);
480
491
return it == unicode_map_lowercase.end () ? cp : it->second ;
481
492
}
482
-
483
- bool unicode_wregex_exists (const std::string & regex) {
484
- return unicode_regex_to_wregex.find (regex) != unicode_regex_to_wregex.end ();
485
- }
486
-
487
- std::vector<std::string> unicode_regex_split (const std::string & text, const std::vector<std::wstring> & regex_exprs) {
493
+
494
+ std::vector<std::string> unicode_regex_split (const std::string & text, const std::vector<std::string> & regex_exprs) {
488
495
std::wstring wtext = unicode_wstring_from_utf8 (text);
489
496
490
497
std::vector<size_t > bpe_offsets = {wtext.size ()};
491
498
492
499
for (auto & regex_expr : regex_exprs) {
493
- bpe_offsets = unicode_regex_preprocess (wtext, bpe_offsets, regex_expr);
500
+
501
+ if (unicode_regex_equivalent_wregex_exists (regex_expr)) {
502
+ const std::wstring& wregex_expr = unicode_regex_equivalent_wregex.at (regex_expr);
503
+ bpe_offsets = unicode_regex_preprocess (wtext, bpe_offsets, wregex_expr);
504
+ } else if (unicode_regex_with_custom_preprocessor_exists (regex_expr)) {
505
+ bpe_offsets = unicode_regex_custom_preprocess (regex_expr, wtext, bpe_offsets);
506
+ } else {
507
+ throw std::runtime_error (" Unicode regex is not found" );
508
+ }
494
509
}
495
510
496
511
std::vector<std::string> bpe_words;
497
512
bpe_words.reserve (bpe_offsets.size ()); // Reserve memory for the approximate size
498
513
size_t start = 0 ;
499
- for (size_t & offset : bpe_offsets){
500
- const auto temp_word = std::wstring (wtext, start, offset);
501
-
502
- if (unicode_regex_matched (temp_word, regex_exprs)) {
503
- bpe_words.emplace_back (unicode_wstring_to_utf8 (temp_word));
504
- } else {
505
- auto custom_bpe_words = unicode_custom_preprocess (unicode_wstring_to_utf8 (temp_word));
506
- bpe_words.insert (bpe_words.end (), custom_bpe_words.begin (), custom_bpe_words.end ());
507
- }
508
-
514
+ for (size_t & offset : bpe_offsets) {
515
+ bpe_words.emplace_back (unicode_wstring_to_utf8 (std::wstring (wtext, start, offset)));
509
516
start += offset;
510
517
}
511
518
0 commit comments