Skip to content

Commit acad6fe

Browse files
committed
feat: Add HFTokenizer implementation
This still doesn't implement the post_processor portion of the HF tokenizers library Branch: HFTokenizers Signed-off-by: Gabe Goodhart <[email protected]>
1 parent bcbfcc2 commit acad6fe

File tree

2 files changed

+324
-0
lines changed

2 files changed

+324
-0
lines changed

include/hf_tokenizer.h

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// Used by many Huggingface models. Adapted from a combination of the original
10+
// rust implementation (https://github.com/huggingface/tokenizers/tree/main)
11+
// and the corresponding support in llama.cpp
12+
// (https://github.com/ggerganov/llama.cpp)
13+
#pragma once
14+
15+
// Standard
16+
#include <string>
17+
18+
// Third Party
19+
#include <re2/re2.h>
20+
21+
// Local
22+
#include "detail/bpe_tokenizer_base.h"
23+
#include "error.h"
24+
#include "pre_tokenizer.h"
25+
#include "result.h"
26+
#include "token_decoder.h"
27+
28+
namespace tokenizers {
29+
class HFTokenizer : public detail::BPETokenizerBase {
30+
public:
31+
/*-- Public Interface --*/
32+
33+
/**
34+
* Default initialize with no loaded data
35+
*/
36+
explicit HFTokenizer() {}
37+
~HFTokenizer() {}
38+
39+
/**
40+
* Load the model data into the
41+
*/
42+
Error load(const std::string& tokenizer_path) override;
43+
44+
private:
45+
46+
Error _encode(
47+
re2::StringPiece& input,
48+
std::vector<uint64_t>& ret,
49+
uint64_t& last_piece_token_len) const override;
50+
51+
void _decode(
52+
re2::StringPiece input,
53+
std::string& ret) const override;
54+
55+
PreTokenizer::Ptr _pretokenizer;
56+
TokenDecoder::Ptr _decoder;
57+
};
58+
59+
} // namespace tokenizers

