Skip to content

Commit d159de2

Browse files
Lunwen Hefacebook-github-bot
authored andcommitted
Let models provider their own specific special tokens (#4227)
Summary: Pull Request resolved: #4227 ghstack-source-id: 233588006 Reviewed By: larryliu0820 Differential Revision: D59651199 fbshipit-source-id: dfcdc03434e8b8aac0ded235048e401cdfd327c0
1 parent 7021d51 commit d159de2

File tree

9 files changed

+195
-117
lines changed

9 files changed

+195
-117
lines changed

examples/models/llama2/runner/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ target_include_directories(
4343

4444
if(EXECUTORCH_USE_TIKTOKEN)
4545
list(APPEND _llama_runner__srcs
46-
${CMAKE_CURRENT_SOURCE_DIR}/../tokenizer/tiktoken.cpp
46+
${CMAKE_CURRENT_SOURCE_DIR}/../tokenizer/llama_tiktoken.cpp
4747
)
4848
set(_preprocessor_flag -DET_USE_TIKTOKEN)
4949
endif()

examples/models/llama2/runner/runner.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
#include <executorch/examples/models/llama2/runner/runner.h>
1313
#if ET_USE_TIKTOKEN
14-
#include <executorch/examples/models/llama2/tokenizer/tiktoken.h>
14+
#include <executorch/examples/models/llama2/tokenizer/llama_tiktoken.h>
1515
#else /* BPE */
1616
#include <executorch/examples/models/llama2/tokenizer/bpe_tokenizer.h>
1717
#endif /* ET_USE_TIKTOKEN*/
@@ -81,7 +81,7 @@ Error Runner::load() {
8181

8282
// Load tokenizer
8383
#if ET_USE_TIKTOKEN
84-
tokenizer_ = std::make_unique<Tiktoken>();
84+
tokenizer_ = get_tiktoken_for_llama();
8585
#else
8686
tokenizer_ = std::make_unique<BPETokenizer>();
8787
#endif
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/examples/models/llama2/tokenizer/llama_tiktoken.h>
10+
11+
namespace torch {
12+
namespace executor {
13+
namespace {
14+
static constexpr int32_t kSpecialTokensSize = 256;
15+
static constexpr size_t kBOSTokenIndex = 0;
16+
static constexpr size_t kEOSTokenIndex = 1;
17+
18+
static inline std::unique_ptr<std::vector<std::string>>
19+
_get_default_special_tokens() {
20+
auto special_tokens =
21+
std::make_unique<std::vector<std::string>>(std::vector<std::string>{
22+
"<|begin_of_text|>",
23+
"<|end_of_text|>",
24+
"<|reserved_special_token_0|>",
25+
"<|reserved_special_token_1|>",
26+
"<|reserved_special_token_2|>",
27+
"<|reserved_special_token_3|>",
28+
"<|start_header_id|>",
29+
"<|end_header_id|>",
30+
"<|reserved_special_token_4|>",
31+
"<|eot_id|>"});
32+
33+
// pad the rest of the special tokens with reserved tokens
34+
ssize_t reserved_special_token_num = 5;
35+
while (special_tokens->size() < kSpecialTokensSize) {
36+
special_tokens->emplace_back(
37+
"<|reserved_special_token_" +
38+
std::to_string(reserved_special_token_num++) + "|>");
39+
}
40+
return special_tokens;
41+
}
42+
43+
static inline std::unique_ptr<std::vector<std::string>>
44+
_get_multimodal_special_tokens() {
45+
auto special_tokens =
46+
std::make_unique<std::vector<std::string>>(std::vector<std::string>{
47+
"<|begin_of_text|>",
48+
"<|end_of_text|>",
49+
"<|reserved_special_token_0|>",
50+
"<|reserved_special_token_1|>",
51+
"<|reserved_special_token_2|>",
52+
"<|reserved_special_token_3|>",
53+
"<|start_header_id|>",
54+
"<|end_header_id|>",
55+
"<|eom_id|>",
56+
"<|eot_id|>",
57+
"<|image|>"});
58+
59+
// pad the rest of the special tokens with reserved tokens except the last
60+
// one
61+
ssize_t reserved_special_token_num = 4;
62+
while (special_tokens->size() < kSpecialTokensSize - 1) {
63+
special_tokens->emplace_back(
64+
"<|reserved_special_token_" +
65+
std::to_string(reserved_special_token_num++) + "|>");
66+
}
67+
68+
special_tokens->emplace_back("<|python_tag|>");
69+
70+
return special_tokens;
71+
}
72+
73+
std::unique_ptr<std::vector<std::string>> _get_special_tokens(Version version) {
74+
switch (version) {
75+
case MULTIMODAL:
76+
return _get_multimodal_special_tokens();
77+
default:
78+
return _get_default_special_tokens();
79+
}
80+
}
81+
82+
} // namespace
83+
84+
std::unique_ptr<Tiktoken> get_tiktoken_for_llama(Version version) {
85+
return std::make_unique<Tiktoken>(
86+
_get_special_tokens(version), kBOSTokenIndex, kEOSTokenIndex);
87+
}
88+
89+
} // namespace executor
90+
} // namespace torch
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/examples/models/llama2/tokenizer/tiktoken.h>
12+
13+
namespace torch {
14+
namespace executor {
15+
16+
enum Version {
17+
DEFAULT,
18+
MULTIMODAL,
19+
};
20+
21+
std::unique_ptr<Tiktoken> get_tiktoken_for_llama(Version version = DEFAULT);
22+
23+
} // namespace executor
24+
} // namespace torch

examples/models/llama2/tokenizer/targets.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@ def define_common_targets():
2323
name = "tiktoken",
2424
srcs = [
2525
"tiktoken.cpp",
26+
"llama_tiktoken.cpp",
2627
],
2728
exported_headers = [
2829
"tokenizer.h",
2930
"tiktoken.h",
31+
"llama_tiktoken.h",
3032
"base64.h",
3133
],
3234
exported_deps = [

examples/models/llama2/tokenizer/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ set(
2626
test_tiktoken.cpp
2727
test_bpe_tokenizer.cpp
2828
${CMAKE_CURRENT_SOURCE_DIR}/../tiktoken.cpp
29+
${CMAKE_CURRENT_SOURCE_DIR}/../llama_tiktoken.cpp
2930
${CMAKE_CURRENT_SOURCE_DIR}/../bpe_tokenizer.cpp
3031
)
3132

examples/models/llama2/tokenizer/test/test_tiktoken.cpp

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <executorch/examples/models/llama2/tokenizer/tiktoken.h>
9+
#include <executorch/examples/models/llama2/tokenizer/llama_tiktoken.h>
1010
#include <executorch/examples/models/llama2/tokenizer/tokenizer.h>
1111
#include <executorch/runtime/platform/runtime.h>
1212
#include <gtest/gtest.h>
@@ -21,7 +21,7 @@ class TiktokenExtensionTest : public Test {
2121
public:
2222
void SetUp() override {
2323
torch::executor::runtime_init();
24-
tokenizer_ = std::make_unique<Tiktoken>();
24+
tokenizer_ = get_tiktoken_for_llama();
2525
modelPath_ = std::getenv("RESOURCES_PATH") +
2626
std::string("/test_tiktoken_tokenizer.model");
2727
}
@@ -34,7 +34,7 @@ class MultimodalTiktokenV5ExtensionTest : public Test {
3434
public:
3535
void SetUp() override {
3636
torch::executor::runtime_init();
37-
tokenizer_ = std::make_unique<Tiktoken>(MULTIMODAL);
37+
tokenizer_ = get_tiktoken_for_llama(MULTIMODAL);
3838
modelPath_ = std::getenv("RESOURCES_PATH") +
3939
std::string("/test_tiktoken_tokenizer.model");
4040
}
@@ -144,5 +144,34 @@ TEST_F(TiktokenExtensionTest, TokenizerDecodeOutOfRangeFails) {
144144
EXPECT_EQ(out.error(), Error::NotSupported);
145145
}
146146

147+
TEST_F(TiktokenExtensionTest, ConstructionWithInvalidBOSIndex) {
148+
// gtest death test doesn't work on iOS:
149+
// https://github.com/google/googletest/issues/2834
150+
#if !GTEST_OS_IOS
151+
EXPECT_EXIT(
152+
std::make_unique<Tiktoken>(
153+
std::make_unique<std::vector<std::string>>(
154+
std::vector<std::string>{"<|end_of_text|>"}),
155+
1,
156+
0),
157+
::testing::KilledBySignal(SIGABRT),
158+
"");
159+
#endif
160+
}
161+
162+
TEST_F(TiktokenExtensionTest, ConstructionWithInvalidEOSIndex) {
163+
// gtest death test doesn't work on iOS:
164+
// https://github.com/google/googletest/issues/2834
165+
#if !GTEST_OS_IOS
166+
EXPECT_EXIT(
167+
std::make_unique<Tiktoken>(
168+
std::make_unique<std::vector<std::string>>(
169+
std::vector<std::string>{"<|begin_of_text|>"}),
170+
0,
171+
1),
172+
::testing::KilledBySignal(SIGABRT),
173+
"");
174+
#endif
175+
}
147176
} // namespace executor
148177
} // namespace torch

examples/models/llama2/tokenizer/tiktoken.cpp

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,12 +329,38 @@ std::pair<std::vector<uint64_t>, uint64_t> Tiktoken::_encode_with_special_token(
329329
return std::make_pair(tokens, last_piece_token_len);
330330
}
331331

332+
Encoder Tiktoken::_build_special_token_encoder(ssize_t num_base_tokens) const {
333+
Encoder special_token_encoder;
334+
for (ssize_t i = 0; i < _special_tokens->size(); ++i) {
335+
special_token_encoder.emplace(_special_tokens->at(i), num_base_tokens + i);
336+
}
337+
return special_token_encoder;
338+
}
339+
332340
// -------------------------private method end-------------------------------
333341
// -------------------------public method start-------------------------------
334342

343+
Tiktoken::Tiktoken(
344+
std::unique_ptr<std::vector<std::string>> special_tokens,
345+
size_t bos_token_index,
346+
size_t eos_token_index)
347+
: Tokenizer(),
348+
_special_tokens(std::move(special_tokens)),
349+
_bos_token_index(bos_token_index),
350+
_eos_token_index(eos_token_index) {
351+
ET_CHECK_MSG(
352+
_bos_token_index < _special_tokens->size(),
353+
"invalid bos_token_index %zu",
354+
_bos_token_index);
355+
ET_CHECK_MSG(
356+
_eos_token_index < _special_tokens->size(),
357+
"invalid eos_token_index %zu",
358+
_eos_token_index);
359+
}
360+
335361
Error Tiktoken::load(const std::string& path) {
336362
_encoder = _load_encoder(path);
337-
_special_token_encoder = _get_special_tokens(_encoder.size());
363+
_special_token_encoder = _build_special_token_encoder(_encoder.size());
338364

339365
_decoder = _build_decoder(_encoder);
340366
_special_token_decoder = _build_decoder(_special_token_encoder);
@@ -345,8 +371,8 @@ Error Tiktoken::load(const std::string& path) {
345371

346372
// initialize vocab_size, bos_tok, eos_tok
347373
vocab_size_ = _encoder.size() + _special_token_encoder.size();
348-
bos_tok_ = _special_token_encoder.at("<|begin_of_text|>");
349-
eos_tok_ = _special_token_encoder.at("<|end_of_text|>");
374+
bos_tok_ = _special_token_encoder.at(_special_tokens->at(_bos_token_index));
375+
eos_tok_ = _special_token_encoder.at(_special_tokens->at(_eos_token_index));
350376

351377
initialized_ = true;
352378
return Error::Ok;

0 commit comments

Comments
 (0)