Skip to content

Commit f1573f2

Browse files
committed
Move into class
1 parent 421288b commit f1573f2

File tree

2 files changed

+54
-13
lines changed

2 files changed

+54
-13
lines changed
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import json
8+
import os
9+
from typing import List, Optional
10+
11+
from tokenizers import Tokenizer
12+
13+
14+
class HuggingFaceTokenizer:
15+
"""
16+
Tokenizing and encoding/decoding text using the Hugging face tokenizer.
17+
"""
18+
def __init__(self, model_path: str, config_path: Optional[str] = None):
19+
"""
20+
Initializes the Tokenizer with a tokenizer.json from HuggingFace.
21+
22+
Args:
23+
model_path (str): The path to the Tiktoken model file.
24+
"""
25+
assert os.path.isfile(model_path), model_path
26+
27+
self.model = tokenizer = Tokenizer.from_file(model_path)
28+
29+
self.n_words: int = tokenizer.get_vocab_size()
30+
if config_path:
31+
with open(config_path) as f:
32+
tokenizer_config = json.load(f)
33+
self.bos_id = self.model.token_to_id(tokenizer_config["bos_token"])if tokenizer_config["bos_token"] else None
34+
self.eos_id = self.model.token_to_id(tokenizer_config["eos_token"])
35+
else: # Fallback guess.
36+
self.bos_id = self.model.token_to_id("<|begin_of_text|>")
37+
self.eos_id = self.model.token_to_id("<|endoftext|>")
38+
39+
self.stop_tokens = [
40+
self.eos_id,
41+
]
42+
43+
def encode(self, s: str, *, bos: bool, eos: bool) -> List[int]:
44+
assert type(s) is str
45+
return self.model.encode(s).ids
46+
47+
def decode(self, t: List[int]) -> str:
48+
return self.model.decode(t)
49+
50+
def decode_token(self, t: int) -> str:
51+
return self.model.decode([t])
52+

extension/llm/tokenizer/utils.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,9 @@
1515

1616
def get_tokenizer(tokenizer_path: str, tokenizer_config_path: Optional[str] = None):
1717
if tokenizer_path.endswith(".json"):
18-
from tokenizers import Tokenizer
18+
from executorch.extension.llm.tokenizer.hf_tokenizer import HuggingFaceTokenizer
1919

20-
tokenizer = Tokenizer.from_file(tokenizer_path)
21-
22-
# Keep in line with internal tokenizer apis.
23-
tokenizer.n_words = tokenizer.get_vocab_size()
24-
tokenizer.decode_token = lambda token: tokenizer.decode([token])
25-
original_encode = tokenizer.encode
26-
tokenizer.encode = lambda prompt, **kwargs: original_encode(prompt).ids
27-
28-
if tokenizer_config_path:
29-
with open(tokenizer_config_path) as f:
30-
tokenizer_config = json.load(f)
31-
tokenizer.eos_id = tokenizer.token_to_id(tokenizer_config["eos_token"])
20+
tokenizer = HuggingFaceTokenizer(tokenizer_path, tokenizer_config_path)
3221
else:
3322
try:
3423
tokenizer = SentencePieceTokenizer(model_path=str(tokenizer_path))

0 commit comments

Comments
 (0)