Skip to content

Commit b78d8b2

Browse files
committed
Let models provider their own specific special tokens
Pull Request resolved: #4227 ghstack-source-id: 233473801 Differential Revision: [D59651199](https://our.internmc.facebook.com/intern/diff/D59651199/)
1 parent e570a22 commit b78d8b2

File tree

9 files changed

+151
-116
lines changed

9 files changed

+151
-116
lines changed

examples/models/llama2/runner/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ target_include_directories(
4343

4444
if(EXECUTORCH_USE_TIKTOKEN)
4545
list(APPEND _llama_runner__srcs
46-
${CMAKE_CURRENT_SOURCE_DIR}/../tokenizer/tiktoken.cpp
46+
${CMAKE_CURRENT_SOURCE_DIR}/../tokenizer/llama_tiktoken.cpp
4747
)
4848
set(_preprocessor_flag -DET_USE_TIKTOKEN)
4949
endif()

examples/models/llama2/runner/runner.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
#include <executorch/examples/models/llama2/runner/runner.h>
1313
#if ET_USE_TIKTOKEN
14-
#include <executorch/examples/models/llama2/tokenizer/tiktoken.h>
14+
#include <executorch/examples/models/llama2/tokenizer/llama_tiktoken.h>
1515
#else /* BPE */
1616
#include <executorch/examples/models/llama2/tokenizer/bpe_tokenizer.h>
1717
#endif /* ET_USE_TIKTOKEN*/
@@ -81,7 +81,7 @@ Error Runner::load() {
8181

8282
// Load tokenizer
8383
#if ET_USE_TIKTOKEN
84-
tokenizer_ = std::make_unique<Tiktoken>();
84+
tokenizer_ = std::make_unique<LlamaTiktoken>();
8585
#else
8686
tokenizer_ = std::make_unique<BPETokenizer>();
8787
#endif
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/examples/models/llama2/tokenizer/llama_tiktoken.h>
10+
11+
namespace torch {
12+
namespace executor {
13+
namespace {
14+
static constexpr int32_t kSpecialTokensSize = 256;
15+
16+
static inline const Encoder _get_default_special_tokens(
17+
ssize_t num_base_tokens) {
18+
Encoder special_tokens;
19+
ssize_t special_token_count = 0;
20+
special_tokens.emplace(
21+
"<|begin_of_text|>", num_base_tokens + special_token_count++);
22+
special_tokens.emplace(
23+
"<|end_of_text|>", num_base_tokens + special_token_count++);
24+
special_tokens.emplace(
25+
"<|reserved_special_token_0|>", num_base_tokens + special_token_count++);
26+
special_tokens.emplace(
27+
"<|reserved_special_token_1|>", num_base_tokens + special_token_count++);
28+
special_tokens.emplace(
29+
"<|reserved_special_token_2|>", num_base_tokens + special_token_count++);
30+
special_tokens.emplace(
31+
"<|reserved_special_token_3|>", num_base_tokens + special_token_count++);
32+
special_tokens.emplace(
33+
"<|start_header_id|>", num_base_tokens + special_token_count++);
34+
special_tokens.emplace(
35+
"<|end_header_id|>", num_base_tokens + special_token_count++);
36+
special_tokens.emplace(
37+
"<|reserved_special_token_4|>", num_base_tokens + special_token_count++);
38+
special_tokens.emplace("<|eot_id|>", num_base_tokens + special_token_count++);
39+
40+
// pad the rest of the special tokens with reserved tokens
41+
ssize_t reserved_special_token_num = 5;
42+
while (special_token_count < kSpecialTokensSize) {
43+
special_tokens.emplace(
44+
"<|reserved_special_token_" +
45+
std::to_string(reserved_special_token_num++) + "|>",
46+
num_base_tokens + special_token_count++);
47+
}
48+
return special_tokens;
49+
}
50+
51+
static inline const Encoder _get_multimodal_special_tokens(
52+
ssize_t num_base_tokens) {
53+
ssize_t special_token_count = 0;
54+
Encoder special_tokens;
55+
special_tokens.emplace(
56+
"<|begin_of_text|>", num_base_tokens + special_token_count++);
57+
special_tokens.emplace(
58+
"<|end_of_text|>", num_base_tokens + special_token_count++);
59+
special_tokens.emplace(
60+
"<|reserved_special_token_0|>", num_base_tokens + special_token_count++);
61+
special_tokens.emplace(
62+
"<|reserved_special_token_1|>", num_base_tokens + special_token_count++);
63+
special_tokens.emplace(
64+
"<|reserved_special_token_2|>", num_base_tokens + special_token_count++);
65+
special_tokens.emplace(
66+
"<|reserved_special_token_3|>", num_base_tokens + special_token_count++);
67+
special_tokens.emplace(
68+
"<|start_header_id|>", num_base_tokens + special_token_count++);
69+
special_tokens.emplace(
70+
"<|end_header_id|>", num_base_tokens + special_token_count++);
71+
special_tokens.emplace("<|eom_id|>", num_base_tokens + special_token_count++);
72+
special_tokens.emplace("<|eot_id|>", num_base_tokens + special_token_count++);
73+
special_tokens.emplace("<|image|>", num_base_tokens + special_token_count++);
74+
75+
// pad the rest of the special tokens with reserved tokens except the last
76+
// one
77+
ssize_t reserved_special_token_num = 4;
78+
while (special_token_count < kSpecialTokensSize - 1) {
79+
special_tokens.emplace(
80+
"<|reserved_special_token_" +
81+
std::to_string(reserved_special_token_num++) + "|>",
82+
num_base_tokens + special_token_count++);
83+
}
84+
85+
special_tokens.emplace(
86+
"<|python_tag|>", num_base_tokens + special_token_count++);
87+
88+
return special_tokens;
89+
}
90+
} // namespace
91+
92+
const Encoder LlamaTiktoken::get_special_tokens(ssize_t num_base_tokens) const {
93+
switch (_version) {
94+
case MULTIMODAL:
95+
return _get_multimodal_special_tokens(num_base_tokens);
96+
default:
97+
return _get_default_special_tokens(num_base_tokens);
98+
}
99+
}
100+
} // namespace executor
101+
} // namespace torch
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/examples/models/llama2/tokenizer/tiktoken.h>
12+
13+
namespace torch {
14+
namespace executor {
15+
16+
enum Version {
17+
DEFAULT,
18+
MULTIMODAL,
19+
};
20+
21+
class LlamaTiktoken : public Tiktoken {
22+
public:
23+
explicit LlamaTiktoken(Version version = Version::DEFAULT)
24+
: Tiktoken(), _version(version) {}
25+
~LlamaTiktoken() override {}
26+
27+
protected:
28+
const Encoder get_special_tokens(ssize_t num_base_tokens) const override;
29+
30+
private:
31+
const Version _version;
32+
};
33+
} // namespace executor
34+
} // namespace torch

examples/models/llama2/tokenizer/targets.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@ def define_common_targets():
2323
name = "tiktoken",
2424
srcs = [
2525
"tiktoken.cpp",
26+
"llama_tiktoken.cpp",
2627
],
2728
exported_headers = [
2829
"tokenizer.h",
2930
"tiktoken.h",
31+
"llama_tiktoken.h",
3032
"base64.h",
3133
],
3234
exported_deps = [

examples/models/llama2/tokenizer/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ set(
2626
test_tiktoken.cpp
2727
test_bpe_tokenizer.cpp
2828
${CMAKE_CURRENT_SOURCE_DIR}/../tiktoken.cpp
29+
${CMAKE_CURRENT_SOURCE_DIR}/../llama_tiktoken.cpp
2930
${CMAKE_CURRENT_SOURCE_DIR}/../bpe_tokenizer.cpp
3031
)
3132

examples/models/llama2/tokenizer/test/test_tiktoken.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <executorch/examples/models/llama2/tokenizer/tiktoken.h>
9+
#include <executorch/examples/models/llama2/tokenizer/llama_tiktoken.h>
1010
#include <executorch/examples/models/llama2/tokenizer/tokenizer.h>
1111
#include <executorch/runtime/platform/runtime.h>
1212
#include <gtest/gtest.h>
@@ -21,7 +21,7 @@ class TiktokenExtensionTest : public Test {
2121
public:
2222
void SetUp() override {
2323
torch::executor::runtime_init();
24-
tokenizer_ = std::make_unique<Tiktoken>();
24+
tokenizer_ = std::make_unique<LlamaTiktoken>();
2525
modelPath_ = std::getenv("RESOURCES_PATH") +
2626
std::string("/test_tiktoken_tokenizer.model");
2727
}
@@ -34,7 +34,7 @@ class MultimodalTiktokenV5ExtensionTest : public Test {
3434
public:
3535
void SetUp() override {
3636
torch::executor::runtime_init();
37-
tokenizer_ = std::make_unique<Tiktoken>(MULTIMODAL);
37+
tokenizer_ = std::make_unique<LlamaTiktoken>(MULTIMODAL);
3838
modelPath_ = std::getenv("RESOURCES_PATH") +
3939
std::string("/test_tiktoken_tokenizer.model");
4040
}

examples/models/llama2/tokenizer/tiktoken.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ std::pair<std::vector<uint64_t>, uint64_t> Tiktoken::_encode_with_special_token(
334334

335335
Error Tiktoken::load(const std::string& path) {
336336
_encoder = _load_encoder(path);
337-
_special_token_encoder = _get_special_tokens(_encoder.size());
337+
_special_token_encoder = get_special_tokens(_encoder.size());
338338

339339
_decoder = _build_decoder(_encoder);
340340
_special_token_decoder = _build_decoder(_special_token_encoder);

examples/models/llama2/tokenizer/tiktoken.h

Lines changed: 6 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,10 @@ using Encoder = std::unordered_map<std::string, uint64_t>;
2424
using Decoder = std::unordered_map<uint64_t, std::string>;
2525
using Re2UPtr = std::unique_ptr<re2::RE2>;
2626

27-
constexpr int32_t kSpecialTokensSize = 256;
28-
29-
enum Version {
30-
DEFAULT,
31-
MULTIMODAL,
32-
};
33-
3427
class Tiktoken : public Tokenizer {
3528
public:
36-
explicit Tiktoken(const Version& version = DEFAULT)
37-
: Tokenizer(), _version(version) {}
38-
~Tiktoken(){};
29+
explicit Tiktoken() : Tokenizer() {}
30+
virtual ~Tiktoken() {}
3931

4032
Error load(const std::string& tokenizer_path) override;
4133

@@ -45,104 +37,11 @@ class Tiktoken : public Tokenizer {
4537
Result<std::string> decode(uint64_t prev_token, uint64_t token)
4638
const override;
4739

48-
private:
49-
static inline const Encoder _get_default_special_tokens(
50-
ssize_t num_base_tokens) {
51-
Encoder special_tokens;
52-
ssize_t special_token_count = 0;
53-
special_tokens.emplace(
54-
"<|begin_of_text|>", num_base_tokens + special_token_count++);
55-
special_tokens.emplace(
56-
"<|end_of_text|>", num_base_tokens + special_token_count++);
57-
special_tokens.emplace(
58-
"<|reserved_special_token_0|>",
59-
num_base_tokens + special_token_count++);
60-
special_tokens.emplace(
61-
"<|reserved_special_token_1|>",
62-
num_base_tokens + special_token_count++);
63-
special_tokens.emplace(
64-
"<|reserved_special_token_2|>",
65-
num_base_tokens + special_token_count++);
66-
special_tokens.emplace(
67-
"<|reserved_special_token_3|>",
68-
num_base_tokens + special_token_count++);
69-
special_tokens.emplace(
70-
"<|start_header_id|>", num_base_tokens + special_token_count++);
71-
special_tokens.emplace(
72-
"<|end_header_id|>", num_base_tokens + special_token_count++);
73-
special_tokens.emplace(
74-
"<|reserved_special_token_4|>",
75-
num_base_tokens + special_token_count++);
76-
special_tokens.emplace(
77-
"<|eot_id|>", num_base_tokens + special_token_count++);
78-
79-
// pad the rest of the special tokens with reserved tokens
80-
ssize_t reserved_special_token_num = 5;
81-
while (special_token_count < kSpecialTokensSize) {
82-
special_tokens.emplace(
83-
"<|reserved_special_token_" +
84-
std::to_string(reserved_special_token_num++) + "|>",
85-
num_base_tokens + special_token_count++);
86-
}
87-
return special_tokens;
88-
}
89-
90-
static inline const Encoder _get_multimodal_special_tokens(
91-
ssize_t num_base_tokens) {
92-
ssize_t special_token_count = 0;
93-
Encoder special_tokens;
94-
special_tokens.emplace(
95-
"<|begin_of_text|>", num_base_tokens + special_token_count++);
96-
special_tokens.emplace(
97-
"<|end_of_text|>", num_base_tokens + special_token_count++);
98-
special_tokens.emplace(
99-
"<|reserved_special_token_0|>",
100-
num_base_tokens + special_token_count++);
101-
special_tokens.emplace(
102-
"<|reserved_special_token_1|>",
103-
num_base_tokens + special_token_count++);
104-
special_tokens.emplace(
105-
"<|reserved_special_token_2|>",
106-
num_base_tokens + special_token_count++);
107-
special_tokens.emplace(
108-
"<|reserved_special_token_3|>",
109-
num_base_tokens + special_token_count++);
110-
special_tokens.emplace(
111-
"<|start_header_id|>", num_base_tokens + special_token_count++);
112-
special_tokens.emplace(
113-
"<|end_header_id|>", num_base_tokens + special_token_count++);
114-
special_tokens.emplace(
115-
"<|eom_id|>", num_base_tokens + special_token_count++);
116-
special_tokens.emplace(
117-
"<|eot_id|>", num_base_tokens + special_token_count++);
118-
special_tokens.emplace(
119-
"<|image|>", num_base_tokens + special_token_count++);
120-
121-
// pad the rest of the special tokens with reserved tokens except the last
122-
// one
123-
ssize_t reserved_special_token_num = 4;
124-
while (special_token_count < kSpecialTokensSize - 1) {
125-
special_tokens.emplace(
126-
"<|reserved_special_token_" +
127-
std::to_string(reserved_special_token_num++) + "|>",
128-
num_base_tokens + special_token_count++);
129-
}
130-
131-
special_tokens.emplace(
132-
"<|python_tag|>", num_base_tokens + special_token_count++);
133-
134-
return special_tokens;
135-
}
136-
137-
inline const Encoder _get_special_tokens(ssize_t num_base_tokens) {
138-
switch (_version) {
139-
case MULTIMODAL:
140-
return _get_multimodal_special_tokens(num_base_tokens);
141-
default:
142-
return _get_default_special_tokens(num_base_tokens);
143-
}
144-
}
40+
protected:
41+
// Provide model specific special tokens.
42+
virtual const Encoder get_special_tokens(ssize_t num_base_tokens) const = 0;
14543

44+
private:
14645
template <typename T>
14746
std::pair<std::optional<std::string>, re2::StringPiece>
14847
_split_with_allowed_special_token(
@@ -159,8 +58,6 @@ class Tiktoken : public Tokenizer {
15958
const std::string& text,
16059
const T& allowed_special) const;
16160

162-
const Version _version;
163-
16461
// Removed negative lookahead \s+(?!\S) since it's not supported by RE2.
16562
const std::string _pattern =
16663
R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+)";

0 commit comments

Comments
 (0)