Skip to content

Commit cf45252

Browse files
authored
tests : multi-thread the tokenizer tests (#5474)
* tests : multi-thread the tokenizer tests ggml-ci * unicode : fix data race for unidentified codepoints ggml-ci * unicode : minor style fixes ggml-ci
1 parent 03bf161 commit cf45252

File tree

4 files changed

+121
-99
lines changed

4 files changed

+121
-99
lines changed

llama.cpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7782,7 +7782,7 @@ struct llm_bigram_spm {
77827782
};
77837783

77847784
struct llm_tokenizer_spm {
7785-
llm_tokenizer_spm(const llama_vocab & vocab): vocab(vocab) {}
7785+
llm_tokenizer_spm(const llama_vocab & vocab) : vocab(vocab) {}
77867786

77877787
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
77887788
// split string into utf8 chars
@@ -7857,6 +7857,7 @@ struct llm_tokenizer_spm {
78577857

78587858
if (p == rev_merge.end()) {
78597859
// output any symbols that did not form tokens as bytes.
7860+
output.reserve(output.size() + symbol.n);
78607861
for (int j = 0; j < (int)symbol.n; ++j) {
78617862
llama_vocab::id token_id = llama_byte_to_token(vocab, symbol.text[j]);
78627863
output.push_back(token_id);
@@ -8419,17 +8420,18 @@ struct fragment_buffer_variant {
84198420
token(_token),
84208421
raw_text(_dummy),
84218422
offset(0),
8422-
length(0){}
8423+
length(0) {}
8424+
84238425
fragment_buffer_variant(const std::string & _raw_text, int64_t _offset, int64_t _length)
84248426
:
84258427
type(FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT),
8426-
token((llama_vocab::id)-1),
8428+
token((llama_vocab::id) - 1),
84278429
raw_text(_raw_text),
84288430
offset(_offset),
84298431
length(_length){
8430-
GGML_ASSERT( _offset >= 0 );
8431-
GGML_ASSERT( _length >= 1 );
8432-
GGML_ASSERT( offset + length <= raw_text.length() );
8432+
GGML_ASSERT(_offset >= 0);
8433+
GGML_ASSERT(_length >= 1);
8434+
GGML_ASSERT(offset + length <= raw_text.length());
84338435
}
84348436

84358437
const FRAGMENT_BUFFER_VARIANT_TYPE type;
@@ -8553,14 +8555,14 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
85538555
}
85548556

85558557
std::forward_list<fragment_buffer_variant> fragment_buffer;
8556-
fragment_buffer.emplace_front( raw_text, 0, raw_text.length() );
8558+
fragment_buffer.emplace_front(raw_text, 0, raw_text.length());
85578559

8558-
if (special) tokenizer_st_partition( vocab, fragment_buffer );
8560+
if (special) tokenizer_st_partition(vocab, fragment_buffer);
85598561

85608562
switch (vocab.type) {
85618563
case LLAMA_VOCAB_TYPE_SPM:
85628564
{
8563-
for (const auto & fragment: fragment_buffer) {
8565+
for (const auto & fragment : fragment_buffer) {
85648566
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
85658567
// without adding this leading whitespace, we do not get the same results as the original tokenizer
85668568

@@ -8588,7 +8590,7 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
85888590
} break;
85898591
case LLAMA_VOCAB_TYPE_BPE:
85908592
{
8591-
for (const auto & fragment: fragment_buffer) {
8593+
for (const auto & fragment : fragment_buffer) {
85928594
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
85938595
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
85948596

@@ -8604,7 +8606,7 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
86048606
} break;
86058607
case LLAMA_VOCAB_TYPE_WPM:
86068608
{
8607-
for (const auto & fragment: fragment_buffer) {
8609+
for (const auto & fragment : fragment_buffer) {
86088610
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
86098611
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
86108612

tests/test-tokenizer-1-bpe.cpp

Lines changed: 37 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
#include "console.h"
55

66
#include <cassert>
7+
#include <codecvt>
78
#include <cstdio>
89
#include <cstring>
10+
#include <locale>
911
#include <string>
10-
#include <codecvt>
11-
#include <map>
12+
#include <thread>
1213
#include <vector>
13-
#include <locale>
1414

1515
int main(int argc, char **argv) {
1616
if (argc < 2) {
@@ -74,45 +74,46 @@ int main(int argc, char **argv) {
7474
}
7575
}
7676
catch (const std::invalid_argument &) {
77-
fprintf(stderr, "%s : info: utf8 conversion %d '%s'\n", __func__, i, str.c_str());
77+
//fprintf(stderr, "%s : info: utf8 conversion %d '%s'\n", __func__, i, str.c_str());
7878
}
7979
}
8080

81-
for (uint32_t cp = 0x0000; cp < 0xffff; ++cp) {
82-
// NOTE: these exceptions seem to be necessary, because the GPT2 tokenizer doesn't want to interfere with some ASCII control characters
83-
if ((cp < 0x03 || cp > 0x05) && cp != 0x0b && cp != 0x11 && (cp < 0x13 || cp > 0x17) && cp != 0x19 && (cp < 0x1c || cp > 0x1e) && (cp < 0xd800 || cp > 0xdfff)) {
84-
std::string str = " " + codepoint_to_utf8(cp);
85-
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
86-
std::string check = llama_detokenize_bpe(ctx, tokens);
87-
if (str != check) {
88-
fprintf(stderr, "%s : error: codepoint %x detokenizes to '%s'(%zu) instead of '%s'(%zu)\n",
89-
__func__, cp, check.c_str(), check.length(), str.c_str(), str.length());
90-
return 3;
91-
}
92-
}
93-
}
94-
// Restrict to assigned unicode planes
95-
// for (uint32_t cp = 0x10000; cp < 0x0010ffff; ++cp) {
96-
for (uint32_t cp = 0x10000; cp < 0x00040000; ++cp) {
97-
std::string str = codepoint_to_utf8(cp);
98-
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
99-
std::string check = llama_detokenize_bpe(ctx, tokens);
100-
if (str != check) {
101-
fprintf(stderr, "%s : error: codepoint %x detokenizes to '%s'(%zu) instead of '%s'(%zu)\n",
102-
__func__, cp, check.c_str(), check.length(), str.c_str(), str.length());
103-
return 4;
81+
// unicode
82+
{
83+
const int nthread = std::thread::hardware_concurrency();
84+
85+
std::vector<std::thread> threads(nthread);
86+
87+
for (int i = 0; i < nthread; ++i) {
88+
threads[i] = std::thread([i, nthread, ctx]() {
89+
for (uint32_t cp = i; cp < 0x0010ffff; cp += nthread) {
90+
if (!( // NOLINT
91+
(cp < 0x03 || cp > 0x05) && cp != 0x0b && cp != 0x11 &&
92+
(cp < 0x13 || cp > 0x17) && cp != 0x19 &&
93+
(cp < 0x1c || cp > 0x1e) &&
94+
(cp < 0xd800 || cp > 0xdfff) &&
95+
(cp < 0x00040000 || cp >= 0x000e0000)
96+
)) {
97+
continue;
98+
}
99+
100+
std::string str = codepoint_to_utf8(cp);
101+
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
102+
std::string check = llama_detokenize_bpe(ctx, tokens);
103+
if (cp != 9601 && str != check) {
104+
fprintf(stderr, "error: codepoint %x detokenizes to '%s'(%zu) instead of '%s'(%zu)\n",
105+
cp, check.c_str(), check.length(), str.c_str(), str.length());
106+
std::exit(3);
107+
}
108+
}
109+
});
104110
}
105-
}
106-
for (uint32_t cp = 0x000e0000; cp < 0x0010ffff; ++cp) {
107-
std::string str = codepoint_to_utf8(cp);
108-
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
109-
std::string check = llama_detokenize_bpe(ctx, tokens);
110-
if (str != check) {
111-
fprintf(stderr, "%s : error: codepoint %x detokenizes to '%s'(%zu) instead of '%s'(%zu)\n",
112-
__func__, cp, check.c_str(), check.length(), str.c_str(), str.length());
113-
return 4;
111+
112+
for (auto & t : threads) {
113+
t.join();
114114
}
115115
}
116+
116117
llama_free_model(model);
117118
llama_free(ctx);
118119

tests/test-tokenizer-1-llama.cpp

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
#include "console.h"
55

66
#include <cassert>
7+
#include <codecvt>
78
#include <cstdio>
89
#include <cstring>
10+
#include <locale>
911
#include <string>
10-
#include <codecvt>
11-
#include <map>
12+
#include <thread>
1213
#include <vector>
13-
#include <locale>
1414

1515
int main(int argc, char **argv) {
1616
if (argc < 2) {
@@ -72,26 +72,33 @@ int main(int argc, char **argv) {
7272
}
7373
}
7474

75-
for (uint32_t cp = 0x0000; cp < 0xffff; ++cp) {
76-
if (cp < 0xd800 || cp > 0xdfff) {
77-
std::string str = codepoint_to_utf8(cp);
78-
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
79-
std::string check = llama_detokenize_spm(ctx, tokens);
80-
if (cp != 9601 && str != check) {
81-
fprintf(stderr, "%s : error: codepoint %d detokenizes to '%s'(%zu) instead of '%s'(%zu)\n",
82-
__func__, cp, check.c_str(), check.length(), str.c_str(), str.length());
83-
return 3;
84-
}
75+
// unicode
76+
{
77+
const int nthread = std::thread::hardware_concurrency();
78+
79+
std::vector<std::thread> threads(nthread);
80+
81+
for (int i = 0; i < nthread; ++i) {
82+
threads[i] = std::thread([i, nthread, ctx]() {
83+
for (uint32_t cp = i; cp < 0x0010ffff; cp += nthread) {
84+
if (cp >= 0xd800 && cp <= 0xdfff) {
85+
continue;
86+
}
87+
88+
std::string str = codepoint_to_utf8(cp);
89+
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
90+
std::string check = llama_detokenize_spm(ctx, tokens);
91+
if (cp != 9601 && str != check) {
92+
fprintf(stderr, "error: codepoint %x detokenizes to '%s'(%zu) instead of '%s'(%zu)\n",
93+
cp, check.c_str(), check.length(), str.c_str(), str.length());
94+
std::exit(3);
95+
}
96+
}
97+
});
8598
}
86-
}
87-
for (uint32_t cp = 0x10000; cp < 0x0010ffff; ++cp) {
88-
std::string str = codepoint_to_utf8(cp);
89-
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
90-
std::string check = llama_detokenize_spm(ctx, tokens);
91-
if (str != check) {
92-
fprintf(stderr, "%s : error: codepoint %d detokenizes to '%s'(%zu) instead of '%s'(%zu)\n",
93-
__func__, cp, check.c_str(), check.length(), str.c_str(), str.length());
94-
return 4;
99+
100+
for (auto & t : threads) {
101+
t.join();
95102
}
96103
}
97104

unicode.h

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -264,26 +264,29 @@ static uint32_t codepoint_from_utf8(const std::string & utf8, size_t & offset) {
264264
offset += 1;
265265
return result;
266266
}
267-
else if (!(utf8[offset + 0] & 0x40)) {
267+
if (!(utf8[offset + 0] & 0x40)) {
268268
throw std::invalid_argument("invalid character");
269269
}
270-
else if (!(utf8[offset + 0] & 0x20)) {
271-
if (offset + 1 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80))
270+
if (!(utf8[offset + 0] & 0x20)) {
271+
if (offset + 1 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80)) {
272272
throw std::invalid_argument("invalid character");
273+
}
273274
auto result = ((utf8[offset + 0] & 0x1f) << 6) | (utf8[offset + 1] & 0x3f);
274275
offset += 2;
275276
return result;
276277
}
277-
else if (!(utf8[offset + 0] & 0x10)) {
278-
if (offset + 2 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80))
278+
if (!(utf8[offset + 0] & 0x10)) {
279+
if (offset + 2 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80)) {
279280
throw std::invalid_argument("invalid character");
281+
}
280282
auto result = ((utf8[offset + 0] & 0x0f) << 12) | ((utf8[offset + 1] & 0x3f) << 6) | (utf8[offset + 2] & 0x3f);
281283
offset += 3;
282284
return result;
283285
}
284-
else if (!(utf8[offset + 0] & 0x08)) {
285-
if (offset + 3 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80) || !((utf8[offset + 3] & 0xc0) == 0x80))
286+
if (!(utf8[offset + 0] & 0x08)) {
287+
if (offset + 3 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80) || !((utf8[offset + 3] & 0xc0) == 0x80)) {
286288
throw std::invalid_argument("invalid character");
289+
}
287290
auto result = ((utf8[offset + 0] & 0x07) << 18) | ((utf8[offset + 1] & 0x3f) << 12) | ((utf8[offset + 2] & 0x3f) << 6) | (utf8[offset + 3] & 0x3f);
288291
offset += 4;
289292
return result;
@@ -331,21 +334,22 @@ static uint32_t codepoint_from_utf16(const std::vector<uint16_t> & utf16, size_t
331334
offset += 1;
332335
return result;
333336
}
334-
else {
335-
if (offset + 1 >= utf16.size() || !((utf16[1] & 0xdc00) == 0xdc00))
336-
throw std::invalid_argument("invalid character");
337-
auto result = 0x10000 + (((utf16[0] & 0x03ff) << 10) | (utf16[1] & 0x03ff));
338-
offset += 2;
339-
return result;
337+
338+
if (offset + 1 >= utf16.size() || !((utf16[1] & 0xdc00) == 0xdc00)) {
339+
throw std::invalid_argument("invalid character");
340340
}
341-
throw std::invalid_argument("invalid string");
341+
342+
auto result = 0x10000 + (((utf16[0] & 0x03ff) << 10) | (utf16[1] & 0x03ff));
343+
offset += 2;
344+
return result;
342345
}
343346

344347
static std::vector<uint32_t> codepoints_from_utf16(const std::vector<uint16_t> & utf16) {
345348
std::vector<uint32_t> result;
346349
size_t offset = 0;
347-
while (offset < utf16.size())
350+
while (offset < utf16.size()) {
348351
result.push_back(codepoint_from_utf16(utf16, offset));
352+
}
349353
return result;
350354
}
351355

@@ -361,44 +365,52 @@ static std::vector<uint32_t> codepoints_from_utf16(const std::vector<uint16_t> &
361365
static std::unordered_map<uint32_t, int> codepoint_type_map() {
362366
std::unordered_map<uint32_t, int> codepoint_types;
363367
for (auto p : digit_ranges) {
364-
for(auto i = p.first; i <= p.second; ++ i)
368+
for (auto i = p.first; i <= p.second; ++ i) {
365369
codepoint_types[i] = CODEPOINT_TYPE_DIGIT;
370+
}
366371
}
367-
for(auto p : letter_ranges) {
368-
for(auto i = p.first; i <= p.second; ++ i)
372+
for (auto p : letter_ranges) {
373+
for (auto i = p.first; i <= p.second; ++ i) {
369374
codepoint_types[i] = CODEPOINT_TYPE_LETTER;
375+
}
370376
}
371-
for(auto p : whitespace_ranges) {
372-
for(auto i = p.first; i <= p.second; ++ i)
377+
for (auto p : whitespace_ranges) {
378+
for (auto i = p.first; i <= p.second; ++ i) {
373379
codepoint_types[i] = CODEPOINT_TYPE_WHITESPACE;
380+
}
374381
}
375-
for(auto p : accent_mark_ranges) {
376-
for(auto i = p.first; i <= p.second; ++ i)
382+
for (auto p : accent_mark_ranges) {
383+
for (auto i = p.first; i <= p.second; ++ i) {
377384
codepoint_types[i] = CODEPOINT_TYPE_ACCENT_MARK;
385+
}
378386
}
379-
for(auto p : punctuation_ranges) {
380-
for(auto i = p.first; i <= p.second; ++ i)
387+
for (auto p : punctuation_ranges) {
388+
for (auto i = p.first; i <= p.second; ++ i) {
381389
codepoint_types[i] = CODEPOINT_TYPE_PUNCTUATION;
390+
}
382391
}
383-
for (auto p : symbol_ranges) {
384-
for (auto i = p.first; i <= p.second; ++i)
392+
for (auto p : symbol_ranges) {
393+
for (auto i = p.first; i <= p.second; ++i) {
385394
codepoint_types[i] = CODEPOINT_TYPE_SYMBOL;
395+
}
386396
}
387-
for(auto p : control_ranges) {
388-
for(auto i = p.first; i <= p.second; ++ i)
397+
for (auto p : control_ranges) {
398+
for (auto i = p.first; i <= p.second; ++ i) {
389399
codepoint_types[i] = CODEPOINT_TYPE_CONTROL;
400+
}
390401
}
391402
return codepoint_types;
392403
}
393404

394405
static int codepoint_type(uint32_t cp) {
395406
static std::unordered_map<uint32_t, int> codepoint_types = codepoint_type_map();
396-
return codepoint_types[cp];
407+
return codepoint_types.find(cp) == codepoint_types.end() ? CODEPOINT_TYPE_UNIDENTIFIED : codepoint_types.at(cp);
397408
}
398409

399410
static int codepoint_type(const std::string & utf8) {
400-
if (utf8.length() == 0)
411+
if (utf8.length() == 0) {
401412
return CODEPOINT_TYPE_UNIDENTIFIED;
413+
}
402414
size_t offset = 0;
403415
return codepoint_type(codepoint_from_utf8(utf8, offset));
404416
}

0 commit comments

Comments
 (0)