1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the BSD-style license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+ #include " hf_tokenizer.h"
9
+
10
+ // Standard
11
+ #include < iostream>
12
+ #include < filesystem>
13
+ #include < fstream>
14
+ #include < string>
15
+
16
+ // Third Party
17
+ #include < nlohmann/json.hpp>
18
+
19
+ namespace fs = std::filesystem;
20
+ using json = nlohmann::json;
21
+
22
+ namespace tokenizers
23
+ {
24
+
25
+ // -------------------------private method end-------------------------------
26
+ // -------------------------public method start-------------------------------
27
+
28
+ Error HFTokenizer::load (const std::string& path) {
29
+
30
+ // If this is a directory, look for tokenizer.json and tokenizer_config.json
31
+ std::string model_json = path;
32
+ std::string model_config_json = " " ;
33
+ if (fs::is_directory (path)) {
34
+ const fs::path root (path);
35
+ model_json = root / " tokenizer.json" ;
36
+ if (!fs::exists (model_json)) {
37
+ fprintf (stderr, " no tokenizer.json found in %s\n " , path.c_str ());
38
+ return Error::LoadFailure;
39
+ }
40
+ const auto model_config_json_path = root / " tokenizer_config.json" ;
41
+ if (fs::exists (model_config_json_path)) {
42
+ model_config_json = model_config_json_path;
43
+ }
44
+ }
45
+
46
+ // Load the tokenizer.json file
47
+ std::ifstream file (model_json);
48
+ if (!file) {
49
+ fprintf (stderr, " failed to open encoder file: %s\n " , path.c_str ());
50
+ return Error::LoadFailure;
51
+ }
52
+ std::string contents (
53
+ (std::istreambuf_iterator<char >(file)), std::istreambuf_iterator<char >());
54
+ json parsed_json;
55
+ try {
56
+ parsed_json = json::parse (contents);
57
+ } catch (const json::exception& e) {
58
+ std::cout << " Error parsing json file: " << e.what () << std::endl;
59
+ return Error::LoadFailure;
60
+ }
61
+
62
+ // Parse the special tokens
63
+ try {
64
+ const auto & special_tokens = parsed_json.at (" added_tokens" );
65
+ for (auto it = special_tokens.begin (); it != special_tokens.end (); ++it) {
66
+ const std::string token = it->at (" content" );
67
+ const uint64_t token_id = it->at (" id" );
68
+ if (!special_token_encoder_.emplace (token, token_id).second ) {
69
+ fprintf (stderr, " duplicate special token: %s\n " , token.c_str ());
70
+ return Error::LoadFailure;
71
+ }
72
+ if (!special_token_decoder_.emplace (token_id, token).second ) {
73
+ fprintf (stderr, " duplicate special token id: %llu\n " , token_id);
74
+ return Error::LoadFailure;
75
+ }
76
+ }
77
+ } catch (const json::out_of_range& e) {
78
+ fprintf (stderr, " Could not parse special tokens: %s\n " , e.what ());
79
+ return Error::LoadFailure;
80
+ }
81
+
82
+ // Parse the standard tokens
83
+ try {
84
+ const auto & vocab = parsed_json.at (" /model/vocab" _json_pointer);
85
+ for (const auto & entry : vocab.items ()) {
86
+ const std::string token = entry.key ();
87
+ const uint64_t token_id = entry.value ();
88
+ // Skip adding special tokens to the standard encoder/decoder
89
+ if (special_token_decoder_.find (token_id) == special_token_decoder_.end ()) {
90
+ if (!encoder_.emplace (token, token_id).second ) {
91
+ fprintf (stderr, " duplicate token: %s\n " , token.c_str ());
92
+ return Error::LoadFailure;
93
+ }
94
+ if (!decoder_.emplace (token_id, token).second ) {
95
+ fprintf (stderr, " duplicate token id: %llu\n " , token_id);
96
+ return Error::LoadFailure;
97
+ }
98
+ }
99
+ }
100
+ } catch (const json::out_of_range& e) {
101
+ fprintf (stderr, " Could not parse tokens: %s\n " , e.what ());
102
+ return Error::LoadFailure;
103
+ }
104
+
105
+ // Set the vocab size to include special tokens
106
+ vocab_size_ = encoder_.size () + special_token_encoder_.size ();
107
+
108
+ // Set up the pre-tokenizer
109
+ try {
110
+ _pretokenizer = PreTokenizerConfig ().parse_json (parsed_json.at (" pre_tokenizer" )).create ();
111
+ } catch (const json::out_of_range& e) {
112
+ fprintf (stderr, " Could not parse pre_tokenizer: %s\n " , e.what ());
113
+ return Error::LoadFailure;
114
+ }
115
+
116
+ // Set up the decoder (optional)
117
+ try {
118
+ _decoder = TokenDecoderConfig ().parse_json (parsed_json.at (" decoder" )).create ();
119
+ } catch (const json::out_of_range& e) {
120
+ // No decoder specified
121
+ }
122
+
123
+ // TODO: Do we need to parse the merges?
124
+
125
+ // If a tokenizer config file is found, parse it to look up the eos/bos tokens
126
+ if (!model_config_json.empty ()) {
127
+
128
+ // Load it and parse it as json
129
+ std::ifstream file (model_config_json);
130
+ if (!file) {
131
+ fprintf (stderr, " failed to open encoder file: %s\n " , path.c_str ());
132
+ return Error::LoadFailure;
133
+ }
134
+ std::string contents (
135
+ (std::istreambuf_iterator<char >(file)), std::istreambuf_iterator<char >());
136
+ json parsed_json;
137
+ try {
138
+ parsed_json = json::parse (contents);
139
+ } catch (const json::exception& e) {
140
+ std::cout << " Error parsing model config json json file: " << e.what () << std::endl;
141
+ return Error::LoadFailure;
142
+ }
143
+
144
+ // Pull out the token strings
145
+ try {
146
+ const std::string bos_token = parsed_json.at (" bos_token" );
147
+ const std::string eos_token = parsed_json.at (" eos_token" );
148
+ const auto & bos_it = special_token_encoder_.find (bos_token);
149
+ const auto & eos_it = special_token_encoder_.find (eos_token);
150
+ if (bos_it == special_token_encoder_.end ()) {
151
+ fprintf (stderr, " BOS token %s not in special tokens\n " , bos_token.c_str ());
152
+ return Error::LoadFailure;
153
+ }
154
+ if (eos_it == special_token_encoder_.end ()) {
155
+ fprintf (stderr, " EOS token %s not in special tokens\n " , eos_token.c_str ());
156
+ return Error::LoadFailure;
157
+ }
158
+ bos_tok_ = bos_it->second ;
159
+ eos_tok_ = eos_it->second ;
160
+ } catch (const json::out_of_range& e) {
161
+ fprintf (stderr, " Could not eos/bos from tokenizer config: %s\n " , e.what ());
162
+ return Error::LoadFailure;
163
+ }
164
+ }
165
+
166
+ // Otherwise, make an educated guess with the following logic:
167
+ // 1. Look for special tokens with "bos"/"begin" or "eos"/"end" in them
168
+ // 2. Sub-qualify with the word "text" if needed
169
+ // 3. If EOS found, but BOS is not (or vice versa), assume they are the same
170
+ else {
171
+ std::vector<std::string> bos_candidates;
172
+ std::vector<std::string> eos_candidates;
173
+ for (const auto & token : special_token_encoder_) {
174
+ if (
175
+ token.first .find (" bos" ) != std::string::npos ||
176
+ token.first .find (" begin" ) != std::string::npos
177
+ ) {
178
+ bos_candidates.push_back (token.first );
179
+ }
180
+ if (
181
+ token.first .find (" eos" ) != std::string::npos ||
182
+ token.first .find (" end" ) != std::string::npos
183
+ ) {
184
+ eos_candidates.push_back (token.first );
185
+ }
186
+ }
187
+ if (bos_candidates.size () > 1 ) {
188
+ const auto orig_candidates = bos_candidates;
189
+ bos_candidates.clear ();
190
+ for (const auto & cand : orig_candidates) {
191
+ if (cand.find (" text" ) != std::string::npos) {
192
+ bos_candidates.push_back (cand);
193
+ }
194
+ }
195
+ }
196
+ if (eos_candidates.size () > 1 ) {
197
+ const auto orig_candidates = eos_candidates;
198
+ eos_candidates.clear ();
199
+ for (const auto & cand : orig_candidates) {
200
+ if (cand.find (" text" ) != std::string::npos) {
201
+ eos_candidates.push_back (cand);
202
+ }
203
+ }
204
+ }
205
+
206
+ // Use if a single candidate
207
+ bool bos_found = false ;
208
+ bool eos_found = false ;
209
+ if (bos_candidates.size () == 1 ) {
210
+ bos_found = true ;
211
+ bos_tok_ = special_token_encoder_[bos_candidates[0 ]];
212
+ }
213
+ if (eos_candidates.size () == 1 ) {
214
+ eos_found = true ;
215
+ eos_tok_ = special_token_encoder_[eos_candidates[0 ]];
216
+ }
217
+
218
+ // Make them the same if only one found
219
+ if (bos_found && ! eos_found) {
220
+ eos_tok_ = bos_tok_;
221
+ } else if (! bos_found && eos_found) {
222
+ bos_tok_ = eos_tok_;
223
+ }
224
+ }
225
+
226
+ // Mark initialized once everything is done
227
+ initialized_ = true ;
228
+
229
+ return Error::Ok;
230
+ }
231
+ // -------------------------public method end-----------------------------------
232
+ // -------------------------private method start--------------------------------
233
+
234
+ Error HFTokenizer::_encode (
235
+ re2::StringPiece& input,
236
+ std::vector<uint64_t >& ret,
237
+ uint64_t & last_piece_token_len
238
+ ) const {
239
+ for (const auto & piece : _pretokenizer->pre_tokenize (input)) {
240
+ auto iter = encoder_.find (piece);
241
+ if (iter != encoder_.end ()) {
242
+ last_piece_token_len = 1 ;
243
+ ret.push_back (iter->second );
244
+ continue ;
245
+ }
246
+ auto tokens = TK_UNWRAP (byte_pair_encode_ (piece, encoder_));
247
+
248
+ last_piece_token_len = tokens.size ();
249
+ ret.insert (ret.end (), tokens.begin (), tokens.end ());
250
+ }
251
+ return Error::Ok;
252
+ }
253
+
254
+ void HFTokenizer::_decode (
255
+ re2::StringPiece input,
256
+ std::string& ret
257
+ ) const {
258
+ if (_decoder) {
259
+ ret += _decoder->decode (input);
260
+ } else {
261
+ ret += input;
262
+ }
263
+ }
264
+
265
+ } // namespace tokenizers
0 commit comments