Skip to content

Commit 3f53253

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Add tiktoken (#3015)
Summary: Pull Request resolved: #3015 C++ implementation of Tiktoken. Added unit tests. Reviewed By: lucylq Differential Revision: D56053255
1 parent 74576e8 commit 3f53253

File tree

14 files changed

+817
-8
lines changed

14 files changed

+817
-8
lines changed

.gitmodules

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,9 @@
6262
[submodule "examples/third-party/LLaVA"]
6363
path = examples/third-party/LLaVA
6464
url = https://github.com/haotian-liu/LLaVA.git
65+
[submodule "examples/models/llama2/third-party/re2"]
66+
path = examples/models/llama2/third-party/re2
67+
url = https://github.com/google/re2.git
68+
[submodule "examples/models/llama2/third-party/abseil-cpp"]
69+
path = examples/models/llama2/third-party/abseil-cpp
70+
url = https://github.com/abseil/abseil-cpp.git

examples/models/llama2/CMakeLists.txt

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ project(llama_runner)
2121
# Duplicating options as root CMakeLists.txt
2222
option(EXECUTORCH_BUILD_OPTIMIZED "Build the optimized kernels" OFF)
2323

24+
option(EXECUTORCH_BUILD_RE2 "Build RE2" OFF)
25+
2426
include(CMakeDependentOption)
2527
#
2628
# pthreadpool: build pthreadpool library. Disable on unsupported platforms
@@ -86,8 +88,19 @@ endif()
8688

8789
# llama_runner library
8890
add_subdirectory(runner)
89-
90-
set(link_libraries)
91+
if(EXECUTORCH_BUILD_RE2)
92+
# find RE2 for tokenizer
93+
set(ABSL_ENABLE_INSTALL ON)
94+
set(_pic_flag
95+
${CMAKE_POSITION_INDEPENDENT_CODE})
96+
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
97+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third-party/abseil-cpp)
98+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third-party/re2)
99+
set(CMAKE_POSITION_INDEPENDENT_CODE ${_pic_flag})
100+
target_link_libraries(llama_runner PUBLIC re2::re2)
101+
endif()
102+
103+
set(link_libraries gflags)
91104
set(_srcs main.cpp)
92105

93106
if(EXECUTORCH_BUILD_OPTIMIZED)
@@ -162,7 +175,7 @@ if(CMAKE_BUILD_TYPE EQUAL "RELEASE")
162175
endif()
163176

164177
target_include_directories(llama_main PUBLIC ${_common_include_directories})
165-
target_link_libraries(llama_main PUBLIC gflags llama_runner ${link_libraries})
178+
target_link_libraries(llama_main PUBLIC llama_runner ${link_libraries})
166179
target_compile_options(llama_main PUBLIC ${_common_compile_options})
167180

168181
if(APPLE)

examples/models/llama2/main.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ DEFINE_int32(
3939
-1,
4040
"Number of CPU threads for inference. Defaults to -1, which implies we'll use a heuristic to derive the # of performant cores for a specific device.");
4141

