Skip to content

Commit 799fe4d

Browse files
committed
Let models provider their own specific special tokens
Pull Request resolved: #4227 ghstack-source-id: 233578696 Differential Revision: [D59651199](https://our.internmc.facebook.com/intern/diff/D59651199/)
1 parent f9efb05 commit 799fe4d

File tree

9 files changed

+193
-117
lines changed

9 files changed

+193
-117
lines changed

examples/models/llama2/runner/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ target_include_directories(
4343

4444
if(EXECUTORCH_USE_TIKTOKEN)
4545
list(APPEND _llama_runner__srcs
46-
${CMAKE_CURRENT_SOURCE_DIR}/../tokenizer/tiktoken.cpp
46+
${CMAKE_CURRENT_SOURCE_DIR}/../tokenizer/llama_tiktoken.cpp
4747
)
4848
set(_preprocessor_flag -DET_USE_TIKTOKEN)
4949
endif()

examples/models/llama2/runner/runner.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
#include <executorch/examples/models/llama2/runner/runner.h>
1313
#if ET_USE_TIKTOKEN
14-
#include <executorch/examples/models/llama2/tokenizer/tiktoken.h>
14+
#include <executorch/examples/models/llama2/tokenizer/llama_tiktoken.h>
1515
#else /* BPE */
1616
#include <executorch/examples/models/llama2/tokenizer/bpe_tokenizer.h>
1717
#endif /* ET_USE_TIKTOKEN*/
@@ -81,7 +81,7 @@ Error Runner::load() {
8181

8282
// Load tokenizer
8383
#if ET_USE_TIKTOKEN
84-
tokenizer_ = std::make_unique<Tiktoken>();
84+
tokenizer_ = get_tiktoken_for_llama();
8585
#else
8686
tokenizer_ = std::make_unique<BPETokenizer>();
8787
#endif
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/examples/models/llama2/tokenizer/llama_tiktoken.h>
10+
11+
namespace torch {
12+
namespace executor {
13+
namespace {
14+
static constexpr int32_t kSpecialTokensSize = 256;
15+
static std::string kBOSToken = "<|begin_of_text|>";
16+
static constexpr size_t kBOSTokenIndex = 0;
17+
static std::string kEOSToken = "<|end_of_text|>";
18+
static constexpr size_t kEOSTokenIndex = 1;
19+
20+
static inline std::unique_ptr<std::vector<std::string>>
21+
_get_default_special_tokens() {
22+
auto special_tokens = std::make_unique<std::vector<std::string>>(
23+
std::vector<std::string>{kBOSToken, kEOSToken});
24+
special_tokens->emplace_back("<|reserved_special_token_0|>");
25+
special_tokens->emplace_back("<|reserved_special_token_1|>");
26+
special_tokens->emplace_back("<|reserved_special_token_2|>");
27+
special_tokens->emplace_back("<|reserved_special_token_3|>");
28+
special_tokens->emplace_back("<|start_header_id|>");
29+
special_tokens->emplace_back("<|end_header_id|>");
30+
special_tokens->emplace_back("<|reserved_special_token_4|>");
31+
special_tokens->emplace_back("<|eot_id|>");
32+
33+
// pad the rest of the special tokens with reserved tokens
34+
ssize_t reserved_special_token_num = 5;
35+
while (special_tokens->size() < kSpecialTokensSize) {
36+
special_tokens->emplace_back(
37+
"<|reserved_special_token_" +
38+
std::to_string(reserved_special_token_num++) + "|>");
39+
}
40+
return special_tokens;
41+
}
42+
43+
static inline std::unique_ptr<std::vector<std::string>>
44+
_get_multimodal_special_tokens() {
45+
auto special_tokens = std::make_unique<std::vector<std::string>>(
46+
std::vector<std::string>{kBOSToken, kEOSToken});
47+
special_tokens->emplace_back("<|reserved_special_token_0|>");
48+
special_tokens->emplace_back("<|reserved_special_token_1|>");
49+
special_tokens->emplace_back("<|reserved_special_token_2|>");
50+
special_tokens->emplace_back("<|reserved_special_token_3|>");
51+
special_tokens->emplace_back("<|start_header_id|>");
52+
special_tokens->emplace_back("<|end_header_id|>");
53+
special_tokens->emplace_back("<|eom_id|>");
54+
special_tokens->emplace_back("<|eot_id|>");
55+
special_tokens->emplace_back("<|image|>");
56+
57+
// pad the rest of the special tokens with reserved tokens except the last
58+
// one
59+
ssize_t reserved_special_token_num = 4;
60+
while (special_tokens->size() < kSpecialTokensSize - 1) {
61+
special_tokens->emplace_back(
62+
"<|reserved_special_token_" +
63+
std::to_string(reserved_special_token_num++) + "|>");
64+
}
65+
66+
special_tokens->emplace_back("<|python_tag|>");
67+
68+
return special_tokens;
69+
}
70+
71+
std::unique_ptr<std::vector<std::string>> _get_special_tokens(Version version) {
72+
switch (version) {
73+
case MULTIMODAL:
74+
return _get_multimodal_special_tokens();
75+
default:
76+
return _get_default_special_tokens();
77+
}
78+
}
79+
80+
} // namespace
81+
82+
std::unique_ptr<Tiktoken> get_tiktoken_for_llama(Version version) {
83+
return std::make_unique<Tiktoken>(
84+
_get_special_tokens(version), kBOSTokenIndex, kEOSTokenIndex);
85+
}
86+
87+
} // namespace executor
88+
} // namespace torch
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/examples/models/llama2/tokenizer/tiktoken.h>
12+
13+
namespace torch {
14+
namespace executor {
15+
16+
enum Version {
17+
DEFAULT,
18+
MULTIMODAL,
19+
};
20+
21+
std::unique_ptr<Tiktoken> get_tiktoken_for_llama(Version version = DEFAULT);
22+
23+
} // namespace executor
24+
} // namespace torch

examples/models/llama2/tokenizer/targets.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@ def define_common_targets():
2323
name = "tiktoken",
2424
srcs = [
2525
"tiktoken.cpp",
26+
"llama_tiktoken.cpp",
2627
],
2728
exported_headers = [
2829
"tokenizer.h",
2930
"tiktoken.h",
31+
"llama_tiktoken.h",
3032
"base64.h",
3133
],
3234
exported_deps = [

examples/models/llama2/tokenizer/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ set(
2626
test_tiktoken.cpp
2727
test_bpe_tokenizer.cpp
2828
${CMAKE_CURRENT_SOURCE_DIR}/../tiktoken.cpp
29+
${CMAKE_CURRENT_SOURCE_DIR}/../llama_tiktoken.cpp
2930
${CMAKE_CURRENT_SOURCE_DIR}/../bpe_tokenizer.cpp
3031
)
3132

examples/models/llama2/tokenizer/test/test_tiktoken.cpp

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <executorch/examples/models/llama2/tokenizer/tiktoken.h>
9+
#include <executorch/examples/models/llama2/tokenizer/llama_tiktoken.h>
1010
#include <executorch/examples/models/llama2/tokenizer/tokenizer.h>
1111
#include <executorch/runtime/platform/runtime.h>
1212
#include <gtest/gtest.h>
@@ -21,7 +21,7 @@ class TiktokenExtensionTest : public Test {
2121
public:
2222
void SetUp() override {
2323
torch::executor::runtime_init();
24-
tokenizer_ = std::make_unique<Tiktoken>();
24+
tokenizer_ = get_tiktoken_for_llama();
2525
modelPath_ = std::getenv("RESOURCES_PATH") +
2626
std::string("/test_tiktoken_tokenizer.model");
2727
}
@@ -34,7 +34,7 @@ class MultimodalTiktokenV5ExtensionTest : public Test {
3434
public:
3535
void SetUp() override {
3636
torch::executor::runtime_init();
37-
tokenizer_ = std::make_unique<Tiktoken>(MULTIMODAL);
37+
tokenizer_ = get_tiktoken_for_llama(MULTIMODAL);
3838
modelPath_ = std::getenv("RESOURCES_PATH") +
3939
std::string("/test_tiktoken_tokenizer.model");
4040
}
@@ -144,5 +144,34 @@ TEST_F(TiktokenExtensionTest, TokenizerDecodeOutOfRangeFails) {
144144
EXPECT_EQ(out.error(), Error::NotSupported);
145145
}
146146

147+
TEST_F(TiktokenExtensionTest, ConstructionWithInvalidBOSIndex) {
148+
// gtest death test doesn't work on iOS:
149+
// https://github.com/google/googletest/issues/2834
150+
#if !GTEST_OS_IOS
151+
EXPECT_EXIT(
152+
std::make_unique<Tiktoken>(
153+
std::make_unique<std::vector<std::string>>(
154+
std::vector<std::string>{"<|end_of_text|>"}),
155+
1,
156+
0),
157+
::testing::KilledBySignal(SIGABRT),
158+
"");
159+
#endif
160+
}
161+
162+
TEST_F(TiktokenExtensionTest, ConstructionWithInvalidEOSIndex) {
163+
// gtest death test doesn't work on iOS:
164+
// https://github.com/google/googletest/issues/2834
165+
#if !GTEST_OS_IOS
166+
EXPECT_EXIT(
167+
std::make_unique<Tiktoken>(
168+
std::make_unique<std::vector<std::string>>(
169+
std::vector<std::string>{"<|begin_of_text|>"}),
170+
0,
171+
1),
172+
::testing::KilledBySignal(SIGABRT),
173+
"");
174+
#endif
175+
}
147176
} // namespace executor
148177
} // namespace torch

examples/models/llama2/tokenizer/tiktoken.cpp

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,12 +329,38 @@ std::pair<std::vector<uint64_t>, uint64_t> Tiktoken::_encode_with_special_token(
329329
return std::make_pair(tokens, last_piece_token_len);
330330
}
331331

332+
Encoder Tiktoken::_build_special_token_encoder(ssize_t num_base_tokens) const {
333+
Encoder special_token_encoder;
334+
for (ssize_t i = 0; i < _special_tokens->size(); ++i) {
335+
special_token_encoder.emplace(_special_tokens->at(i), num_base_tokens + i);
336+
}
337+
return special_token_encoder;
338+
}
339+
332340
// -------------------------private method end-------------------------------
333341
// -------------------------public method start-------------------------------
334342

343+
Tiktoken::Tiktoken(
344+
std::unique_ptr<std::vector<std::string>> special_tokens,
345+
size_t bos_token_index,
346+
size_t eos_token_index)
347+
: Tokenizer(),
348+
_special_tokens(std::move(special_tokens)),
349+
_bos_token_index(bos_token_index),
350+
_eos_token_index(eos_token_index) {
351+
ET_CHECK_MSG(
352+
_bos_token_index < _special_tokens->size(),
353+
"invalid bos_token_index %zu",
354+
_bos_token_index);
355+
ET_CHECK_MSG(
356+
_eos_token_index < _special_tokens->size(),
357+
"invalid eos_token_index %zu",
358+
_eos_token_index);
359+
}
360+
335361
Error Tiktoken::load(const std::string& path) {
336362
_encoder = _load_encoder(path);
337-
_special_token_encoder = _get_special_tokens(_encoder.size());
363+
_special_token_encoder = _build_special_token_encoder(_encoder.size());
338364

339365
_decoder = _build_decoder(_encoder);
340366
_special_token_decoder = _build_decoder(_special_token_encoder);
@@ -345,8 +371,8 @@ Error Tiktoken::load(const std::string& path) {
345371

346372
// initialize vocab_size, bos_tok, eos_tok
347373
vocab_size_ = _encoder.size() + _special_token_encoder.size();
348-
bos_tok_ = _special_token_encoder.at("<|begin_of_text|>");
349-
eos_tok_ = _special_token_encoder.at("<|end_of_text|>");
374+
bos_tok_ = _special_token_encoder.at(_special_tokens->at(_bos_token_index));
375+
eos_tok_ = _special_token_encoder.at(_special_tokens->at(_eos_token_index));
350376

351377
initialized_ = true;
352378
return Error::Ok;

0 commit comments

Comments
 (0)