Skip to content

Commit 0e5dbed

Browse files
committed
Update base for Update on "Remove llama related stuff out of bpe_tokenizer"
We don't need to initialize `vocab_`, `vocab_scores_`, etc. They will be initialized anyway while loading the tokenizer binary. A benefit of removing them is that we can remove these llama related default values and make `bpe_tokenizer` agnostic to models. Differential Revision: [D59664556](https://our.internmc.facebook.com/intern/diff/D59664556/) [ghstack-poisoned]
1 parent 1b5184d commit 0e5dbed

File tree

6 files changed

+122
-95
lines changed

6 files changed

+122
-95
lines changed

examples/models/llama2/runner/runner.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ Error Runner::load() {
8181

8282
// Load tokenizer
8383
#if ET_USE_TIKTOKEN
84-
tokenizer_ = std::make_unique<LlamaTiktoken>();
84+
tokenizer_ = get_tiktoken_for_llama();
8585
#else
8686
tokenizer_ = std::make_unique<BPETokenizer>();
8787
#endif

examples/models/llama2/tokenizer/llama_tiktoken.cpp

Lines changed: 45 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -12,98 +12,77 @@ namespace torch {
1212
namespace executor {
1313
namespace {
1414
static constexpr int32_t kSpecialTokensSize = 256;
15+
static std::string kBOSToken = "<|begin_of_text|>";
16+
static constexpr size_t kBOSTokenIndex = 0;
17+
static std::string kEOSToken = "<|end_of_text|>";
18+
static constexpr size_t kEOSTokenIndex = 1;
1519

16-
static inline const Encoder _get_default_special_tokens(
17-
ssize_t num_base_tokens) {
18-
Encoder special_tokens;
19-
ssize_t special_token_count = 0;
20-
special_tokens.emplace(
21-
"<|begin_of_text|>", num_base_tokens + special_token_count++);
22-
special_tokens.emplace(
23-
"<|end_of_text|>", num_base_tokens + special_token_count++);
24-
special_tokens.emplace(
25-
"<|reserved_special_token_0|>", num_base_tokens + special_token_count++);
26-
special_tokens.emplace(
27-
"<|reserved_special_token_1|>", num_base_tokens + special_token_count++);
28-
special_tokens.emplace(
29-
"<|reserved_special_token_2|>", num_base_tokens + special_token_count++);
30-
special_tokens.emplace(
31-
"<|reserved_special_token_3|>", num_base_tokens + special_token_count++);
32-
special_tokens.emplace(
33-
"<|start_header_id|>", num_base_tokens + special_token_count++);
34-
special_tokens.emplace(
35-
"<|end_header_id|>", num_base_tokens + special_token_count++);
36-
special_tokens.emplace(
37-
"<|reserved_special_token_4|>", num_base_tokens + special_token_count++);
38-
special_tokens.emplace("<|eot_id|>", num_base_tokens + special_token_count++);
20+
static inline std::unique_ptr<std::vector<std::string>>
21+
_get_default_special_tokens() {
22+
auto special_tokens = std::make_unique<std::vector<std::string>>(
23+
std::vector<std::string>{kBOSToken, kEOSToken});
24+
special_tokens->emplace_back("<|reserved_special_token_0|>");
25+
special_tokens->emplace_back("<|reserved_special_token_1|>");
26+
special_tokens->emplace_back("<|reserved_special_token_2|>");
27+
special_tokens->emplace_back("<|reserved_special_token_3|>");
28+
special_tokens->emplace_back("<|start_header_id|>");
29+
special_tokens->emplace_back("<|end_header_id|>");
30+
special_tokens->emplace_back("<|reserved_special_token_4|>");
31+
special_tokens->emplace_back("<|eot_id|>");
3932

4033
// pad the rest of the special tokens with reserved tokens
4134
ssize_t reserved_special_token_num = 5;
42-
while (special_token_count < kSpecialTokensSize) {
43-
special_tokens.emplace(
35+
while (special_tokens->size() < kSpecialTokensSize) {
36+
special_tokens->emplace_back(
4437
"<|reserved_special_token_" +
45-
std::to_string(reserved_special_token_num++) + "|>",
46-
num_base_tokens + special_token_count++);
38+
std::to_string(reserved_special_token_num++) + "|>");
4739
}
4840
return special_tokens;
4941
}
5042

51-
static inline const Encoder _get_multimodal_special_tokens(
52-
ssize_t num_base_tokens) {
53-
ssize_t special_token_count = 0;
54-
Encoder special_tokens;
55-
special_tokens.emplace(
56-
"<|begin_of_text|>", num_base_tokens + special_token_count++);
57-
special_tokens.emplace(
58-
"<|end_of_text|>", num_base_tokens + special_token_count++);
59-
special_tokens.emplace(
60-
"<|reserved_special_token_0|>", num_base_tokens + special_token_count++);
61-
special_tokens.emplace(
62-
"<|reserved_special_token_1|>", num_base_tokens + special_token_count++);
63-
special_tokens.emplace(
64-
"<|reserved_special_token_2|>", num_base_tokens + special_token_count++);
65-
special_tokens.emplace(
66-
"<|reserved_special_token_3|>", num_base_tokens + special_token_count++);
67-
special_tokens.emplace(
68-
"<|start_header_id|>", num_base_tokens + special_token_count++);
69-
special_tokens.emplace(
70-
"<|end_header_id|>", num_base_tokens + special_token_count++);
71-
special_tokens.emplace("<|eom_id|>", num_base_tokens + special_token_count++);
72-
special_tokens.emplace("<|eot_id|>", num_base_tokens + special_token_count++);
73-
special_tokens.emplace("<|image|>", num_base_tokens + special_token_count++);
43+
static inline std::unique_ptr<std::vector<std::string>>
44+
_get_multimodal_special_tokens() {
45+
auto special_tokens = std::make_unique<std::vector<std::string>>(
46+
std::vector<std::string>{kBOSToken, kEOSToken});
47+
special_tokens->emplace_back("<|reserved_special_token_0|>");
48+
special_tokens->emplace_back("<|reserved_special_token_1|>");
49+
special_tokens->emplace_back("<|reserved_special_token_2|>");
50+
special_tokens->emplace_back("<|reserved_special_token_3|>");
51+
special_tokens->emplace_back("<|start_header_id|>");
52+
special_tokens->emplace_back("<|end_header_id|>");
53+
special_tokens->emplace_back("<|eom_id|>");
54+
special_tokens->emplace_back("<|eot_id|>");
55+
special_tokens->emplace_back("<|image|>");
7456

7557
// pad the rest of the special tokens with reserved tokens except the last
7658
// one
7759
ssize_t reserved_special_token_num = 4;
78-
while (special_token_count < kSpecialTokensSize - 1) {
79-
special_tokens.emplace(
60+
while (special_tokens->size() < kSpecialTokensSize - 1) {
61+
special_tokens->emplace_back(
8062
"<|reserved_special_token_" +
81-
std::to_string(reserved_special_token_num++) + "|>",
82-
num_base_tokens + special_token_count++);
63+
std::to_string(reserved_special_token_num++) + "|>");
8364
}
8465

85-
special_tokens.emplace(
86-
"<|python_tag|>", num_base_tokens + special_token_count++);
66+
special_tokens->emplace_back("<|python_tag|>");
8767

8868
return special_tokens;
8969
}
90-
} // namespace
9170

92-
const Encoder LlamaTiktoken::get_special_tokens(ssize_t num_base_tokens) const {
93-
switch (_version) {
71+
std::unique_ptr<std::vector<std::string>> _get_special_tokens(Version version) {
72+
switch (version) {
9473
case MULTIMODAL:
95-
return _get_multimodal_special_tokens(num_base_tokens);
74+
return _get_multimodal_special_tokens();
9675
default:
97-
return _get_default_special_tokens(num_base_tokens);
76+
return _get_default_special_tokens();
9877
}
9978
}
10079

101-
const std::string LlamaTiktoken::get_bos_token() const {
102-
return "<|begin_of_text|>";
103-
}
80+
} // namespace
10481

105-
const std::string LlamaTiktoken::get_eos_token() const {
106-
return "<|end_of_text|>";
82+
std::unique_ptr<Tiktoken> get_tiktoken_for_llama(Version version) {
83+
return std::make_unique<Tiktoken>(
84+
_get_special_tokens(version), kBOSTokenIndex, kEOSTokenIndex);
10785
}
86+
10887
} // namespace executor
10988
} // namespace torch

examples/models/llama2/tokenizer/llama_tiktoken.h

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,7 @@ enum Version {
1818
MULTIMODAL,
1919
};
2020

21-
class LlamaTiktoken : public Tiktoken {
22-
public:
23-
explicit LlamaTiktoken(Version version = Version::DEFAULT)
24-
: Tiktoken(), _version(version) {}
25-
~LlamaTiktoken() override {}
21+
std::unique_ptr<Tiktoken> get_tiktoken_for_llama(Version version = DEFAULT);
2622

27-
protected:
28-
const Encoder get_special_tokens(ssize_t num_base_tokens) const override;
29-
const std::string get_bos_token() const override;
30-
const std::string get_eos_token() const override;
31-
32-
private:
33-
const Version _version;
34-
};
3523
} // namespace executor
3624
} // namespace torch

examples/models/llama2/tokenizer/test/test_tiktoken.cpp

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class TiktokenExtensionTest : public Test {
2121
public:
2222
void SetUp() override {
2323
torch::executor::runtime_init();
24-
tokenizer_ = std::make_unique<LlamaTiktoken>();
24+
tokenizer_ = get_tiktoken_for_llama();
2525
modelPath_ = std::getenv("RESOURCES_PATH") +
2626
std::string("/test_tiktoken_tokenizer.model");
2727
}
@@ -34,7 +34,7 @@ class MultimodalTiktokenV5ExtensionTest : public Test {
3434
public:
3535
void SetUp() override {
3636
torch::executor::runtime_init();
37-
tokenizer_ = std::make_unique<LlamaTiktoken>(MULTIMODAL);
37+
tokenizer_ = get_tiktoken_for_llama(MULTIMODAL);
3838
modelPath_ = std::getenv("RESOURCES_PATH") +
3939
std::string("/test_tiktoken_tokenizer.model");
4040
}
@@ -144,5 +144,34 @@ TEST_F(TiktokenExtensionTest, TokenizerDecodeOutOfRangeFails) {
144144
EXPECT_EQ(out.error(), Error::NotSupported);
145145
}
146146

147+
TEST_F(TiktokenExtensionTest, ConstructionWithInvalidBOSIndex) {
148+
// gtest death test doesn't work on iOS:
149+
// https://github.com/google/googletest/issues/2834
150+
#if !GTEST_OS_IOS
151+
EXPECT_EXIT(
152+
std::make_unique<Tiktoken>(
153+
std::make_unique<std::vector<std::string>>(
154+
std::vector<std::string>{"<|end_of_text|>"}),
155+
1,
156+
0),
157+
::testing::KilledBySignal(SIGABRT),
158+
"");
159+
#endif
160+
}
161+
162+
TEST_F(TiktokenExtensionTest, ConstructionWithInvalidEOSIndex) {
163+
// gtest death test doesn't work on iOS:
164+
// https://github.com/google/googletest/issues/2834
165+
#if !GTEST_OS_IOS
166+
EXPECT_EXIT(
167+
std::make_unique<Tiktoken>(
168+
std::make_unique<std::vector<std::string>>(
169+
std::vector<std::string>{"<|begin_of_text|>"}),
170+
0,
171+
1),
172+
::testing::KilledBySignal(SIGABRT),
173+
"");
174+
#endif
175+
}
147176
} // namespace executor
148177
} // namespace torch

examples/models/llama2/tokenizer/tiktoken.cpp

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -330,12 +330,38 @@ std::pair<std::vector<uint64_t>, uint64_t> Tiktoken::_encode_with_special_token(
330330
return std::make_pair(tokens, last_piece_token_len);
331331
}
332332

333+
Encoder Tiktoken::_build_special_token_encoder(ssize_t num_base_tokens) const {
334+
Encoder special_token_encoder;
335+
for (ssize_t i = 0; i < _special_tokens->size(); ++i) {
336+
special_token_encoder.emplace(_special_tokens->at(i), num_base_tokens + i);
337+
}
338+
return special_token_encoder;
339+
}
340+
333341
// -------------------------private method end-------------------------------
334342
// -------------------------public method start-------------------------------
335343

344+
Tiktoken::Tiktoken(
345+
std::unique_ptr<std::vector<std::string>> special_tokens,
346+
size_t bos_token_index,
347+
size_t eos_token_index)
348+
: Tokenizer(),
349+
_special_tokens(std::move(special_tokens)),
350+
_bos_token_index(bos_token_index),
351+
_eos_token_index(eos_token_index) {
352+
ET_CHECK_MSG(
353+
_bos_token_index < _special_tokens->size(),
354+
"invalid bos_token_index %zu",
355+
_bos_token_index);
356+
ET_CHECK_MSG(
357+
_eos_token_index < _special_tokens->size(),
358+
"invalid eos_token_index %zu",
359+
_eos_token_index);
360+
}
361+
336362
Error Tiktoken::load(const std::string& path) {
337363
_encoder = _load_encoder(path);
338-
_special_token_encoder = get_special_tokens(_encoder.size());
364+
_special_token_encoder = _build_special_token_encoder(_encoder.size());
339365

340366
_decoder = _build_decoder(_encoder);
341367
_special_token_decoder = _build_decoder(_special_token_encoder);
@@ -346,8 +372,8 @@ Error Tiktoken::load(const std::string& path) {
346372

347373
// initialize vocab_size, bos_tok, eos_tok
348374
vocab_size_ = _encoder.size() + _special_token_encoder.size();
349-
bos_tok_ = _special_token_encoder.at(get_bos_token());
350-
eos_tok_ = _special_token_encoder.at(get_eos_token());
375+
bos_tok_ = _special_token_encoder.at(_special_tokens->at(_bos_token_index));
376+
eos_tok_ = _special_token_encoder.at(_special_tokens->at(_eos_token_index));
351377

352378
initialized_ = true;
353379
return Error::Ok;

examples/models/llama2/tokenizer/tiktoken.h

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,16 @@ using Re2UPtr = std::unique_ptr<re2::RE2>;
2525

2626
class Tiktoken : public Tokenizer {
2727
public:
28-
explicit Tiktoken() : Tokenizer() {}
29-
virtual ~Tiktoken() {}
28+
/**
29+
* @param[in] special_tokens List of special tokens including bos, eos;
30+
* @param[in] bos_token_index Index of the bos token in special_tokens;
31+
* @param[in] eos_token_index Index of the eos token in special_tokens.
32+
*/
33+
explicit Tiktoken(
34+
std::unique_ptr<std::vector<std::string>> special_tokens,
35+
size_t bos_token_index,
36+
size_t eos_token_index);
37+
~Tiktoken() {}
3038

3139
Error load(const std::string& tokenizer_path) override;
3240

@@ -36,14 +44,6 @@ class Tiktoken : public Tokenizer {
3644
Result<std::string> decode(uint64_t prev_token, uint64_t token)
3745
const override;
3846

39-
protected:
40-
// Provide model specific special tokens.
41-
virtual const Encoder get_special_tokens(ssize_t num_base_tokens) const = 0;
42-
// Provide beginning of sentence token.
43-
virtual const std::string get_bos_token() const = 0;
44-
// Provide end of sentence token.
45-
virtual const std::string get_eos_token() const = 0;
46-
4747
private:
4848
template <typename T>
4949
std::pair<std::optional<std::string>, re2::StringPiece>
@@ -61,6 +61,11 @@ class Tiktoken : public Tokenizer {
6161
const std::string& text,
6262
const T& allowed_special) const;
6363

64+
Encoder _build_special_token_encoder(ssize_t num_base_tokens) const;
65+
66+
std::unique_ptr<std::vector<std::string>> _special_tokens;
67+
size_t _bos_token_index;
68+
size_t _eos_token_index;
6469
// Removed negative lookahead \s+(?!\S) since it's not supported by RE2.
6570
const std::string _pattern =
6671
R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+)";

0 commit comments

Comments
 (0)