@@ -24,9 +24,17 @@ using Encoder = std::unordered_map<std::string, uint64_t>;
24
24
using Decoder = std::unordered_map<uint64_t , std::string>;
25
25
using Re2UPtr = std::unique_ptr<re2::RE2>;
26
26
27
+ constexpr int32_t kSpecialTokensSize = 256 ;
28
+
29
+ enum Version {
30
+ DEFAULT,
31
+ MULTIMODAL,
32
+ };
33
+
27
34
class Tiktoken : public Tokenizer {
28
35
public:
29
- explicit Tiktoken () : Tokenizer() {}
36
+ explicit Tiktoken (const Version& version = DEFAULT)
37
+ : Tokenizer(), _version(version) {}
30
38
~Tiktoken (){};
31
39
32
40
Error load (const std::string& tokenizer_path) override ;
@@ -38,26 +46,103 @@ class Tiktoken : public Tokenizer {
38
46
const override ;
39
47
40
48
private:
41
- static inline const Encoder _get_special_tokens (ssize_t num_base_tokens) {
49
+ static inline const Encoder _get_default_special_tokens (
50
+ ssize_t num_base_tokens) {
42
51
Encoder special_tokens;
43
- special_tokens.emplace (" <|begin_of_text|>" , num_base_tokens++);
44
- special_tokens.emplace (" <|end_of_text|>" , num_base_tokens++);
45
- special_tokens.emplace (" <|reserved_special_token_0|>" , num_base_tokens++);
46
- special_tokens.emplace (" <|reserved_special_token_1|>" , num_base_tokens++);
47
- special_tokens.emplace (" <|reserved_special_token_2|>" , num_base_tokens++);
48
- special_tokens.emplace (" <|reserved_special_token_3|>" , num_base_tokens++);
49
- special_tokens.emplace (" <|start_header_id|>" , num_base_tokens++);
50
- special_tokens.emplace (" <|end_header_id|>" , num_base_tokens++);
51
- special_tokens.emplace (" <|reserved_special_token_4|>" , num_base_tokens++);
52
- special_tokens.emplace (" <|eot_id|>" , num_base_tokens++);
53
- for (auto i = 5 ; i < 251 ; ++i) {
52
+ ssize_t special_token_count = 0 ;
53
+ special_tokens.emplace (
54
+ " <|begin_of_text|>" , num_base_tokens + special_token_count++);
55
+ special_tokens.emplace (
56
+ " <|end_of_text|>" , num_base_tokens + special_token_count++);
57
+ special_tokens.emplace (
58
+ " <|reserved_special_token_0|>" ,
59
+ num_base_tokens + special_token_count++);
60
+ special_tokens.emplace (
61
+ " <|reserved_special_token_1|>" ,
62
+ num_base_tokens + special_token_count++);
63
+ special_tokens.emplace (
64
+ " <|reserved_special_token_2|>" ,
65
+ num_base_tokens + special_token_count++);
66
+ special_tokens.emplace (
67
+ " <|reserved_special_token_3|>" ,
68
+ num_base_tokens + special_token_count++);
69
+ special_tokens.emplace (
70
+ " <|start_header_id|>" , num_base_tokens + special_token_count++);
71
+ special_tokens.emplace (
72
+ " <|end_header_id|>" , num_base_tokens + special_token_count++);
73
+ special_tokens.emplace (
74
+ " <|reserved_special_token_4|>" ,
75
+ num_base_tokens + special_token_count++);
76
+ special_tokens.emplace (
77
+ " <|eot_id|>" , num_base_tokens + special_token_count++);
78
+
79
+ // pad the rest of the special tokens with reserved tokens
80
+ ssize_t reserved_special_token_num = 5 ;
81
+ while (special_token_count < kSpecialTokensSize ) {
54
82
special_tokens.emplace (
55
- " <|reserved_special_token_" + std::to_string (i) + " |>" ,
56
- num_base_tokens++);
83
+ " <|reserved_special_token_" +
84
+ std::to_string (reserved_special_token_num++) + " |>" ,
85
+ num_base_tokens + special_token_count++);
57
86
}
58
87
return special_tokens;
59
88
}
60
89
90
+ static inline const Encoder _get_multimodal_special_tokens (
91
+ ssize_t num_base_tokens) {
92
+ ssize_t special_token_count = 0 ;
93
+ Encoder special_tokens;
94
+ special_tokens.emplace (
95
+ " <|begin_of_text|>" , num_base_tokens + special_token_count++);
96
+ special_tokens.emplace (
97
+ " <|end_of_text|>" , num_base_tokens + special_token_count++);
98
+ special_tokens.emplace (
99
+ " <|reserved_special_token_0|>" ,
100
+ num_base_tokens + special_token_count++);
101
+ special_tokens.emplace (
102
+ " <|reserved_special_token_1|>" ,
103
+ num_base_tokens + special_token_count++);
104
+ special_tokens.emplace (
105
+ " <|reserved_special_token_2|>" ,
106
+ num_base_tokens + special_token_count++);
107
+ special_tokens.emplace (
108
+ " <|reserved_special_token_3|>" ,
109
+ num_base_tokens + special_token_count++);
110
+ special_tokens.emplace (
111
+ " <|start_header_id|>" , num_base_tokens + special_token_count++);
112
+ special_tokens.emplace (
113
+ " <|end_header_id|>" , num_base_tokens + special_token_count++);
114
+ special_tokens.emplace (
115
+ " <|eom_id|>" , num_base_tokens + special_token_count++);
116
+ special_tokens.emplace (
117
+ " <|eot_id|>" , num_base_tokens + special_token_count++);
118
+ special_tokens.emplace (
119
+ " <|image|>" , num_base_tokens + special_token_count++);
120
+
121
+ // pad the rest of the special tokens with reserved tokens except the last
122
+ // one
123
+ ssize_t reserved_special_token_num = 4 ;
124
+ while (special_token_count < kSpecialTokensSize - 1 ) {
125
+ special_tokens.emplace (
126
+ " <|reserved_special_token_" +
127
+ std::to_string (reserved_special_token_num++) + " |>" ,
128
+ num_base_tokens + special_token_count++);
129
+ }
130
+
131
+ special_tokens.emplace (
132
+ " <|python_tag|>" , num_base_tokens + special_token_count++);
133
+
134
+ return special_tokens;
135
+ }
136
+
137
+ inline const Encoder _get_special_tokens (ssize_t num_base_tokens) {
138
+ switch (_version) {
139
+ case MULTIMODAL:
140
+ return _get_multimodal_special_tokens (num_base_tokens);
141
+ default :
142
+ return _get_default_special_tokens (num_base_tokens);
143
+ }
144
+ }
145
+
61
146
template <typename T>
62
147
std::pair<std::optional<std::string>, re2::StringPiece>
63
148
_split_with_allowed_special_token (
@@ -74,6 +159,8 @@ class Tiktoken : public Tokenizer {
74
159
const std::string& text,
75
160
const T& allowed_special) const ;
76
161
162
+ const Version _version;
163
+
77
164
// Removed negative lookahead \s+(?!\S) since it's not supported by RE2.
78
165
const std::string _pattern =
79
166
R"( (?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+)" ;
0 commit comments