src/hf_tokenizer.cpp

Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
#include "hf_tokenizer.h"
9+
10+
// Standard
11+
#include <iostream>
12+
#include <filesystem>
13+
#include <fstream>
14+
#include <string>
15+
16+
// Third Party
17+
#include <nlohmann/json.hpp>
18+
19+
namespace fs = std::filesystem;
20+
using json = nlohmann::json;
21+
22+
namespace tokenizers
23+
{
24+
25+
// -------------------------private method end-------------------------------
26+
// -------------------------public method start-------------------------------
27+
28+
Error HFTokenizer::load(const std::string& path) {
29+
30+
// If this is a directory, look for tokenizer.json and tokenizer_config.json
31+
std::string model_json = path;
32+
std::string model_config_json = "";
33+
if (fs::is_directory(path)) {
34+
const fs::path root(path);
35+
model_json = root / "tokenizer.json";
36+
if (!fs::exists(model_json)) {
37+
fprintf(stderr, "no tokenizer.json found in %s\n", path.c_str());
38+
return Error::LoadFailure;
39+
}
40+
const auto model_config_json_path = root / "tokenizer_config.json";
41+
if (fs::exists(model_config_json_path)) {
42+
model_config_json = model_config_json_path;
43+
}
44+
}
45+
46+
// Load the tokenizer.json file
47+
std::ifstream file(model_json);
48+
if (!file) {
49+
fprintf(stderr, "failed to open encoder file: %s\n", path.c_str());
50+
return Error::LoadFailure;
51+
}
52+
std::string contents(
53+
(std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
54+
json parsed_json;
55+
try {
56+
parsed_json = json::parse(contents);
57+
} catch (const json::exception& e) {
58+
std::cout << "Error parsing json file: " << e.what() << std::endl;
59+
return Error::LoadFailure;
60+
}
61+
62+
// Parse the special tokens
63+
try {
64+
const auto& special_tokens = parsed_json.at("added_tokens");
65+
for (auto it = special_tokens.begin(); it != special_tokens.end(); ++it) {
66+
const std::string token = it->at("content");
67+
const uint64_t token_id = it->at("id");
68+
if (!special_token_encoder_.emplace(token, token_id).second) {
69+
fprintf(stderr, "duplicate special token: %s\n", token.c_str());
70+
return Error::LoadFailure;
71+
}
72+
if (!special_token_decoder_.emplace(token_id, token).second) {
73+
fprintf(stderr, "duplicate special token id: %llu\n", token_id);
74+
return Error::LoadFailure;
75+
}
76+
}
77+
} catch (const json::out_of_range& e) {
78+
fprintf(stderr, "Could not parse special tokens: %s\n", e.what());
79+
return Error::LoadFailure;
80+
}
81+
82+
// Parse the standard tokens
83+
try {
84+
const auto& vocab = parsed_json.at("/model/vocab"_json_pointer);
85+
for (const auto& entry : vocab.items()) {
86+
const std::string token = entry.key();
87+
const uint64_t token_id = entry.value();
88+
// Skip adding special tokens to the standard encoder/decoder
89+
if (special_token_decoder_.find(token_id) == special_token_decoder_.end()) {
90+
if (!encoder_.emplace(token, token_id).second) {
91+
fprintf(stderr, "duplicate token: %s\n", token.c_str());
92+
return Error::LoadFailure;
93+
}
94+
if (!decoder_.emplace(token_id, token).second) {
95+
fprintf(stderr, "duplicate token id: %llu\n", token_id);
96+
return Error::LoadFailure;
97+
}
98+
}
99+
}
100+
} catch (const json::out_of_range& e) {
101+
fprintf(stderr, "Could not parse tokens: %s\n", e.what());
102+
return Error::LoadFailure;
103+
}
104+
105+
// Set the vocab size to include special tokens
106+
vocab_size_ = encoder_.size() + special_token_encoder_.size();
107+
108+
// Set up the pre-tokenizer
109+
try {
110+
_pretokenizer = PreTokenizerConfig().parse_json(parsed_json.at("pre_tokenizer")).create();
111+
} catch (const json::out_of_range& e) {
112+
fprintf(stderr, "Could not parse pre_tokenizer: %s\n", e.what());
113+
return Error::LoadFailure;
114+
}
115+
116+
// Set up the decoder (optional)
117+
try {
118+
_decoder = TokenDecoderConfig().parse_json(parsed_json.at("decoder")).create();
119+
} catch (const json::out_of_range& e) {
120+
// No decoder specified
121+
}
122+
123+
// TODO: Do we need to parse the merges?
124+
125+
// If a tokenizer config file is found, parse it to look up the eos/bos tokens
126+
if (!model_config_json.empty()) {
127+
128+
// Load it and parse it as json
129+
std::ifstream file(model_config_json);
130+
if (!file) {
131+
fprintf(stderr, "failed to open encoder file: %s\n", path.c_str());
132+
return Error::LoadFailure;
133+
}
134+
std::string contents(
135+
(std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
136+
json parsed_json;
137+
try {
138+
parsed_json = json::parse(contents);
139+
} catch (const json::exception& e) {
140+
std::cout << "Error parsing model config json json file: " << e.what() << std::endl;
141+
return Error::LoadFailure;
142+
}
143+
144+
// Pull out the token strings
145+
try {
146+
const std::string bos_token = parsed_json.at("bos_token");
147+
const std::string eos_token = parsed_json.at("eos_token");
148+
const auto& bos_it = special_token_encoder_.find(bos_token);
149+
const auto& eos_it = special_token_encoder_.find(eos_token);
150+
if (bos_it == special_token_encoder_.end()) {
151+
fprintf(stderr, "BOS token %s not in special tokens\n", bos_token.c_str());
152+
return Error::LoadFailure;
153+
}
154+
if (eos_it == special_token_encoder_.end()) {
155+
fprintf(stderr, "EOS token %s not in special tokens\n", eos_token.c_str());
156+
return Error::LoadFailure;
157+
}
158+
bos_tok_ = bos_it->second;
159+
eos_tok_ = eos_it->second;
160+
} catch (const json::out_of_range& e) {
161+
fprintf(stderr, "Could not eos/bos from tokenizer config: %s\n", e.what());
162+
return Error::LoadFailure;
163+
}
164+
}
165+
166+
// Otherwise, make an educated guess with the following logic:
167+
// 1. Look for special tokens with "bos"/"begin" or "eos"/"end" in them
168+
// 2. Sub-qualify with the word "text" if needed
169+
// 3. If EOS found, but BOS is not (or vice versa), assume they are the same
170+
else {
171+
std::vector<std::string> bos_candidates;
172+
std::vector<std::string> eos_candidates;
173+
for (const auto& token : special_token_encoder_) {
174+
if (
175+
token.first.find("bos") != std::string::npos ||
176+
token.first.find("begin") != std::string::npos
177+
) {
178+
bos_candidates.push_back(token.first);
179+
}
180+
if (
181+
token.first.find("eos") != std::string::npos ||
182+
token.first.find("end") != std::string::npos
183+
) {
184+
eos_candidates.push_back(token.first);
185+
}
186+
}
187+
if (bos_candidates.size() > 1) {
188+
const auto orig_candidates = bos_candidates;
189+
bos_candidates.clear();
190+
for (const auto& cand : orig_candidates) {
191+
if (cand.find("text") != std::string::npos) {
192+
bos_candidates.push_back(cand);
193+
}
194+
}
195+
}
196+
if (eos_candidates.size() > 1) {
197+
const auto orig_candidates = eos_candidates;
198+
eos_candidates.clear();
199+
for (const auto& cand : orig_candidates) {
200+
if (cand.find("text") != std::string::npos) {
201+
eos_candidates.push_back(cand);
202+
}
203+
}
204+
}
205+
206+
// Use if a single candidate
207+
bool bos_found = false;
208+
bool eos_found = false;
209+
if (bos_candidates.size() == 1) {
210+
bos_found = true;
211+
bos_tok_ = special_token_encoder_[bos_candidates[0]];
212+
}
213+
if (eos_candidates.size() == 1) {
214+
eos_found = true;
215+
eos_tok_ = special_token_encoder_[eos_candidates[0]];
216+
}
217+
218+
// Make them the same if only one found
219+
if (bos_found && ! eos_found) {
220+
eos_tok_ = bos_tok_;
221+
} else if (! bos_found && eos_found) {
222+
bos_tok_ = eos_tok_;
223+
}
224+
}
225+
226+
// Mark initialized once everything is done
227+
initialized_ = true;
228+
229+
return Error::Ok;
230+
}
231+
// -------------------------public method end-----------------------------------
232+
// -------------------------private method start--------------------------------
233+
234+
Error HFTokenizer::_encode(
235+
re2::StringPiece& input,
236+
std::vector<uint64_t>& ret,
237+
uint64_t& last_piece_token_len
238+
) const {
239+
for (const auto& piece : _pretokenizer->pre_tokenize(input)) {
240+
auto iter = encoder_.find(piece);
241+
if (iter != encoder_.end()) {
242+
last_piece_token_len = 1;
243+
ret.push_back(iter->second);
244+
continue;
245+
}
246+
auto tokens = TK_UNWRAP(byte_pair_encode_(piece, encoder_));
247+
248+
last_piece_token_len = tokens.size();
249+
ret.insert(ret.end(), tokens.begin(), tokens.end());
250+
}
251+
return Error::Ok;
252+
}
253+
254+
void HFTokenizer::_decode(
255+
re2::StringPiece input,
256+
std::string& ret
257+
) const {
258+
if (_decoder) {
259+
ret += _decoder->decode(input);
260+
} else {
261+
ret += input;
262+
}
263+
}
264+
265+
} // namespace tokenizers

0 commit comments

Comments
 (0)