@@ -2,8 +2,11 @@ use fancy_regex::Regex;
2
2
use mlua:: prelude:: * ;
3
3
use rustc_hash:: FxHashMap as HashMap ;
4
4
use std:: collections:: HashSet ;
5
+ use std:: fs:: File ;
6
+ use std:: io:: { BufRead , BufReader } ;
5
7
use std:: sync:: { Arc , Mutex } ;
6
8
use std:: thread;
9
+ use base64;
7
10
8
11
#[ cfg( feature = "multithreading" ) ]
9
12
const MAX_NUM_THREADS : usize = 128 ;
@@ -191,12 +194,12 @@ pub fn tiktoken_core(lua: &mlua::Lua) -> LuaResult<LuaTable> {
191
194
192
195
let _new = lua. create_function (
193
196
move |_,
194
- ( encoder , special_tokens_encoder, pattern) : (
195
- HashMap < LuaString , usize > ,
197
+ ( encoder_path , special_tokens_encoder, pattern) : (
198
+ String ,
196
199
HashMap < String , usize > ,
197
200
String ,
198
201
) | {
199
- new ( & * state, encoder , special_tokens_encoder, pattern) ;
202
+ new ( & * state, encoder_path , special_tokens_encoder, pattern) ;
200
203
Ok ( ( ) )
201
204
} ,
202
205
) ?;
@@ -210,14 +213,21 @@ pub fn tiktoken_core(lua: &mlua::Lua) -> LuaResult<LuaTable> {
210
213
211
214
fn new (
212
215
state : & State ,
213
- iencoder : HashMap < LuaString , usize > ,
216
+ encoder_path : String ,
214
217
special_tokens_encoder : HashMap < String , usize > ,
215
218
pattern : String ,
216
219
) {
217
- let encoder: HashMap < Vec < u8 > , usize > = iencoder
218
- . into_iter ( )
219
- . map ( |( k, v) | ( k. as_bytes ( ) . to_vec ( ) , v) )
220
- . collect ( ) ;
220
+ let mut encoder: HashMap < Vec < u8 > , usize > = HashMap :: default ( ) ;
221
+ // Read the encoder file each line is a base64 encoded token and rank separated by a space
222
+ let file = File :: open ( encoder_path) . unwrap ( ) ;
223
+ let reader = BufReader :: new ( file) ;
224
+ for line in reader. lines ( ) {
225
+ let line = line. unwrap ( ) ;
226
+ let mut parts = line. split_whitespace ( ) ;
227
+ let token = base64:: decode ( parts. next ( ) . unwrap ( ) . as_bytes ( ) ) . unwrap ( ) ;
228
+ let rank = parts. next ( ) . unwrap ( ) . parse ( ) . unwrap ( ) ;
229
+ encoder. insert ( token, rank) ;
230
+ }
221
231
let regex = Regex :: new ( & pattern)
222
232
. map_err ( |e| mlua:: Error :: external ( e) )
223
233
. unwrap ( ) ;
@@ -230,11 +240,6 @@ fn new(
230
240
. map_err ( |e| mlua:: Error :: external ( e) )
231
241
. unwrap ( )
232
242
} ;
233
- let decoder: HashMap < usize , Vec < u8 > > = encoder. iter ( ) . map ( |( k, v) | ( * v, k. clone ( ) ) ) . collect ( ) ;
234
- assert ! (
235
- encoder. len( ) == decoder. len( ) ,
236
- "Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?"
237
- ) ;
238
243
let special_tokens_decoder: HashMap < usize , Vec < u8 > > = special_tokens_encoder
239
244
. iter ( )
240
245
. map ( |( k, v) | ( * v, k. as_bytes ( ) . to_vec ( ) ) )
@@ -245,7 +250,8 @@ fn new(
245
250
* core_bpe_lock = Some ( CoreBPENative {
246
251
encoder,
247
252
special_tokens_encoder,
248
- decoder,
253
+ // empty decoder
254
+ decoder : HashMap :: default ( ) ,
249
255
special_tokens_decoder,
250
256
regex_tls : ( 0 ..MAX_NUM_THREADS ) . map ( |_| regex. clone ( ) ) . collect ( ) ,
251
257
special_regex_tls : ( 0 ..MAX_NUM_THREADS )
0 commit comments