42+
DEFINE_bool(
43+
use_tiktoken,
44+
false,
45+
"Use Tiktoken tokenizer instead of the default BPE tokenizer.");
46+
4247
int32_t main(int32_t argc, char** argv) {
4348
gflags::ParseCommandLineFlags(&argc, &argv, true);
4449

@@ -57,6 +62,8 @@ int32_t main(int32_t argc, char** argv) {
5762

5863
int32_t cpu_threads = FLAGS_cpu_threads;
5964

65+
bool use_tiktoken = FLAGS_use_tiktoken;
66+
6067
#if defined(ET_USE_THREADPOOL)
6168
uint32_t num_performant_cores = cpu_threads == -1
6269
? torch::executorch::cpuinfo::get_num_performant_cores()
@@ -69,7 +76,8 @@ int32_t main(int32_t argc, char** argv) {
6976
}
7077
#endif
7178
// create llama runner
72-
::torch::executor::Runner runner(model_path, tokenizer_path, temperature);
79+
::torch::executor::Runner runner(
80+
model_path, tokenizer_path, temperature, use_tiktoken);
7381

7482
// generate
7583
runner.generate(prompt, seq_len);

examples/models/llama2/runner/runner.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include <executorch/examples/models/llama2/runner/runner.h>
1313
#include <executorch/examples/models/llama2/tokenizer/bpe_tokenizer.h>
14+
#include <executorch/examples/models/llama2/tokenizer/tiktoken.h>
1415
#include <executorch/extension/evalue_util/print_evalue.h>
1516
#include <executorch/extension/runner_util/managed_tensor.h>
1617

@@ -37,8 +38,10 @@ std::string statsToJsonString(const Runner::Stats& stats);
3738
Runner::Runner(
3839
const std::string& model_path,
3940
const std::string& tokenizer_path,
40-
const float temperature)
41-
: module_(std::make_unique<Module>(
41+
const float temperature,
42+
bool use_tiktoken)
43+
: use_tiktoken_(use_tiktoken),
44+
module_(std::make_unique<Module>(
4245
model_path,
4346
Module::MlockConfig::UseMlockIgnoreErrors)),
4447
tokenizer_path_(tokenizer_path),
@@ -77,7 +80,11 @@ Error Runner::load() {
7780
append_eos_ = getMetadataHelper("append_eos_to_prompt", false);
7881

7982
// Load tokenizer
80-
tokenizer_ = std::make_unique<BPETokenizer>(vocab_size_, bos_id_, eos_id_);
83+
if (use_tiktoken_) {
84+
tokenizer_ = std::make_unique<Tiktoken>(vocab_size_, bos_id_, eos_id_);
85+
} else {
86+
tokenizer_ = std::make_unique<BPETokenizer>(vocab_size_, bos_id_, eos_id_);
87+
}
8188
tokenizer_->load(tokenizer_path_);
8289
if (tokenizer_->bos_tok() != bos_id_) {
8390
ET_LOG(

examples/models/llama2/runner/runner.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ class Runner {
2929
explicit Runner(
3030
const std::string& model_path,
3131
const std::string& tokenizer_path,
32-
const float temperature = 0.8f);
32+
const float temperature = 0.8f,
33+
bool use_tiktoken = false);
3334

3435
struct Stats {
3536
// Scaling factor for timestamps - in this case, we use ms.
@@ -85,6 +86,7 @@ class Runner {
8586
int32_t n_bos_;
8687
int32_t n_eos_;
8788
int32_t max_seq_len_;
89+
bool use_tiktoken_;
8890
bool use_kv_cache_;
8991
bool use_sdpa_with_kv_cache_;
9092
bool append_eos_;
Submodule abseil-cpp added at 8541930
Submodule re2 added at ac82d4f
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
// @lint-ignore-every LICENSELINT
9+
/**************************************************************************
10+
Copyright (c) 2023 sewenew
11+
12+
Licensed under the Apache License, Version 2.0 (the "License");
13+
you may not use this file except in compliance with the License.
14+
You may obtain a copy of the License at
15+
16+
http://www.apache.org/licenses/LICENSE-2.0
17+
18+
Unless required by applicable law or agreed to in writing, software
19+
distributed under the License is distributed on an "AS IS" BASIS,
20+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21+
See the License for the specific language governing permissions and
22+
limitations under the License.
23+
*************************************************************************/
24+
25+
#pragma once
26+
27+
#include <executorch/runtime/platform/assert.h>
28+
#include <cassert>
29+
#include <string>
30+
#include <string_view>
31+
32+
namespace torch {
33+
namespace executor {
34+
namespace base64 {
35+
36+
std::string decode(const std::string_view& input);
37+
38+
namespace detail {
39+
40+
constexpr uint32_t DECODE_TABLE[] = {
41+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
42+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
43+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 62, 255,
44+
255, 255, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 255, 255,
45+
255, 255, 255, 255, 255, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
46+
10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
47+
25, 255, 255, 255, 255, 255, 255, 26, 27, 28, 29, 30, 31, 32, 33,
48+
34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
49+
49, 50, 51, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
50+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
51+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
52+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
53+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
54+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
55+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
56+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
57+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
58+
255};
59+
60+
inline void validate(uint32_t v) {
61+
ET_CHECK_MSG(v != 255, "invalid char");
62+
}
63+
64+
inline void decode(const std::string_view& input, std::string& output) {
65+
ET_CHECK_MSG(
66+
input.size() == 4, "input length must be 4, got %zu", input.size());
67+
68+
uint32_t val = 0;
69+
70+
uint8_t c = input[0];
71+
auto v = DECODE_TABLE[c];
72+
validate(v);
73+
val = v;
74+
75+
c = input[1];
76+
v = DECODE_TABLE[c];
77+
validate(v);
78+
val = (val << 6) | v;
79+
80+
c = input[2];
81+
v = DECODE_TABLE[c];
82+
validate(v);
83+
val = (val << 6) | v;
84+
85+
c = input[3];
86+
v = DECODE_TABLE[c];
87+
validate(v);
88+
val = (val << 6) | v;
89+
90+
output.push_back(static_cast<char>((val >> 16) & 0xFF));
91+
output.push_back(static_cast<char>((val >> 8) & 0xFF));
92+
output.push_back(static_cast<char>(val & 0xFF));
93+
}
94+
95+
inline void decode_1_padding(
96+
const std::string_view& input,
97+
std::string& output) {
98+
ET_CHECK_MSG(
99+
input.size() == 3, "input length must be 3, got %zu", input.size());
100+
101+
uint32_t val = 0;
102+
103+
uint8_t c = input[0];
104+
auto v = DECODE_TABLE[c];
105+
validate(v);
106+
val = v;
107+
108+
c = input[1];
109+
v = DECODE_TABLE[c];
110+
validate(v);
111+
val = (val << 6) | v;
112+
113+
c = input[2];
114+
v = DECODE_TABLE[c];
115+
validate(v);
116+
val = (val << 6) | v;
117+
118+
output.push_back(static_cast<char>((val >> 10) & 0xFF));
119+
output.push_back(static_cast<char>((val >> 2) & 0xFF));
120+
}
121+
122+
inline void decode_2_padding(
123+
const std::string_view& input,
124+
std::string& output) {
125+
assert(input.size() == 2);
126+
127+
uint32_t val = 0;
128+
129+
uint8_t c = input[0];
130+
auto v = DECODE_TABLE[c];
131+
validate(v);
132+
val = v;
133+
134+
c = input[1];
135+
v = DECODE_TABLE[c];
136+
validate(v);
137+
val = (val << 6) | v;
138+
139+
output.push_back(static_cast<char>((val >> 4) & 0xFF));
140+
}
141+
142+
} // namespace detail
143+
144+
inline std::string decode(const std::string_view& input) {
145+
ET_CHECK_MSG(!input.empty(), "empty input");
146+
147+
// Faster than `input.size() % 4`.
148+
ET_CHECK_MSG(
149+
(input.size() & 3) == 0 && input.size() >= 4,
150+
"input length must be larger than 4 and is multiple of 4, got %zu",
151+
input.size());
152+
153+
std::string output;
154+
output.reserve(input.size() / 4 * 3);
155+
auto idx = 0U;
156+
for (; idx < input.size() - 4; idx += 4) {
157+
detail::decode(input.substr(idx, 4), output);
158+
}
159+
160+
// Last 4 bytes. Might contain paddings.
161+
if (input[idx + 3] == '=') {
162+
if (input[idx + 2] == '=') {
163+
// Tow paddings.
164+
detail::decode_2_padding(input.substr(idx, 2), output);
165+
} else {
166+
// One padding.
167+
detail::decode_1_padding(input.substr(idx, 3), output);
168+
}
169+
} else {
170+
// No padding.
171+
detail::decode(input.substr(idx, 4), output);
172+
}
173+
174+
return output;
175+
}
176+
177+
} // namespace base64
178+
179+
} // namespace executor
180+
} // namespace torch

examples/models/llama2/tokenizer/targets.bzl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@ def define_common_targets():
55
name = "tokenizer",
66
srcs = [
77
"bpe_tokenizer.cpp",
8+
"tiktoken.cpp",
89
],
910
exported_headers = [
1011
"tokenizer.h",
1112
"bpe_tokenizer.h",
13+
"tiktoken.h",
14+
"base64.h",
1215
],
1316
exported_deps = [
1417
"//executorch/runtime/core/exec_aten:lib",
@@ -17,6 +20,9 @@ def define_common_targets():
1720
visibility = [
1821
"@EXECUTORCH_CLIENTS",
1922
],
23+
exported_external_deps = [
24+
"re2",
25+
],
2026
)
2127

2228
runtime.python_library(

examples/models/llama2/tokenizer/test/targets.bzl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,36 @@ def define_common_targets():
2020
},
2121
)
2222

23+
runtime.cxx_test(
24+
name = "test_tiktoken",
25+
srcs = [
26+
"test_tiktoken.cpp",
27+
],
28+
deps = [
29+
"//executorch/examples/models/llama2/tokenizer:tokenizer",
30+
],
31+
env = {
32+
"RESOURCES_PATH": "$(location :resources_fb_only)/resources",
33+
},
34+
external_deps = [
35+
"re2",
36+
],
37+
)
38+
2339
runtime.filegroup(
2440
name = "resources",
2541
srcs = native.glob([
2642
"resources/**",
2743
]),
2844
)
2945

46+
runtime.filegroup(
47+
name = "resources_fb_only",
48+
srcs = native.glob([
49+
"resources/fb/**",
50+
]),
51+
)
52+
3053
runtime.python_test(
3154
name = "test_tokenizer_py",
3255
srcs = [

0 commit comments

Comments
 (0)