Skip to content

Commit c2147cb

Browse files
David Linfacebook-github-bot
authored andcommitted
Add Tiktoken v5 vision tokenizer (#4086)
Summary: Pull Request resolved: #4086 Added a new param so that user can decide which version of tiktoken will be used. Refactored special tokens to be more streamlined and reserve 256 special tokens independent of Tiktoken version. Copied test case from https://www.internalfb.com/code/fbsource/fbcode/gen_ai/llm_inference/llm_common/tokenizers/protocols/tiktokenv5_vision_test_cases.json Reviewed By: larryliu0820 Differential Revision: D59074258 fbshipit-source-id: 16f3b1fce68939d464b56f4edb869fe293c6b0b8
1 parent 28a45cd commit c2147cb

File tree

2 files changed

+172
-15
lines changed

2 files changed

+172
-15
lines changed

examples/models/llama2/tokenizer/test/test_tiktoken.cpp

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,19 @@ class TiktokenExtensionTest : public Test {
3030
std::string modelPath_;
3131
};
3232

33+
class MultimodalTiktokenV5ExtensionTest : public Test {
34+
public:
35+
void SetUp() override {
36+
torch::executor::runtime_init();
37+
tokenizer_ = std::make_unique<Tiktoken>(MULTIMODAL);
38+
modelPath_ =
39+
std::getenv("RESOURCES_PATH") + std::string("/tokenizer.model");
40+
}
41+
42+
std::unique_ptr<Tokenizer> tokenizer_;
43+
std::string modelPath_;
44+
};
45+
3346
TEST_F(TiktokenExtensionTest, EncodeWithoutLoadFails) {
3447
Result<std::vector<uint64_t>> res = tokenizer_->encode("hello world", 0, 0);
3548
EXPECT_EQ(res.error(), Error::NotSupported);
@@ -50,6 +63,16 @@ TEST_F(TiktokenExtensionTest, TokenizerVocabSizeIsExpected) {
5063
EXPECT_EQ(tokenizer_->eos_tok(), 128001);
5164
}
5265

66+
TEST_F(MultimodalTiktokenV5ExtensionTest, TokenizerVocabSizeIsExpected) {
67+
Error res = tokenizer_->load(modelPath_.c_str());
68+
EXPECT_EQ(res, Error::Ok);
69+
// test.bin has vocab size 0 but the tokenizer respects the vocab size being
70+
// passed in and add placeholder tokens.
71+
EXPECT_EQ(tokenizer_->vocab_size(), 128256);
72+
EXPECT_EQ(tokenizer_->bos_tok(), 128000);
73+
EXPECT_EQ(tokenizer_->eos_tok(), 128001);
74+
}
75+
5376
TEST_F(TiktokenExtensionTest, TokenizerEncodeCorrectly) {
5477
Error res = tokenizer_->load(modelPath_.c_str());
5578
EXPECT_EQ(res, Error::Ok);
@@ -63,6 +86,29 @@ TEST_F(TiktokenExtensionTest, TokenizerEncodeCorrectly) {
6386
EXPECT_EQ(out.get()[2], 1917);
6487
}
6588

89+
TEST_F(MultimodalTiktokenV5ExtensionTest, TokenizerEncodeCorrectly) {
90+
Error res = tokenizer_->load(modelPath_.c_str());
91+
EXPECT_EQ(res, Error::Ok);
92+
// test.bin has vocab size 0 but the tokenizer respects the vocab size being
93+
// passed in and add placeholder tokens.
94+
Result<std::vector<uint64_t>> out = tokenizer_->encode(
95+
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n<|image|>What do you think is going on in this snapshot?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nAmidst a scenic garden backdrop, a man dressed in a suit with a distinct button on its lower portion stands prominently.<|eom_id|>",
96+
0,
97+
0);
98+
EXPECT_EQ(out.error(), Error::Ok);
99+
EXPECT_EQ(out.get().size(), 48);
100+
std::vector<uint64_t> expected_out = {
101+
128000, 128006, 882, 128007, 271, 128010, 3923, 656,
102+
499, 1781, 374, 2133, 389, 304, 420, 16694,
103+
30, 128009, 128006, 78191, 128007, 271, 6219, 307,
104+
267, 264, 62081, 13863, 39577, 11, 264, 893,
105+
26435, 304, 264, 7937, 449, 264, 12742, 3215,
106+
389, 1202, 4827, 13651, 13656, 74088, 13, 128008};
107+
for (size_t i = 0; i < expected_out.size(); ++i) {
108+
EXPECT_EQ(expected_out[i], out.get()[i]);
109+
}
110+
}
111+
66112
TEST_F(TiktokenExtensionTest, TokenizerDecodeCorrectly) {
67113
Error res = tokenizer_->load(modelPath_.c_str());
68114
EXPECT_EQ(res, Error::Ok);
@@ -77,6 +123,30 @@ TEST_F(TiktokenExtensionTest, TokenizerDecodeCorrectly) {
77123
}
78124
}
79125

126+
TEST_F(MultimodalTiktokenV5ExtensionTest, TokenizerDecodeCorrectly) {
127+
Error res = tokenizer_->load(modelPath_.c_str());
128+
EXPECT_EQ(res, Error::Ok);
129+
// test.bin has vocab size 0 but the tokenizer respects the vocab size being
130+
// passed in and add placeholder tokens.
131+
std::vector<std::string> expected = {
132+
"<|begin_of_text|>",
133+
"<|start_header_id|>",
134+
"user",
135+
"<|end_header_id|>",
136+
"<|image|>",
137+
"<|image|>",
138+
"hello",
139+
"<|image|>",
140+
"<|eom_id|>"};
141+
std::vector<uint64_t> tokens = {
142+
128000, 128006, 882, 128007, 128010, 128010, 15339, 128010, 128008};
143+
for (size_t i = 0; i < tokens.size(); i++) {
144+
Result<std::string> out = tokenizer_->decode(0, tokens[i]);
145+
EXPECT_EQ(out.error(), Error::Ok);
146+
EXPECT_EQ(out.get(), expected[i]);
147+
}
148+
}
149+
80150
TEST_F(TiktokenExtensionTest, TokenizerDecodeOutOfRangeFails) {
81151
Error res = tokenizer_->load(modelPath_.c_str());
82152
EXPECT_EQ(res, Error::Ok);

examples/models/llama2/tokenizer/tiktoken.h

Lines changed: 102 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,17 @@ using Encoder = std::unordered_map<std::string, uint64_t>;
2424
using Decoder = std::unordered_map<uint64_t, std::string>;
2525
using Re2UPtr = std::unique_ptr<re2::RE2>;
2626

27+
constexpr int32_t kSpecialTokensSize = 256;
28+
29+
enum Version {
30+
DEFAULT,
31+
MULTIMODAL,
32+
};
33+
2734
class Tiktoken : public Tokenizer {
2835
public:
29-
explicit Tiktoken() : Tokenizer() {}
36+
explicit Tiktoken(const Version& version = DEFAULT)
37+
: Tokenizer(), _version(version) {}
3038
~Tiktoken(){};
3139

3240
Error load(const std::string& tokenizer_path) override;
@@ -38,26 +46,103 @@ class Tiktoken : public Tokenizer {
3846
const override;
3947

4048
private:
41-
static inline const Encoder _get_special_tokens(ssize_t num_base_tokens) {
49+
static inline const Encoder _get_default_special_tokens(
50+
ssize_t num_base_tokens) {
4251
Encoder special_tokens;
43-
special_tokens.emplace("<|begin_of_text|>", num_base_tokens++);
44-
special_tokens.emplace("<|end_of_text|>", num_base_tokens++);
45-
special_tokens.emplace("<|reserved_special_token_0|>", num_base_tokens++);
46-
special_tokens.emplace("<|reserved_special_token_1|>", num_base_tokens++);
47-
special_tokens.emplace("<|reserved_special_token_2|>", num_base_tokens++);
48-
special_tokens.emplace("<|reserved_special_token_3|>", num_base_tokens++);
49-
special_tokens.emplace("<|start_header_id|>", num_base_tokens++);
50-
special_tokens.emplace("<|end_header_id|>", num_base_tokens++);
51-
special_tokens.emplace("<|reserved_special_token_4|>", num_base_tokens++);
52-
special_tokens.emplace("<|eot_id|>", num_base_tokens++);
53-
for (auto i = 5; i < 251; ++i) {
52+
ssize_t special_token_count = 0;
53+
special_tokens.emplace(
54+
"<|begin_of_text|>", num_base_tokens + special_token_count++);
55+
special_tokens.emplace(
56+
"<|end_of_text|>", num_base_tokens + special_token_count++);
57+
special_tokens.emplace(
58+
"<|reserved_special_token_0|>",
59+
num_base_tokens + special_token_count++);
60+
special_tokens.emplace(
61+
"<|reserved_special_token_1|>",
62+
num_base_tokens + special_token_count++);
63+
special_tokens.emplace(
64+
"<|reserved_special_token_2|>",
65+
num_base_tokens + special_token_count++);
66+
special_tokens.emplace(
67+
"<|reserved_special_token_3|>",
68+
num_base_tokens + special_token_count++);
69+
special_tokens.emplace(
70+
"<|start_header_id|>", num_base_tokens + special_token_count++);
71+
special_tokens.emplace(
72+
"<|end_header_id|>", num_base_tokens + special_token_count++);
73+
special_tokens.emplace(
74+
"<|reserved_special_token_4|>",
75+
num_base_tokens + special_token_count++);
76+
special_tokens.emplace(
77+
"<|eot_id|>", num_base_tokens + special_token_count++);
78+
79+
// pad the rest of the special tokens with reserved tokens
80+
ssize_t reserved_special_token_num = 5;
81+
while (special_token_count < kSpecialTokensSize) {
5482
special_tokens.emplace(
55-
"<|reserved_special_token_" + std::to_string(i) + "|>",
56-
num_base_tokens++);
83+
"<|reserved_special_token_" +
84+
std::to_string(reserved_special_token_num++) + "|>",
85+
num_base_tokens + special_token_count++);
5786
}
5887
return special_tokens;
5988
}
6089

90+
static inline const Encoder _get_multimodal_special_tokens(
91+
ssize_t num_base_tokens) {
92+
ssize_t special_token_count = 0;
93+
Encoder special_tokens;
94+
special_tokens.emplace(
95+
"<|begin_of_text|>", num_base_tokens + special_token_count++);
96+
special_tokens.emplace(
97+
"<|end_of_text|>", num_base_tokens + special_token_count++);
98+
special_tokens.emplace(
99+
"<|reserved_special_token_0|>",
100+
num_base_tokens + special_token_count++);
101+
special_tokens.emplace(
102+
"<|reserved_special_token_1|>",
103+
num_base_tokens + special_token_count++);
104+
special_tokens.emplace(
105+
"<|reserved_special_token_2|>",
106+
num_base_tokens + special_token_count++);
107+
special_tokens.emplace(
108+
"<|reserved_special_token_3|>",
109+
num_base_tokens + special_token_count++);
110+
special_tokens.emplace(
111+
"<|start_header_id|>", num_base_tokens + special_token_count++);
112+
special_tokens.emplace(
113+
"<|end_header_id|>", num_base_tokens + special_token_count++);
114+
special_tokens.emplace(
115+
"<|eom_id|>", num_base_tokens + special_token_count++);
116+
special_tokens.emplace(
117+
"<|eot_id|>", num_base_tokens + special_token_count++);
118+
special_tokens.emplace(
119+
"<|image|>", num_base_tokens + special_token_count++);
120+
121+
// pad the rest of the special tokens with reserved tokens except the last
122+
// one
123+
ssize_t reserved_special_token_num = 4;
124+
while (special_token_count < kSpecialTokensSize - 1) {
125+
special_tokens.emplace(
126+
"<|reserved_special_token_" +
127+
std::to_string(reserved_special_token_num++) + "|>",
128+
num_base_tokens + special_token_count++);
129+
}
130+
131+
special_tokens.emplace(
132+
"<|python_tag|>", num_base_tokens + special_token_count++);
133+
134+
return special_tokens;
135+
}
136+
137+
inline const Encoder _get_special_tokens(ssize_t num_base_tokens) {
138+
switch (_version) {
139+
case MULTIMODAL:
140+
return _get_multimodal_special_tokens(num_base_tokens);
141+
default:
142+
return _get_default_special_tokens(num_base_tokens);
143+
}
144+
}
145+
61146
template <typename T>
62147
std::pair<std::optional<std::string>, re2::StringPiece>
63148
_split_with_allowed_special_token(
@@ -74,6 +159,8 @@ class Tiktoken : public Tokenizer {
74159
const std::string& text,
75160
const T& allowed_special) const;
76161

162+
const Version _version;
163+
77164
// Removed negative lookahead \s+(?!\S) since it's not supported by RE2.
78165
const std::string _pattern =
79166
R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+)";

0 commit comments

Comments
 (0)