Skip to content

Commit cba637e

Browse files
committed
Tokenizer test
1 parent b3ba207 commit cba637e

File tree

5 files changed

+74
-158
lines changed

5 files changed

+74
-158
lines changed

test/test_base64.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <pytorch/tokenizers/base64.h>
109
#include "gtest/gtest.h"
10+
#include <pytorch/tokenizers/base64.h>
1111

1212
namespace tokenizers {
1313

test/test_llama2c_tokenizer.cpp

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,34 +6,19 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#ifdef TOKENIZERS_FB_BUCK
10-
#include <TestResourceUtils/TestResourceUtils.h>
11-
#endif
129
#include <gtest/gtest.h>
1310
#include <pytorch/tokenizers/llama2c_tokenizer.h>
1411

1512
using namespace ::testing;
1613

1714
namespace tokenizers {
1815

19-
namespace {
20-
// Test case based on llama2.c tokenizer
21-
static inline std::string _get_resource_path(const std::string& name) {
22-
#ifdef TOKENIZERS_FB_BUCK
23-
return facebook::xplat::testing::getPathForTestResource(
24-
"test/resources/" + name);
25-
#else
26-
return std::getenv("RESOURCES_PATH") + std::string("/") + name;
27-
#endif
28-
}
29-
30-
} // namespace
31-
3216
class Llama2cTokenizerTest : public Test {
33-
public:
17+
public:
3418
void SetUp() override {
3519
tokenizer_ = std::make_unique<Llama2cTokenizer>();
36-
modelPath_ = _get_resource_path("test_llama2c_tokenizer.bin");
20+
modelPath_ = std::getenv("RESOURCES_PATH") +
21+
std::string("/test_llama2c_tokenizer.bin");
3722
}
3823

3924
std::unique_ptr<Tokenizer> tokenizer_;
@@ -51,15 +36,15 @@ TEST_F(Llama2cTokenizerTest, DecodeWithoutLoadFails) {
5136
}
5237

5338
TEST_F(Llama2cTokenizerTest, DecodeOutOfRangeFails) {
54-
Error res = tokenizer_->load(modelPath_.c_str());
39+
Error res = tokenizer_->load(modelPath_);
5540
EXPECT_EQ(res, Error::Ok);
5641
auto result = tokenizer_->decode(0, 64000);
5742
// The vocab size is 32000, and token 64000 is out of vocab range.
5843
EXPECT_EQ(result.error(), Error::OutOfRange);
5944
}
6045

6146
TEST_F(Llama2cTokenizerTest, TokenizerMetadataIsExpected) {
62-
Error res = tokenizer_->load(modelPath_.c_str());
47+
Error res = tokenizer_->load(modelPath_);
6348
EXPECT_EQ(res, Error::Ok);
6449
// test_bpe_tokenizer.bin has vocab_size 0, bos_id 0, eos_id 0 recorded.
6550
EXPECT_EQ(tokenizer_->vocab_size(), 0);

test/test_pre_tokenizer.cpp

Lines changed: 25 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,11 @@ using namespace tokenizers;
1919

2020
// Helpers /////////////////////////////////////////////////////////////////////
2121

22-
static void assert_split_match(
23-
const PreTokenizer& ptok,
24-
const std::string& prompt,
25-
const std::vector<std::string>& expected) {
22+
static void assert_split_match(const PreTokenizer &ptok,
23+
const std::string &prompt,
24+
const std::vector<std::string> &expected) {
2625
re2::StringPiece prompt_view(prompt);
27-
const auto& got = ptok.pre_tokenize(prompt_view);
26+
const auto &got = ptok.pre_tokenize(prompt_view);
2827
EXPECT_EQ(expected.size(), got.size());
2928
for (auto i = 0; i < got.size(); ++i) {
3029
EXPECT_EQ(expected[i], got[i]);
@@ -35,16 +34,14 @@ static void assert_split_match(
3534
class RegexPreTokenizerTest : public ::testing::Test {};
3635

3736
// Test the basic construction
38-
TEST_F(RegexPreTokenizerTest, Construct) {
39-
RegexPreTokenizer ptok("[0-9]+");
40-
}
37+
TEST_F(RegexPreTokenizerTest, Construct) { RegexPreTokenizer ptok("[0-9]+"); }
4138

4239
// Test basic splitting using the expression for Tiktoken
4340
TEST_F(RegexPreTokenizerTest, TiktokenExpr) {
4441
RegexPreTokenizer ptok(
4542
R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+)");
46-
assert_split_match(
47-
ptok, "How are you doing?", {"How", " are", " you", " doing", "?"});
43+
assert_split_match(ptok, "How are you doing?",
44+
{"How", " are", " you", " doing", "?"});
4845
}
4946

5047
// DigitsPreTokenizer //////////////////////////////////////////////////////////
@@ -54,18 +51,15 @@ class DigitsPreTokenizerTest : public ::testing::Test {};
5451
TEST_F(DigitsPreTokenizerTest, IndividualDigits) {
5552
DigitsPreTokenizer ptok(true);
5653
assert_split_match(
57-
ptok,
58-
"The number 1 then 234 then 5.",
54+
ptok, "The number 1 then 234 then 5.",
5955
{"The number ", "1", " then ", "2", "3", "4", " then ", "5", "."});
6056
}
6157

6258
// Test digit splitting with contiguous digits
6359
TEST_F(DigitsPreTokenizerTest, ContiguousDigits) {
6460
DigitsPreTokenizer ptok(false);
65-
assert_split_match(
66-
ptok,
67-
"The number 1 then 234 then 5.",
68-
{"The number ", "1", " then ", "234", " then ", "5", "."});
61+
assert_split_match(ptok, "The number 1 then 234 then 5.",
62+
{"The number ", "1", " then ", "234", " then ", "5", "."});
6963
}
7064

7165
// ByteLevelPreTokenizer ///////////////////////////////////////////////////////
@@ -75,8 +69,7 @@ TEST_F(ByteLevelPreTokenizerTest, PreTokenizeDefault) {
7569
ByteLevelPreTokenizer ptok;
7670
assert_split_match(ptok, "Hello World", {"ĠHello", "ĠWorld"});
7771
assert_split_match(
78-
ptok,
79-
"The number 1 then 234 then 5.",
72+
ptok, "The number 1 then 234 then 5.",
8073
{"ĠThe", "Ġnumber", "Ġ1", "Ġthen", "Ġ234", "Ġthen", "Ġ5", "."});
8174
}
8275

@@ -97,22 +90,9 @@ TEST_F(SequencePreTokenizerTest, PreTokenizeDigitAndByteLevel) {
9790
PreTokenizer::Ptr dptok(new DigitsPreTokenizer(true));
9891
PreTokenizer::Ptr bptok(new ByteLevelPreTokenizer(false));
9992
SequencePreTokenizer ptok({dptok, bptok});
100-
assert_split_match(
101-
ptok,
102-
"The number 1 then 234 then 5.",
103-
{"The",
104-
"Ġnumber",
105-
"Ġ",
106-
"1",
107-
"Ġthen",
108-
"Ġ",
109-
"2",
110-
"3",
111-
"4",
112-
"Ġthen",
113-
"Ġ",
114-
"5",
115-
"."});
93+
assert_split_match(ptok, "The number 1 then 234 then 5.",
94+
{"The", "Ġnumber", "Ġ", "1", "Ġthen", "Ġ", "2", "3", "4",
95+
"Ġthen", "Ġ", "5", "."});
11696
}
11797

11898
// PreTokenizerConfig //////////////////////////////////////////////////////////
@@ -152,14 +132,12 @@ TEST_F(PreTokenizerConfigTest, AllTypesFailureCases) {
152132

153133
// Sequence
154134
EXPECT_THROW(PreTokenizerConfig("Sequence").create(), std::runtime_error);
155-
EXPECT_THROW(
156-
PreTokenizerConfig("Sequence").set_pretokenizers({}).create(),
157-
std::runtime_error);
158-
EXPECT_THROW(
159-
PreTokenizerConfig("Sequence")
160-
.set_pretokenizers({PreTokenizerConfig("Split")})
161-
.create(),
162-
std::runtime_error);
135+
EXPECT_THROW(PreTokenizerConfig("Sequence").set_pretokenizers({}).create(),
136+
std::runtime_error);
137+
EXPECT_THROW(PreTokenizerConfig("Sequence")
138+
.set_pretokenizers({PreTokenizerConfig("Split")})
139+
.create(),
140+
std::runtime_error);
163141

164142
// Unsupported
165143
EXPECT_THROW(PreTokenizerConfig("Unsupported").create(), std::runtime_error);
@@ -183,22 +161,9 @@ TEST_F(PreTokenizerConfigTest, ParseJson) {
183161
}},
184162
})
185163
.create();
186-
assert_split_match(
187-
*ptok,
188-
"The number 1 then 234 then 5.",
189-
{"The",
190-
"Ġnumber",
191-
"Ġ",
192-
"1",
193-
"Ġthen",
194-
"Ġ",
195-
"2",
196-
"3",
197-
"4",
198-
"Ġthen",
199-
"Ġ",
200-
"5",
201-
"."});
164+
assert_split_match(*ptok, "The number 1 then 234 then 5.",
165+
{"The", "Ġnumber", "Ġ", "1", "Ġthen", "Ġ", "2", "3", "4",
166+
"Ġthen", "Ġ", "5", "."});
202167
}
203168

204169
TEST_F(PreTokenizerConfigTest, ParseJsonOptionalKey) {
@@ -208,10 +173,8 @@ TEST_F(PreTokenizerConfigTest, ParseJsonOptionalKey) {
208173
{"type", "Digits"},
209174
})
210175
.create();
211-
assert_split_match(
212-
*ptok,
213-
"The number 1 then 234 then 5.",
214-
{"The number ", "1", " then ", "234", " then ", "5", "."});
176+
assert_split_match(*ptok, "The number 1 then 234 then 5.",
177+
{"The number ", "1", " then ", "234", " then ", "5", "."});
215178
}
216179

217180
TEST_F(PreTokenizerConfigTest, Split) {

test/test_sentencepiece.cpp

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,11 @@
77
*/
88
// @lint-ignore-every LICENSELINT
99

10-
#ifdef TOKENIZERS_FB_BUCK
11-
#include <TestResourceUtils/TestResourceUtils.h>
12-
#endif
1310
#include <gtest/gtest.h>
1411
#include <pytorch/tokenizers/sentencepiece.h>
1512

1613
namespace tokenizers {
1714

18-
namespace {
19-
static inline std::string _get_resource_path(const std::string& name) {
20-
#ifdef TOKENIZERS_FB_BUCK
21-
return facebook::xplat::testing::getPathForTestResource(
22-
"test/resources/" + name);
23-
#else
24-
return std::getenv("RESOURCES_PATH") + std::string("/") + name;
25-
#endif
26-
}
27-
28-
} // namespace
29-
3015
TEST(SPTokenizerTest, TestEncodeWithoutLoad) {
3116
SPTokenizer tokenizer;
3217
std::string text = "Hello world!";
@@ -42,7 +27,8 @@ TEST(SPTokenizerTest, TestDecodeWithoutLoad) {
4227

4328
TEST(SPTokenizerTest, TestLoad) {
4429
SPTokenizer tokenizer;
45-
auto path = _get_resource_path("test_sentencepiece.model");
30+
auto path =
31+
std::getenv("RESOURCES_PATH") + std::string("/test_sentencepiece.model");
4632
auto error = tokenizer.load(path);
4733
EXPECT_EQ(error, Error::Ok);
4834
}
@@ -55,7 +41,8 @@ TEST(SPTokenizerTest, TestLoadInvalidPath) {
5541

5642
TEST(SPTokenizerTest, TestEncode) {
5743
SPTokenizer tokenizer;
58-
auto path = _get_resource_path("test_sentencepiece.model");
44+
auto path =
45+
std::getenv("RESOURCES_PATH") + std::string("/test_sentencepiece.model");
5946
auto error = tokenizer.load(path);
6047
EXPECT_EQ(error, Error::Ok);
6148
std::string text = "Hello world!";
@@ -70,7 +57,8 @@ TEST(SPTokenizerTest, TestEncode) {
7057

7158
TEST(SPTokenizerTest, TestDecode) {
7259
SPTokenizer tokenizer;
73-
auto path = _get_resource_path("test_sentencepiece.model");
60+
auto path =
61+
std::getenv("RESOURCES_PATH") + std::string("/test_sentencepiece.model");
7462
auto error = tokenizer.load(path);
7563
EXPECT_EQ(error, Error::Ok);
7664
std::vector<uint64_t> tokens = {1, 15043, 3186, 29991};

0 commit comments

Comments
 (0)