Skip to content

Commit 9d9308a

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Add tiktoken
Summary: C++ implementation of Tiktoken. Added unit tests. Differential Revision: D56053255
1 parent 7adbcad commit 9d9308a

File tree

13 files changed

+805
-7
lines changed

13 files changed

+805
-7
lines changed

CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ option(EXECUTORCH_BUILD_XNNPACK "Build the XNNPACK backend" OFF)
174174

175175
option(EXECUTORCH_BUILD_VULKAN "Build the Vulkan backend" OFF)
176176

177+
option(EXECUTORCH_BUILD_RE2 "Build RE2" OFF)
177178
#
178179
# pthreadpool: build pthreadpool library. Disable on unsupported platforms
179180
#
@@ -518,6 +519,10 @@ if(EXECUTORCH_BUILD_COREML)
518519
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/apple/coreml)
519520
endif()
520521

522+
if(EXECUTORCH_BUILD_RE2)
523+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/examples/third-party/re2)
524+
endif()
525+
521526
if(EXECUTORCH_BUILD_PYBIND)
522527
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third-party/pybind11)
523528

examples/models/llama2/CMakeLists.txt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ set(_common_include_directories ${EXECUTORCH_ROOT}/..)
6868
set(gflags_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/gflags)
6969
find_package(gflags REQUIRED)
7070

71+
# find RE2 for tokenizer
72+
find_package(re2 REQUIRED)
73+
7174
#
7275
# llama_main: test binary to run llama, with tokenizer and sampler integrated
7376
#
@@ -87,7 +90,7 @@ endif()
8790
# llama_runner library
8891
add_subdirectory(runner)
8992

90-
set(link_libraries)
93+
set(link_libraries gflags re2)
9194
set(_srcs main.cpp)
9295

9396
if(EXECUTORCH_BUILD_OPTIMIZED)
@@ -162,7 +165,7 @@ if(CMAKE_BUILD_TYPE EQUAL "RELEASE")
162165
endif()
163166

164167
target_include_directories(llama_main PUBLIC ${_common_include_directories})
165-
target_link_libraries(llama_main PUBLIC gflags llama_runner ${link_libraries})
168+
target_link_libraries(llama_main PUBLIC llama_runner ${link_libraries})
166169
target_compile_options(llama_main PUBLIC ${_common_compile_options})
167170

168171
if(APPLE)

examples/models/llama2/main.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ DEFINE_int32(
3939
-1,
4040
"Number of CPU threads for inference. Defaults to -1, which implies we'll use a heuristic to derive the # of performant cores for a specific device.");
4141

