Skip to content

Commit 8a37c6b

Browse files
feat: added Bedrock and Mistral to exec info
1 parent 4f8b55d commit 8a37c6b

File tree

6 files changed

+311
-25
lines changed

6 files changed

+311
-25
lines changed

scrapegraphai/graphs/base_graph.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import warnings
66
from typing import Tuple
77
from ..telemetry import log_graph_execution
8-
from ..utils import CustomOpenAiCallbackManager
8+
from ..utils import CustomLLMCallbackManager
99

1010
class BaseGraph:
1111
"""
@@ -52,7 +52,7 @@ def __init__(self, nodes: list, edges: list, entry_point: str, use_burr: bool =
5252
self.entry_point = entry_point.node_name
5353
self.graph_name = graph_name
5454
self.initial_state = {}
55-
self.callback_manager = CustomOpenAiCallbackManager()
55+
self.callback_manager = CustomLLMCallbackManager()
5656

5757
if nodes[0].node_name != entry_point.node_name:
5858
# raise a warning if the entry point is not the first node in the list
@@ -108,6 +108,7 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
108108
error_node = None
109109
source_type = None
110110
llm_model = None
111+
llm_model_name = None
111112
embedder_model = None
112113
source = []
113114
prompt = None
@@ -135,9 +136,11 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
135136
if hasattr(current_node, "llm_model") and llm_model is None:
136137
llm_model = current_node.llm_model
137138
if hasattr(llm_model, "model_name"):
138-
llm_model = llm_model.model_name
139+
llm_model_name = llm_model.model_name
139140
elif hasattr(llm_model, "model"):
140-
llm_model = llm_model.model
141+
llm_model_name = llm_model.model
142+
elif hasattr(llm_model, "model_id"):
143+
llm_model_name = llm_model.model_id
141144

142145
if hasattr(current_node, "embedder_model") and embedder_model is None:
143146
embedder_model = current_node.embedder_model
@@ -155,7 +158,7 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
155158
except Exception as e:
156159
schema = None
157160

158-
with self.callback_manager.exclusive_get_openai_callback() as cb:
161+
with self.callback_manager.exclusive_get_callback(llm_model, llm_model_name) as cb:
159162
try:
160163
result = current_node.execute(state)
161164
except Exception as e:
@@ -166,7 +169,7 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
166169
source=source,
167170
prompt=prompt,
168171
schema=schema,
169-
llm_model=llm_model,
172+
llm_model=llm_model_name,
170173
embedder_model=embedder_model,
171174
source_type=source_type,
172175
execution_time=graph_execution_time,
@@ -222,7 +225,7 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
222225
source=source,
223226
prompt=prompt,
224227
schema=schema,
225-
llm_model=llm_model,
228+
llm_model=llm_model_name,
226229
embedder_model=embedder_model,
227230
source_type=source_type,
228231
content=content,

scrapegraphai/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@
1717
from .screenshot_scraping.text_detection import detect_text
1818
from .tokenizer import num_tokens_calculus
1919
from .split_text_into_chunks import split_text_into_chunks
20-
from .custom_openai_callback import CustomOpenAiCallbackManager
20+
from .llm_callback_manager import CustomLLMCallbackManager
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
"""
2+
Custom callback for LLM token usage statistics.
3+
4+
This module has been taken and modified from the OpenAI callback manager in langchian-community.
5+
https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/callbacks/openai_info.py
6+
"""
7+
from contextlib import contextmanager
8+
import threading
9+
from typing import Any, Dict, List, Optional
10+
from contextvars import ContextVar
11+
12+
from langchain_core.callbacks import BaseCallbackHandler
13+
from langchain_core.messages import AIMessage
14+
from langchain_core.outputs import ChatGeneration, LLMResult
15+
from langchain_core.tracers.context import register_configure_hook
16+
17+
from .model_costs import MODEL_COST_PER_1K_TOKENS_INPUT, MODEL_COST_PER_1K_TOKENS_OUTPUT
18+
19+
20+
def get_token_cost_for_model(
21+
model_name: str, num_tokens: int, is_completion: bool = False
22+
) -> float:
23+
"""
24+
Get the cost in USD for a given model and number of tokens.
25+
26+
Args:
27+
model_name: Name of the model
28+
num_tokens: Number of tokens.
29+
is_completion: Whether the model is used for completion or not.
30+
Defaults to False.
31+
32+
Returns:
33+
Cost in USD.
34+
"""
35+
if model_name not in MODEL_COST_PER_1K_TOKENS_INPUT:
36+
return 0.0
37+
if is_completion:
38+
return MODEL_COST_PER_1K_TOKENS_OUTPUT[model_name] * (num_tokens / 1000)
39+
40+
return MODEL_COST_PER_1K_TOKENS_INPUT[model_name] * (num_tokens / 1000)
41+
42+
43+
class CustomCallbackHandler(BaseCallbackHandler):
44+
"""Callback Handler that tracks LLMs info."""
45+
46+
total_tokens: int = 0
47+
prompt_tokens: int = 0
48+
completion_tokens: int = 0
49+
successful_requests: int = 0
50+
total_cost: float = 0.0
51+
52+
def __init__(self, llm_model_name: str) -> None:
53+
super().__init__()
54+
self._lock = threading.Lock()
55+
self.model_name = llm_model_name if llm_model_name else "unknown"
56+
57+
def __repr__(self) -> str:
58+
return (
59+
f"Tokens Used: {self.total_tokens}\n"
60+
f"\tPrompt Tokens: {self.prompt_tokens}\n"
61+
f"\tCompletion Tokens: {self.completion_tokens}\n"
62+
f"Successful Requests: {self.successful_requests}\n"
63+
f"Total Cost (USD): ${self.total_cost}"
64+
)
65+
66+
@property
67+
def always_verbose(self) -> bool:
68+
"""Whether to call verbose callbacks even if verbose is False."""
69+
return True
70+
71+
def on_llm_start(
72+
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
73+
) -> None:
74+
"""Print out the prompts."""
75+
pass
76+
77+
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
78+
"""Print out the token."""
79+
pass
80+
81+
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
82+
"""Collect token usage."""
83+
# Check for usage_metadata (langchain-core >= 0.2.2)
84+
try:
85+
generation = response.generations[0][0]
86+
except IndexError:
87+
generation = None
88+
if isinstance(generation, ChatGeneration):
89+
try:
90+
message = generation.message
91+
if isinstance(message, AIMessage):
92+
usage_metadata = message.usage_metadata
93+
else:
94+
usage_metadata = None
95+
except AttributeError:
96+
usage_metadata = None
97+
else:
98+
usage_metadata = None
99+
if usage_metadata:
100+
token_usage = {"total_tokens": usage_metadata["total_tokens"]}
101+
completion_tokens = usage_metadata["output_tokens"]
102+
prompt_tokens = usage_metadata["input_tokens"]
103+
104+
105+
else:
106+
if response.llm_output is None:
107+
return None
108+
109+
if "token_usage" not in response.llm_output:
110+
with self._lock:
111+
self.successful_requests += 1
112+
return None
113+
114+
# compute tokens and cost for this request
115+
token_usage = response.llm_output["token_usage"]
116+
completion_tokens = token_usage.get("completion_tokens", 0)
117+
prompt_tokens = token_usage.get("prompt_tokens", 0)
118+
if self.model_name in MODEL_COST_PER_1K_TOKENS_INPUT:
119+
completion_cost = get_token_cost_for_model(
120+
self.model_name, completion_tokens, is_completion=True
121+
)
122+
prompt_cost = get_token_cost_for_model(self.model_name, prompt_tokens)
123+
else:
124+
completion_cost = 0
125+
prompt_cost = 0
126+
127+
# update shared state behind lock
128+
with self._lock:
129+
self.total_cost += prompt_cost + completion_cost
130+
self.total_tokens += token_usage.get("total_tokens", 0)
131+
self.prompt_tokens += prompt_tokens
132+
self.completion_tokens += completion_tokens
133+
self.successful_requests += 1
134+
135+
def __copy__(self) -> "CustomCallbackHandler":
136+
"""Return a copy of the callback handler."""
137+
return self
138+
139+
def __deepcopy__(self, memo: Any) -> "CustomCallbackHandler":
140+
"""Return a deep copy of the callback handler."""
141+
return self
142+
143+
144+
custom_callback: ContextVar[Optional[CustomCallbackHandler]] = ContextVar(
145+
"custom_callback", default=None
146+
)
147+
register_configure_hook(custom_callback, True)
148+
149+
@contextmanager
150+
def get_custom_callback(llm_model_name: str):
151+
"""
152+
Function to get custom callback for LLM token usage statistics.
153+
"""
154+
cb = CustomCallbackHandler(llm_model_name)
155+
custom_callback.set(cb)
156+
yield cb
157+
custom_callback.set(None)

scrapegraphai/utils/custom_openai_callback.py

Lines changed: 0 additions & 17 deletions
This file was deleted.
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
"""
2+
This module provides a custom callback manager for the LLM models.
3+
"""
4+
import threading
5+
from contextlib import contextmanager
6+
from .custom_callback import get_custom_callback
7+
8+
from langchain_community.callbacks import get_openai_callback
9+
from langchain_community.callbacks.manager import get_bedrock_anthropic_callback
10+
from langchain_openai import ChatOpenAI, AzureChatOpenAI
11+
from langchain_aws import ChatBedrock
12+
13+
class CustomLLMCallbackManager:
14+
_lock = threading.Lock()
15+
16+
@contextmanager
17+
def exclusive_get_callback(self, llm_model, llm_model_name):
18+
if CustomLLMCallbackManager._lock.acquire(blocking=False):
19+
if isinstance(llm_model, ChatOpenAI) or isinstance(llm_model, AzureChatOpenAI):
20+
try:
21+
with get_openai_callback() as cb:
22+
yield cb
23+
finally:
24+
CustomLLMCallbackManager._lock.release()
25+
elif isinstance(llm_model, ChatBedrock) and llm_model_name is not None and "claude" in llm_model_name:
26+
try:
27+
with get_bedrock_anthropic_callback() as cb:
28+
yield cb
29+
finally:
30+
CustomLLMCallbackManager._lock.release()
31+
else:
32+
try:
33+
with get_custom_callback(llm_model_name) as cb:
34+
yield cb
35+
finally:
36+
CustomLLMCallbackManager._lock.release()
37+
else:
38+
yield None

scrapegraphai/utils/model_costs.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
"""
2+
This file contains the cost of models per 1k tokens for input and output.
3+
The file is on a best effort basis and may not be up to date. Any contributions are welcome.
4+
"""
5+
MODEL_COST_PER_1K_TOKENS_INPUT = {
6+
### MistralAI
7+
# General Purpose
8+
"open-mistral-nemo": 0.00015,
9+
"open-mistral-nemo-2407": 0.00015,
10+
"mistral-large": 0.002,
11+
"mistral-large-2407": 0.002,
12+
"mistral-small": 0.0002,
13+
"mistral-small-2409": 0.0002,
14+
# Specialist Models
15+
"codestral": 0.0002,
16+
"codestral-2405": 0.0002,
17+
"pixtral-12b": 0.00015,
18+
"pixtral-12b-2409": 0.00015,
19+
# Legacy Models
20+
"open-mistral-7b": 0.00025,
21+
"open-mixtral-8x7b": 0.0007,
22+
"open-mixtral-8x22b": 0.002,
23+
"mistral-small-latest": 0.001,
24+
"mistral-medium-latest": 0.00275,
25+
26+
### Bedrock - not Claude
27+
#AI21 Labs
28+
"a121.ju-ultra-v1": 0.0188,
29+
"a121.ju-mid-v1": 0.0125,
30+
"ai21.jamba-instruct-v1:0": 0.0005,
31+
# Meta - LLama
32+
"meta.llama2-13b-chat-v1": 0.00075,
33+
"meta.llama2-70b-chat-v1": 0.00195,
34+
"meta.llama3-8b-instruct-v1:0": 0.0003,
35+
"meta.llama3-70b-instruct-v1:0": 0.00265,
36+
"meta.llama3-1-8b-instruct-v1:0": 0.00022,
37+
"meta.llama3-1-70b-instruct-v1:0": 0.00099,
38+
"meta.llama3-1-405b-instruct-v1:0": 0.00532,
39+
# Cohere - Command
40+
"cohere.command-text-v14": 0.0015,
41+
"cohere.command-light-text-v14": 0.0003,
42+
"cohere.command-r-v1:0": 0.0005,
43+
"cohere.command-r-plus-v1:0": 0.003,
44+
# Mistral
45+
"mistral.mistral-7b-instruct-v0:2": 0.00015,
46+
"mistral.mistral-large-2402-v1:0": 0.004,
47+
"mistral.mistral-large-2407-v1:0": 0.002,
48+
"mistral.mistral-small-2402-v1:0": 0.001,
49+
"mistral.mixtral-7x8b-instruct-v0:1": 0.00045,
50+
# Amazon - Titan
51+
"amazon.titan-text-express-v1": 0.0002,
52+
"amazon.titan-text-lite-v1": 0.00015,
53+
"amazon.titan-text-premier-v1:0": 0.0005,
54+
}
55+
56+
MODEL_COST_PER_1K_TOKENS_OUTPUT = {
57+
### MistralAI
58+
# General Purpose
59+
"open-mistral-nemo": 0.00015,
60+
"open-mistral-nemo-2407": 0.00015,
61+
"mistral-large": 0.002,
62+
"mistral-large-2407": 0.006,
63+
"mistral-small": 0.0002,
64+
"mistral-small-2409": 0.0006,
65+
# Specialist Models
66+
"codestral": 0.0006,
67+
"codestral-2405": 0.0006,
68+
"pixtral-12b": 0.00015,
69+
"pixtral-12b-2409": 0.0006,
70+
# Legacy Models
71+
"open-mistral-7b": 0.00025,
72+
"open-mixtral-8x7b": 0.0007,
73+
"open-mixtral-8x22b": 0.006,
74+
"mistral-small-latest": 0.003,
75+
"mistral-medium-latest": 0.0081,
76+
77+
### Bedrock - not Claude
78+
# AI21 Labs
79+
"a121.ju-ultra-v1": 0.0188,
80+
"a121.ju-mid-v1": 0.0125,
81+
"ai21.jamba-instruct-v1:0": 0.0007,
82+
# Meta - LLama
83+
"meta.llama2-13b-chat-v1": 0.001,
84+
"meta.llama2-70b-chat-v1": 0.00256,
85+
"meta.llama3-8b-instruct-v1:0": 0.0006,
86+
"meta.llama3-70b-instruct-v1:0": 0.0035,
87+
"meta.llama3-1-8b-instruct-v1:0": 0.00022,
88+
"meta.llama3-1-70b-instruct-v1:0": 0.00099,
89+
"meta.llama3-1-405b-instruct-v1:0": 0.016,
90+
# Cohere - Command
91+
"cohere.command-text-v14": 0.002,
92+
"cohere.command-light-text-v14": 0.0006,
93+
"cohere.command-r-v1:0": 0.0015,
94+
"cohere.command-r-plus-v1:0": 0.015,
95+
# Mistral
96+
"mistral.mistral-7b-instruct-v0:2": 0.0002,
97+
"mistral.mistral-large-2402-v1:0": 0.012,
98+
"mistral.mistral-large-2407-v1:0": 0.006,
99+
"mistral.mistral-small-2402-v1:0": 0.003,
100+
"mistral.mixtral-7x8b-instruct-v0:1": 0.0007,
101+
# Amazon - Titan
102+
"amazon.titan-text-express-v1": 0.0006,
103+
"amazon.titan-text-lite-v1": 0.0002,
104+
"amazon.titan-text-premier-v1:0": 0.0015,
105+
}

0 commit comments

Comments
 (0)