Skip to content

Commit 4a16f14

Browse files
authored
Merge pull request #660 from tm-robinson/651-add-tokenization-for-ollama-and-mistral
651 add tokenization for ollama and mistral
2 parents c64ce88 + dc4a76b commit 4a16f14

File tree

7 files changed

+186
-33
lines changed

7 files changed

+186
-33
lines changed

scrapegraphai/graphs/abstract_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class AbstractGraph(ABC):
4040
... return graph
4141
...
4242
>>> my_graph = MyGraph("Example Graph",
43-
{"llm": {"model": "openai/gpt-3.5-turbo"}}, "example_source")
43+
{"llm": {"model": "gpt-3.5-turbo"}}, "example_source")
4444
>>> result = my_graph.run()
4545
"""
4646

scrapegraphai/nodes/parse_node.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ def __init__(
4141
True if node_config is None else node_config.get("parse_html", True)
4242
)
4343

44+
self.llm_model = node_config.get("llm_model")
45+
self.chunk_size = node_config.get("chunk_size")
46+
4447
def execute(self, state: dict) -> dict:
4548
"""
4649
Executes the node's logic to parse the HTML document content and split it into chunks.
@@ -69,19 +72,21 @@ def execute(self, state: dict) -> dict:
6972
docs_transformed = docs_transformed[0]
7073

7174
chunks = split_text_into_chunks(text=docs_transformed.page_content,
72-
chunk_size=self.node_config.get("chunk_size", 4096)-250)
75+
chunk_size=self.chunk_size-250, model=self.llm_model)
7376
else:
7477
docs_transformed = docs_transformed[0]
7578

76-
chunk_size = self.node_config.get("chunk_size", 4096)
79+
chunk_size = self.chunk_size
7780
chunk_size = min(chunk_size - 500, int(chunk_size * 0.9))
7881

7982
if isinstance(docs_transformed, Document):
8083
chunks = split_text_into_chunks(text=docs_transformed.page_content,
81-
chunk_size=chunk_size)
84+
chunk_size=chunk_size,
85+
model=self.llm_model)
8286
else:
8387
chunks = split_text_into_chunks(text=docs_transformed,
84-
chunk_size=chunk_size)
88+
chunk_size=chunk_size,
89+
model=self.llm_model)
8590

8691
state.update({self.output[0]: chunks})
8792

scrapegraphai/utils/split_text_into_chunks.py

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
"""
44
from typing import List
55
from .tokenizer import num_tokens_calculus # Import the new tokenizing function
6+
from langchain_core.language_models.chat_models import BaseChatModel
67

7-
def split_text_into_chunks(text: str, chunk_size: int) -> List[str]:
8+
def split_text_into_chunks(text: str, chunk_size: int, model: BaseChatModel, use_semchunk=True) -> List[str]:
89
"""
910
Splits the text into chunks based on the number of tokens.
1011
@@ -15,26 +16,43 @@ def split_text_into_chunks(text: str, chunk_size: int) -> List[str]:
1516
Returns:
1617
List[str]: A list of text chunks.
1718
"""
18-
tokens = num_tokens_calculus(text)
19-
if tokens <= chunk_size:
20-
return [text]
21-
22-
chunks = []
23-
current_chunk = []
24-
current_length = 0
25-
26-
words = text.split()
27-
for word in words:
28-
word_tokens = num_tokens_calculus(word)
29-
if current_length + word_tokens > chunk_size:
30-
chunks.append(' '.join(current_chunk))
31-
current_chunk = [word]
32-
current_length = word_tokens
33-
else:
34-
current_chunk.append(word)
35-
current_length += word_tokens
3619

37-
if current_chunk:
38-
chunks.append(' '.join(current_chunk))
20+
if use_semchunk:
21+
from semchunk import chunk
22+
def count_tokens(text):
23+
return num_tokens_calculus(text, model)
24+
25+
chunk_size = min(chunk_size - 500, int(chunk_size * 0.9))
26+
27+
chunks = chunk(text=text,
28+
chunk_size=chunk_size,
29+
token_counter=count_tokens,
30+
memoize=False)
31+
return chunks
32+
33+
else:
34+
35+
tokens = num_tokens_calculus(text, model)
36+
37+
if tokens <= chunk_size:
38+
return [text]
39+
40+
chunks = []
41+
current_chunk = []
42+
current_length = 0
43+
44+
words = text.split()
45+
for word in words:
46+
word_tokens = num_tokens_calculus(word, model)
47+
if current_length + word_tokens > chunk_size:
48+
chunks.append(' '.join(current_chunk))
49+
current_chunk = [word]
50+
current_length = word_tokens
51+
else:
52+
current_chunk.append(word)
53+
current_length += word_tokens
54+
55+
if current_chunk:
56+
chunks.append(' '.join(current_chunk))
3957

40-
return chunks
58+
return chunks

scrapegraphai/utils/tokenizer.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,29 @@
1+
"""
2+
Module for counting tokens and splitting text into chunks
13
"""
2-
Module for calculting the token_for_openai
3-
"""
4-
import tiktoken
4+
from typing import List
5+
from langchain_openai import ChatOpenAI
6+
from langchain_ollama import ChatOllama
7+
from langchain_mistralai import ChatMistralAI
8+
from langchain_core.language_models.chat_models import BaseChatModel
59

