Skip to content

Commit 1a21957

Browse files
committed
0.2.0 - fix lua 5.1
1 parent 07fcbd0 commit 1a21957

File tree

6 files changed

+420
-69
lines changed

6 files changed

+420
-69
lines changed

Cargo.lock

Lines changed: 7 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ fancy-regex = "0.11.0"
1515
regex = "1.8.3"
1616
rustc-hash = "1.1.0"
1717
bstr = "1.5.0"
18+
base64 = "0.21.7"
1819

1920
[features]
2021
lua54 = ["mlua/lua54"]

src/lib.rs

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@ use fancy_regex::Regex;
22
use mlua::prelude::*;
33
use rustc_hash::FxHashMap as HashMap;
44
use std::collections::HashSet;
5+
use std::fs::File;
6+
use std::io::{BufRead, BufReader};
57
use std::sync::{Arc, Mutex};
68
use std::thread;
9+
use base64;
710

811
#[cfg(feature = "multithreading")]
912
const MAX_NUM_THREADS: usize = 128;
@@ -191,12 +194,12 @@ pub fn tiktoken_core(lua: &mlua::Lua) -> LuaResult<LuaTable> {
191194

192195
let _new = lua.create_function(
193196
move |_,
194-
(encoder, special_tokens_encoder, pattern): (
195-
HashMap<LuaString, usize>,
197+
(encoder_path, special_tokens_encoder, pattern): (
198+
String,
196199
HashMap<String, usize>,
197200
String,
198201
)| {
199-
new(&*state, encoder, special_tokens_encoder, pattern);
202+
new(&*state, encoder_path, special_tokens_encoder, pattern);
200203
Ok(())
201204
},
202205
)?;
@@ -210,14 +213,21 @@ pub fn tiktoken_core(lua: &mlua::Lua) -> LuaResult<LuaTable> {
210213

211214
fn new(
212215
state: &State,
213-
iencoder: HashMap<LuaString, usize>,
216+
encoder_path: String,
214217
special_tokens_encoder: HashMap<String, usize>,
215218
pattern: String,
216219
) {
217-
let encoder: HashMap<Vec<u8>, usize> = iencoder
218-
.into_iter()
219-
.map(|(k, v)| (k.as_bytes().to_vec(), v))
220-
.collect();
220+
let mut encoder: HashMap<Vec<u8>, usize> = HashMap::default();
221+
// Read the encoder file each line is a base64 encoded token and rank separated by a space
222+
let file = File::open(encoder_path).unwrap();
223+
let reader = BufReader::new(file);
224+
for line in reader.lines() {
225+
let line = line.unwrap();
226+
let mut parts = line.split_whitespace();
227+
let token = base64::decode(parts.next().unwrap().as_bytes()).unwrap();
228+
let rank = parts.next().unwrap().parse().unwrap();
229+
encoder.insert(token, rank);
230+
}
221231
let regex = Regex::new(&pattern)
222232
.map_err(|e| mlua::Error::external(e))
223233
.unwrap();
@@ -230,11 +240,6 @@ fn new(
230240
.map_err(|e| mlua::Error::external(e))
231241
.unwrap()
232242
};
233-
let decoder: HashMap<usize, Vec<u8>> = encoder.iter().map(|(k, v)| (*v, k.clone())).collect();
234-
assert!(
235-
encoder.len() == decoder.len(),
236-
"Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?"
237-
);
238243
let special_tokens_decoder: HashMap<usize, Vec<u8>> = special_tokens_encoder
239244
.iter()
240245
.map(|(k, v)| (*v, k.as_bytes().to_vec()))
@@ -245,7 +250,8 @@ fn new(
245250
*core_bpe_lock = Some(CoreBPENative {
246251
encoder,
247252
special_tokens_encoder,
248-
decoder,
253+
// empty decoder
254+
decoder: HashMap::default(),
249255
special_tokens_decoder,
250256
regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(),
251257
special_regex_tls: (0..MAX_NUM_THREADS)

0 commit comments

Comments
 (0)