42+
DEFINE_bool(
43+
use_tiktoken,
44+
false,
45+
"Use Tiktoken tokenizer instead of the default BPE tokenizer.");
46+
4247
int32_t main(int32_t argc, char** argv) {
4348
gflags::ParseCommandLineFlags(&argc, &argv, true);
4449

@@ -57,6 +62,8 @@ int32_t main(int32_t argc, char** argv) {
5762

5863
int32_t cpu_threads = FLAGS_cpu_threads;
5964

65+
bool use_tiktoken = FLAGS_use_tiktoken;
66+
6067
#if defined(ET_USE_THREADPOOL)
6168
uint32_t num_performant_cores = cpu_threads == -1
6269
? torch::executorch::cpuinfo::get_num_performant_cores()
@@ -69,7 +76,8 @@ int32_t main(int32_t argc, char** argv) {
6976
}
7077
#endif
7178
// create llama runner
72-
::torch::executor::Runner runner(model_path, tokenizer_path, temperature);
79+
::torch::executor::Runner runner(
80+
model_path, tokenizer_path, temperature, use_tiktoken);
7381

7482
// generate
7583
runner.generate(prompt, seq_len);

examples/models/llama2/runner/runner.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include <executorch/examples/models/llama2/runner/runner.h>
1313
#include <executorch/examples/models/llama2/tokenizer/bpe_tokenizer.h>
14+
#include <executorch/examples/models/llama2/tokenizer/tiktoken.h>
1415
#include <executorch/extension/evalue_util/print_evalue.h>
1516
#include <executorch/extension/runner_util/managed_tensor.h>
1617

@@ -37,8 +38,10 @@ std::string statsToJsonString(const Runner::Stats& stats);
3738
Runner::Runner(
3839
const std::string& model_path,
3940
const std::string& tokenizer_path,
40-
const float temperature)
41-
: module_(std::make_unique<Module>(
41+
const float temperature,
42+
bool use_tiktoken)
43+
: use_tiktoken_(use_tiktoken),
44+
module_(std::make_unique<Module>(
4245
model_path,
4346
Module::MlockConfig::UseMlockIgnoreErrors)),
4447
tokenizer_path_(tokenizer_path),
@@ -77,7 +80,11 @@ Error Runner::load() {
7780
append_eos_ = getMetadataHelper("append_eos_to_prompt", false);
7881

7982
// Load tokenizer
80-
tokenizer_ = std::make_unique<BPETokenizer>(vocab_size_, bos_id_, eos_id_);
83+
if (use_tiktoken_) {
84+
tokenizer_ = std::make_unique<Tiktoken>(vocab_size_, bos_id_, eos_id_);
85+
} else {
86+
tokenizer_ = std::make_unique<BPETokenizer>(vocab_size_, bos_id_, eos_id_);
87+
}
8188
tokenizer_->load(tokenizer_path_);
8289
if (tokenizer_->bos_tok() != bos_id_) {
8390
ET_LOG(

examples/models/llama2/runner/runner.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ class Runner {
2929
explicit Runner(
3030
const std::string& model_path,
3131
const std::string& tokenizer_path,
32-
const float temperature = 0.8f);
32+
const float temperature = 0.8f,
33+
bool use_tiktoken = false);
3334

3435
struct Stats {
3536
// Scaling factor for timestamps - in this case, we use ms.
@@ -85,6 +86,7 @@ class Runner {
8586
int32_t n_bos_;
8687
int32_t n_eos_;
8788
int32_t max_seq_len_;
89+
bool use_tiktoken_;
8890
bool use_kv_cache_;
8991
bool use_sdpa_with_kv_cache_;
9092
bool append_eos_;
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
// @lint-ignore-every LICENSELINT
9+
/**************************************************************************
10+
Copyright (c) 2023 sewenew
11+
12+
Licensed under the Apache License, Version 2.0 (the "License");
13+
you may not use this file except in compliance with the License.
14+
You may obtain a copy of the License at
15+
16+
http://www.apache.org/licenses/LICENSE-2.0
17+
18+
Unless required by applicable law or agreed to in writing, software
19+
distributed under the License is distributed on an "AS IS" BASIS,
20+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21+
See the License for the specific language governing permissions and
22+
limitations under the License.
23+
*************************************************************************/
24+
25+
#pragma once
26+
27+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
28+
#include <runtime/platform/assert.h>
29+
#include <cassert>
30+
#include <string>
31+
#include <string_view>
32+
33+
namespace torch {
34+
namespace executor {
35+
namespace base64 {
36+
37+
std::string decode(const std::string_view& input);
38+
39+
namespace detail {
40+
41+
constexpr uint32_t DECODE_TABLE[] = {
42+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
43+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
44+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 62, 255,
45+
255, 255, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 255, 255,
46+
255, 255, 255, 255, 255, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
47+
10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
48+
25, 255, 255, 255, 255, 255, 255, 26, 27, 28, 29, 30, 31, 32, 33,
49+
34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
50+
49, 50, 51, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
51+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
52+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
53+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
54+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
55+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
56+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
57+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
58+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
59+
255};
60+
61+
inline void validate(uint32_t v) {
62+
ET_CHECK_MSG(v != 255, "invalid char");
63+
}
64+
65+
inline void decode(const std::string_view& input, std::string& output) {
66+
ET_CHECK_MSG(
67+
input.size() == 4, "input length must be 4, got %zu", input.size());
68+
69+
uint32_t val = 0;
70+
71+
uint8_t c = input[0];
72+
auto v = DECODE_TABLE[c];
73+
validate(v);
74+
val = v;
75+
76+
c = input[1];
77+
v = DECODE_TABLE[c];
78+
validate(v);
79+
val = (val << 6) | v;
80+
81+
c = input[2];
82+
v = DECODE_TABLE[c];
83+
validate(v);
84+
val = (val << 6) | v;
85+
86+
c = input[3];
87+
v = DECODE_TABLE[c];
88+
validate(v);
89+
val = (val << 6) | v;
90+
91+
output.push_back(static_cast<char>((val >> 16) & 0xFF));
92+
output.push_back(static_cast<char>((val >> 8) & 0xFF));
93+
output.push_back(static_cast<char>(val & 0xFF));
94+
}
95+
96+
inline void decode_1_padding(
97+
const std::string_view& input,
98+
std::string& output) {
99+
ET_CHECK_MSG(
100+
input.size() == 3, "input length must be 3, got %zu", input.size());
101+
102+
uint32_t val = 0;
103+
104+
uint8_t c = input[0];
105+
auto v = DECODE_TABLE[c];
106+
validate(v);
107+
val = v;
108+
109+
c = input[1];
110+
v = DECODE_TABLE[c];
111+
validate(v);
112+
val = (val << 6) | v;
113+
114+
c = input[2];
115+
v = DECODE_TABLE[c];
116+
validate(v);
117+
val = (val << 6) | v;
118+
119+
output.push_back(static_cast<char>((val >> 10) & 0xFF));
120+
output.push_back(static_cast<char>((val >> 2) & 0xFF));
121+
}
122+
123+
inline void decode_2_padding(
124+
const std::string_view& input,
125+
std::string& output) {
126+
assert(input.size() == 2);
127+
128+
uint32_t val = 0;
129+
130+
uint8_t c = input[0];
131+
auto v = DECODE_TABLE[c];
132+
validate(v);
133+
val = v;
134+
135+
c = input[1];
136+
v = DECODE_TABLE[c];
137+
validate(v);
138+
val = (val << 6) | v;
139+
140+
output.push_back(static_cast<char>((val >> 4) & 0xFF));
141+
}
142+
143+
} // namespace detail
144+
145+
inline std::string decode(const std::string_view& input) {
146+
ET_CHECK_MSG(!input.empty(), "empty input");
147+
148+
// Faster than `input.size() % 4`.
149+
ET_CHECK_MSG(
150+
(input.size() & 3) == 0 && input.size() >= 4,
151+
"input length must be larger than 4 and is multiple of 4, got %zu",
152+
input.size());
153+
154+
std::string output;
155+
output.reserve(input.size() / 4 * 3);
156+
auto idx = 0U;
157+
for (; idx < input.size() - 4; idx += 4) {
158+
detail::decode(input.substr(idx, 4), output);
159+
}
160+
161+
// Last 4 bytes. Might contain paddings.
162+
if (input[idx + 3] == '=') {
163+
if (input[idx + 2] == '=') {
164+
// Tow paddings.
165+
detail::decode_2_padding(input.substr(idx, 2), output);
166+
} else {
167+
// One padding.
168+
detail::decode_1_padding(input.substr(idx, 3), output);
169+
}
170+
} else {
171+
// No padding.
172+
detail::decode(input.substr(idx, 4), output);
173+
}
174+
175+
return output;
176+
}
177+
178+
} // namespace base64
179+
180+
} // namespace executor
181+
} // namespace torch

examples/models/llama2/tokenizer/targets.bzl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@ def define_common_targets():
55
name = "tokenizer",
66
srcs = [
77
"bpe_tokenizer.cpp",
8+
"tiktoken.cpp",
89
],
910
exported_headers = [
1011
"tokenizer.h",
1112
"bpe_tokenizer.h",
13+
"tiktoken.h",
14+
"base64.h",
1215
],
1316
exported_deps = [
1417
"//executorch/runtime/core/exec_aten:lib",
@@ -17,6 +20,9 @@ def define_common_targets():
1720
visibility = [
1821
"@EXECUTORCH_CLIENTS",
1922
],
23+
exported_external_deps = [
24+
"re2",
25+
],
2026
)
2127

2228
runtime.python_library(

examples/models/llama2/tokenizer/test/targets.bzl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,36 @@ def define_common_targets():
2020
},
2121
)
2222

23+
runtime.cxx_test(
24+
name = "test_tiktoken",
25+
srcs = [
26+
"test_tiktoken.cpp",
27+
],
28+
deps = [
29+
"//executorch/examples/models/llama2/tokenizer:tokenizer",
30+
],
31+
env = {
32+
"RESOURCES_PATH": "$(location :resources_fb_only)/resources",
33+
},
34+
external_deps = [
35+
"re2",
36+
],
37+
)
38+
2339
runtime.filegroup(
2440
name = "resources",
2541
srcs = native.glob([
2642
"resources/**",
2743
]),
2844
)
2945

46+
runtime.filegroup(
47+
name = "resources_fb_only",
48+
srcs = native.glob([
49+
"resources/fb/**",
50+
]),
51+
)
52+
3053
runtime.python_test(
3154
name = "test_tokenizer_py",
3255
srcs = [

0 commit comments

Comments
 (0)