6-
def num_tokens_calculus(string: str) -> int:
10+
def num_tokens_calculus(string: str, llm_model: BaseChatModel) -> int:
711
"""Returns the number of tokens in a text string."""
8-
encoding = tiktoken.get_encoding("cl100k_base")
9-
num_tokens = len(encoding.encode(string))
12+
13+
if isinstance(llm_model, ChatOpenAI):
14+
from .tokenizers.tokenizer_openai import num_tokens_openai
15+
num_tokens_fn = num_tokens_openai
16+
17+
elif isinstance(llm_model, ChatMistralAI):
18+
from .tokenizers.tokenizer_mistral import num_tokens_mistral
19+
num_tokens_fn = num_tokens_mistral
20+
21+
elif isinstance(llm_model, ChatOllama):
22+
from .tokenizers.tokenizer_ollama import num_tokens_ollama
23+
num_tokens_fn = num_tokens_ollama
24+
25+
else:
26+
raise NotImplementedError(f"There is no tokenization implementation for model '{llm_model}'")
27+
28+
num_tokens = num_tokens_fn(string, llm_model)
1029
return num_tokens
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""
2+
Tokenization utilities for Mistral models
3+
"""
4+
from mistral_common.protocol.instruct.messages import UserMessage
5+
from mistral_common.protocol.instruct.request import ChatCompletionRequest
6+
from mistral_common.protocol.instruct.tool_calls import Function, Tool
7+
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
8+
from langchain_core.language_models.chat_models import BaseChatModel
9+
from ..logging import get_logger
10+
11+
12+
def num_tokens_mistral(text: str, llm_model:BaseChatModel) -> int:
13+
"""
14+
Estimate the number of tokens in a given text using Mistral's tokenization method,
15+
adjusted for different Mistral models.
16+
17+
Args:
18+
text (str): The text to be tokenized and counted.
19+
llm_model (BaseChatModel): The specific Mistral model to adjust tokenization.
20+
21+
Returns:
22+
int: The number of tokens in the text.
23+
"""
24+
25+
logger = get_logger()
26+
27+
logger.debug(f"Counting tokens for text of {len(text)} characters")
28+
try:
29+
model = llm_model.model
30+
except AttributeError:
31+
raise NotImplementedError(f"The model provider you are using ('{llm_model}') "
32+
"does not give us a model name so we cannot identify which encoding to use")
33+
34+
tokenizer = MistralTokenizer.from_model(model)
35+
36+
tokenized = tokenizer.encode_chat_completion(
37+
ChatCompletionRequest(
38+
tools=[],
39+
messages=[
40+
UserMessage(content=text),
41+
],
42+
model=model,
43+
)
44+
)
45+
tokens = tokenized.tokens
46+
return len(tokens)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"""
2+
Tokenization utilities for Ollama models
3+
"""
4+
from langchain_core.language_models.chat_models import BaseChatModel
5+
from ..logging import get_logger
6+
7+
def num_tokens_ollama(text: str, llm_model:BaseChatModel) -> int:
8+
"""
9+
Estimate the number of tokens in a given text using Ollama's tokenization method,
10+
adjusted for different Ollama models.
11+
12+
Args:
13+
text (str): The text to be tokenized and counted.
14+
llm_model (BaseChatModel): The specific Ollama model to adjust tokenization.
15+
16+
Returns:
17+
int: The number of tokens in the text.
18+
"""
19+
20+
logger = get_logger()
21+
22+
logger.debug(f"Counting tokens for text of {len(text)} characters")
23+
24+
# Use langchain token count implementation
25+
# NB: https://github.com/ollama/ollama/issues/1716#issuecomment-2074265507
26+
tokens = llm_model.get_num_tokens(text)
27+
return tokens
28+
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""
2+
Tokenization utilities for OpenAI models
3+
"""
4+
import tiktoken
5+
from langchain_core.language_models.chat_models import BaseChatModel
6+
from ..logging import get_logger
7+
8+
def num_tokens_openai(text: str, llm_model:BaseChatModel) -> int:
9+
"""
10+
Estimate the number of tokens in a given text using OpenAI's tokenization method,
11+
adjusted for different OpenAI models.
12+
13+
Args:
14+
text (str): The text to be tokenized and counted.
15+
llm_model (BaseChatModel): The specific OpenAI model to adjust tokenization.
16+
17+
Returns:
18+
int: The number of tokens in the text.
19+
"""
20+
21+
logger = get_logger()
22+
23+
logger.debug(f"Counting tokens for text of {len(text)} characters")
24+
try:
25+
model = llm_model.model_name
26+
except AttributeError:
27+
raise NotImplementedError(f"The model provider you are using ('{llm_model}') "
28+
"does not give us a model name so we cannot identify which encoding to use")
29+
30+
try:
31+
encoding = tiktoken.encoding_for_model(model)
32+
except KeyError:
33+
raise NotImplementedError(f"Tiktoken does not support identifying the encoding for "
34+
"the model '{model}'")
35+
36+
num_tokens = len(encoding.encode(text))
37+
return num_tokens

0 commit comments

Comments
 (0)