Skip to content

Commit 225ebe8

Browse files
committed
Let models provider their own specific special tokens
Differential Revision: [D59651199](https://our.internmc.facebook.com/intern/diff/D59651199/) ghstack-source-id: 233429303 Pull Request resolved: #4227
1 parent 074a81e commit 225ebe8

File tree

9 files changed

+162
-128
lines changed

9 files changed

+162
-128
lines changed

examples/models/llama2/runner/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ target_include_directories(
4343

4444
if(EXECUTORCH_USE_TIKTOKEN)
4545
list(APPEND _llama_runner__srcs
46-
${CMAKE_CURRENT_SOURCE_DIR}/../tokenizer/tiktoken.cpp
46+
${CMAKE_CURRENT_SOURCE_DIR}/../tokenizer/llama_tiktoken.cpp
4747
)
4848
set(_preprocessor_flag -DET_USE_TIKTOKEN)
4949
endif()

examples/models/llama2/runner/runner.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
#include <executorch/examples/models/llama2/runner/runner.h>
1313
#if ET_USE_TIKTOKEN
14-
#include <executorch/examples/models/llama2/tokenizer/tiktoken.h>
14+
#include <executorch/examples/models/llama2/tokenizer/llama_tiktoken.h>
1515
#else /* BPE */
1616
#include <executorch/examples/models/llama2/tokenizer/bpe_tokenizer.h>
1717
#endif /* ET_USE_TIKTOKEN*/
@@ -81,7 +81,7 @@ Error Runner::load() {
8181

8282
// Load tokenizer
8383
#if ET_USE_TIKTOKEN
84-
tokenizer_ = std::make_unique<Tiktoken>();
84+
tokenizer_ = std::make_unique<LlamaTiktoken>();
8585
#else
8686
tokenizer_ = std::make_unique<BPETokenizer>();
8787
#endif
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/examples/models/llama2/tokenizer/llama_tiktoken.h>
10+
11+
namespace torch {
12+
namespace executor {
13+
namespace {
14+
static constexpr int32_t kSpecialTokensSize = 256;
15+
16+
static inline const Encoder _get_default_special_tokens(
17+
ssize_t num_base_tokens) {
18+
Encoder special_tokens;
19+
ssize_t special_token_count = 0;
20+
special_tokens.emplace(
21+
"<|begin_of_text|>", num_base_tokens + special_token_count++);
22+
special_tokens.emplace(
23+
"<|end_of_text|>", num_base_tokens + special_token_count++);
24+
special_tokens.emplace(
25+
"<|reserved_special_token_0|>", num_base_tokens + special_token_count++);
26+
special_tokens.emplace(
27+
"<|reserved_special_token_1|>", num_base_tokens + special_token_count++);
28+
special_tokens.emplace(
29+
"<|reserved_special_token_2|>", num_base_tokens + special_token_count++);
30+
special_tokens.emplace(
31+
"<|reserved_special_token_3|>", num_base_tokens + special_token_count++);
32+
special_tokens.emplace(
33+
"<|start_header_id|>", num_base_tokens + special_token_count++);
34+
special_tokens.emplace(
35+
"<|end_header_id|>", num_base_tokens + special_token_count++);
36+
special_tokens.emplace(
37+
"<|reserved_special_token_4|>", num_base_tokens + special_token_count++);
38+
special_tokens.emplace("<|eot_id|>", num_base_tokens + special_token_count++);
39+
40+
// pad the rest of the special tokens with reserved tokens
41+
ssize_t reserved_special_token_num = 5;
42+
while (special_token_count < kSpecialTokensSize) {
43+
special_tokens.emplace(
44+
"<|reserved_special_token_" +
45+
std::to_string(reserved_special_token_num++) + "|>",
46+
num_base_tokens + special_token_count++);
47+
}
48+
return special_tokens;
49+
}
50+
51+
static inline const Encoder _get_multimodal_special_tokens(
52+
ssize_t num_base_tokens) {
53+
ssize_t special_token_count = 0;
54+
Encoder special_tokens;
55+
special_tokens.emplace(
56+
"<|begin_of_text|>", num_base_tokens + special_token_count++);
57+
special_tokens.emplace(
58+
"<|end_of_text|>", num_base_tokens + special_token_count++);
59+
special_tokens.emplace(
60+
"<|reserved_special_token_0|>", num_base_tokens + special_token_count++);
61+
special_tokens.emplace(
62+
"<|reserved_special_token_1|>", num_base_tokens + special_token_count++);
63+
special_tokens.emplace(
64+
"<|reserved_special_token_2|>", num_base_tokens + special_token_count++);
65+
special_tokens.emplace(
66+
"<|reserved_special_token_3|>", num_base_tokens + special_token_count++);
67+
special_tokens.emplace(
68+
"<|start_header_id|>", num_base_tokens + special_token_count++);
69+
special_tokens.emplace(
70+
"<|end_header_id|>", num_base_tokens + special_token_count++);
71+
special_tokens.emplace("<|eom_id|>", num_base_tokens + special_token_count++);
72+
special_tokens.emplace("<|eot_id|>", num_base_tokens + special_token_count++);
73+
special_tokens.emplace("<|image|>", num_base_tokens + special_token_count++);
74+
75+
// pad the rest of the special tokens with reserved tokens except the last
76+
// one
77+
ssize_t reserved_special_token_num = 4;
78+
while (special_token_count < kSpecialTokensSize - 1) {
79+
special_tokens.emplace(
80+
"<|reserved_special_token_" +
81+
std::to_string(reserved_special_token_num++) + "|>",
82+
num_base_tokens + special_token_count++);
83+
}
84+
85+
special_tokens.emplace(
86+
"<|python_tag|>", num_base_tokens + special_token_count++);
87+
88+
return special_tokens;
89+
}
90+
} // namespace
91+
92+
const Encoder LlamaTiktoken::get_special_tokens(ssize_t num_base_tokens) const {
93+
switch (_version) {
94+
case MULTIMODAL:
95+
return _get_multimodal_special_tokens(num_base_tokens);
96+
default:
97+
return _get_default_special_tokens(num_base_tokens);
98+
}
99+
}
100+
} // namespace executor
101+
} // namespace torch
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/examples/models/llama2/tokenizer/tiktoken.h>
12+
13+
namespace torch {
14+
namespace executor {
15+
16+
enum Version {
17+
DEFAULT,
18+
MULTIMODAL,
19+
};
20+
21+
class LlamaTiktoken : public Tiktoken {
22+
public:
23+
explicit LlamaTiktoken(Version version = Version::DEFAULT)
24+
: Tiktoken(), _version(version) {}
25+
~LlamaTiktoken() {}
26+
27+
protected:
28+
const Encoder get_special_tokens(ssize_t num_base_tokens) const override;
29+
30+
private:
31+
const Version _version;
32+
};
33+
} // namespace executor
34+
} // namespace torch

examples/models/llama2/tokenizer/targets.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@ def define_common_targets():
2323
name = "tiktoken",
2424
srcs = [
2525
"tiktoken.cpp",
26+
"llama_tiktoken.cpp",
2627
],
2728
exported_headers = [
2829
"tokenizer.h",
2930
"tiktoken.h",
31+
"llama_tiktoken.h",
3032
"base64.h",
3133
],
3234
exported_deps = [

examples/models/llama2/tokenizer/test/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ set(
2525
_tokenizer_test_srcs
2626
test_tiktoken.cpp
2727
test_bpe_tokenizer.cpp
28-
${CMAKE_CURRENT_SOURCE_DIR}/../tiktoken.cpp
28+
${CMAKE_CURRENT_SOURCE_DIR}/../llama_tiktoken.cpp
2929
${CMAKE_CURRENT_SOURCE_DIR}/../bpe_tokenizer.cpp
3030
)
3131

examples/models/llama2/tokenizer/test/test_tiktoken.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <executorch/examples/models/llama2/tokenizer/tiktoken.h>
9+
#include <executorch/examples/models/llama2/tokenizer/llama_tiktoken.h>
1010
#include <executorch/examples/models/llama2/tokenizer/tokenizer.h>
1111
#include <executorch/runtime/platform/runtime.h>
1212
#include <gtest/gtest.h>
@@ -17,11 +17,11 @@ using namespace ::testing;
1717
namespace torch {
1818
namespace executor {
1919

20-
class TiktokenExtensionTest : public Test {
20+
class LlamaTiktokenExtensionTest : public Test {
2121
public:
2222
void SetUp() override {
2323
torch::executor::runtime_init();
24-
tokenizer_ = std::make_unique<Tiktoken>();
24+
tokenizer_ = std::make_unique<LlamaTiktoken>();
2525
modelPath_ = std::getenv("RESOURCES_PATH") +
2626
std::string("/test_tiktoken_tokenizer.model");
2727
}
@@ -30,11 +30,11 @@ class TiktokenExtensionTest : public Test {
3030
std::string modelPath_;
3131
};
3232

33-
class MultimodalTiktokenV5ExtensionTest : public Test {
33+
class MultimodalLlamaTiktokenV5ExtensionTest : public Test {
3434
public:
3535
void SetUp() override {
3636
torch::executor::runtime_init();
37-
tokenizer_ = std::make_unique<Tiktoken>(MULTIMODAL);
37+
tokenizer_ = std::make_unique<LlamaTiktoken>(MULTIMODAL);
3838
modelPath_ = std::getenv("RESOURCES_PATH") +
3939
std::string("/test_tiktoken_tokenizer.model");
4040
}
@@ -43,33 +43,33 @@ class MultimodalTiktokenV5ExtensionTest : public Test {
4343
std::string modelPath_;
4444
};
4545

46-
TEST_F(TiktokenExtensionTest, EncodeWithoutLoadFails) {
46+
TEST_F(LlamaTiktokenExtensionTest, EncodeWithoutLoadFails) {
4747
Result<std::vector<uint64_t>> res = tokenizer_->encode("hello world", 0, 0);
4848
EXPECT_EQ(res.error(), Error::NotSupported);
4949
}
5050

51-
TEST_F(TiktokenExtensionTest, DecodeWithoutLoadFails) {
51+
TEST_F(LlamaTiktokenExtensionTest, DecodeWithoutLoadFails) {
5252
auto result = tokenizer_->decode(0, 0);
5353
EXPECT_EQ(result.error(), Error::NotSupported);
5454
}
5555

56-
TEST_F(TiktokenExtensionTest, TokenizerVocabSizeIsExpected) {
56+
TEST_F(LlamaTiktokenExtensionTest, TokenizerVocabSizeIsExpected) {
5757
Error res = tokenizer_->load(modelPath_.c_str());
5858
EXPECT_EQ(res, Error::Ok);
5959
EXPECT_EQ(tokenizer_->vocab_size(), 128256);
6060
EXPECT_EQ(tokenizer_->bos_tok(), 128000);
6161
EXPECT_EQ(tokenizer_->eos_tok(), 128001);
6262
}
6363

64-
TEST_F(MultimodalTiktokenV5ExtensionTest, TokenizerVocabSizeIsExpected) {
64+
TEST_F(MultimodalLlamaTiktokenV5ExtensionTest, TokenizerVocabSizeIsExpected) {
6565
Error res = tokenizer_->load(modelPath_.c_str());
6666
EXPECT_EQ(res, Error::Ok);
6767
EXPECT_EQ(tokenizer_->vocab_size(), 128256);
6868
EXPECT_EQ(tokenizer_->bos_tok(), 128000);
6969
EXPECT_EQ(tokenizer_->eos_tok(), 128001);
7070
}
7171

72-
TEST_F(TiktokenExtensionTest, TokenizerEncodeCorrectly) {
72+
TEST_F(LlamaTiktokenExtensionTest, TokenizerEncodeCorrectly) {
7373
Error res = tokenizer_->load(modelPath_.c_str());
7474
EXPECT_EQ(res, Error::Ok);
7575
Result<std::vector<uint64_t>> out = tokenizer_->encode("hello world", 1, 0);
@@ -80,7 +80,7 @@ TEST_F(TiktokenExtensionTest, TokenizerEncodeCorrectly) {
8080
EXPECT_EQ(out.get()[2], 1917);
8181
}
8282

83-
TEST_F(MultimodalTiktokenV5ExtensionTest, TokenizerEncodeCorrectly) {
83+
TEST_F(MultimodalLlamaTiktokenV5ExtensionTest, TokenizerEncodeCorrectly) {
8484
Error res = tokenizer_->load(modelPath_.c_str());
8585
EXPECT_EQ(res, Error::Ok);
8686
Result<std::vector<uint64_t>> out = tokenizer_->encode(
@@ -101,7 +101,7 @@ TEST_F(MultimodalTiktokenV5ExtensionTest, TokenizerEncodeCorrectly) {
101101
}
102102
}
103103

104-
TEST_F(TiktokenExtensionTest, TokenizerDecodeCorrectly) {
104+
TEST_F(LlamaTiktokenExtensionTest, TokenizerDecodeCorrectly) {
105105
Error res = tokenizer_->load(modelPath_.c_str());
106106
EXPECT_EQ(res, Error::Ok);
107107
std::vector<std::string> expected = {"<|begin_of_text|>", "hello", " world"};
@@ -113,7 +113,7 @@ TEST_F(TiktokenExtensionTest, TokenizerDecodeCorrectly) {
113113
}
114114
}
115115

116-
TEST_F(MultimodalTiktokenV5ExtensionTest, TokenizerDecodeCorrectly) {
116+
TEST_F(MultimodalLlamaTiktokenV5ExtensionTest, TokenizerDecodeCorrectly) {
117117
Error res = tokenizer_->load(modelPath_.c_str());
118118
EXPECT_EQ(res, Error::Ok);
119119
std::vector<std::string> expected = {
@@ -135,7 +135,7 @@ TEST_F(MultimodalTiktokenV5ExtensionTest, TokenizerDecodeCorrectly) {
135135
}
136136
}
137137

138-
TEST_F(TiktokenExtensionTest, TokenizerDecodeOutOfRangeFails) {
138+
TEST_F(LlamaTiktokenExtensionTest, TokenizerDecodeOutOfRangeFails) {
139139
Error res = tokenizer_->load(modelPath_.c_str());
140140
EXPECT_EQ(res, Error::Ok);
141141
// The vocab size is 128256, addes 256 just so the token is out of vocab

examples/models/llama2/tokenizer/tiktoken.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ std::pair<std::vector<uint64_t>, uint64_t> Tiktoken::_encode_with_special_token(
334334

335335
Error Tiktoken::load(const std::string& path) {
336336
_encoder = _load_encoder(path);
337-
_special_token_encoder = _get_special_tokens(_encoder.size());
337+
_special_token_encoder = get_special_tokens(_encoder.size());
338338

339339
_decoder = _build_decoder(_encoder);
340340
_special_token_decoder = _build_decoder(_special_token_encoder);

0 commit comments

Comments
 (0)