Skip to content

feat: added Bedrock and Mistral to exec info #680

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions scrapegraphai/graphs/base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import warnings
from typing import Tuple
from ..telemetry import log_graph_execution
from ..utils import CustomOpenAiCallbackManager
from ..utils import CustomLLMCallbackManager

class BaseGraph:
"""
Expand Down Expand Up @@ -52,7 +52,7 @@ def __init__(self, nodes: list, edges: list, entry_point: str, use_burr: bool =
self.entry_point = entry_point.node_name
self.graph_name = graph_name
self.initial_state = {}
self.callback_manager = CustomOpenAiCallbackManager()
self.callback_manager = CustomLLMCallbackManager()

if nodes[0].node_name != entry_point.node_name:
# raise a warning if the entry point is not the first node in the list
Expand Down Expand Up @@ -108,6 +108,7 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
error_node = None
source_type = None
llm_model = None
llm_model_name = None
embedder_model = None
source = []
prompt = None
Expand Down Expand Up @@ -135,9 +136,11 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
if hasattr(current_node, "llm_model") and llm_model is None:
llm_model = current_node.llm_model
if hasattr(llm_model, "model_name"):
llm_model = llm_model.model_name
llm_model_name = llm_model.model_name
elif hasattr(llm_model, "model"):
llm_model = llm_model.model
llm_model_name = llm_model.model
elif hasattr(llm_model, "model_id"):
llm_model_name = llm_model.model_id

if hasattr(current_node, "embedder_model") and embedder_model is None:
embedder_model = current_node.embedder_model
Expand All @@ -155,7 +158,7 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
except Exception as e:
schema = None

with self.callback_manager.exclusive_get_openai_callback() as cb:
with self.callback_manager.exclusive_get_callback(llm_model, llm_model_name) as cb:
try:
result = current_node.execute(state)
except Exception as e:
Expand All @@ -166,7 +169,7 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
source=source,
prompt=prompt,
schema=schema,
llm_model=llm_model,
llm_model=llm_model_name,
embedder_model=embedder_model,
source_type=source_type,
execution_time=graph_execution_time,
Expand Down Expand Up @@ -222,7 +225,7 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
source=source,
prompt=prompt,
schema=schema,
llm_model=llm_model,
llm_model=llm_model_name,
embedder_model=embedder_model,
source_type=source_type,
content=content,
Expand Down
2 changes: 1 addition & 1 deletion scrapegraphai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
from .screenshot_scraping.text_detection import detect_text
from .tokenizer import num_tokens_calculus
from .split_text_into_chunks import split_text_into_chunks
from .custom_openai_callback import CustomOpenAiCallbackManager
from .llm_callback_manager import CustomLLMCallbackManager
157 changes: 157 additions & 0 deletions scrapegraphai/utils/custom_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
"""
Custom callback for LLM token usage statistics.

This module has been taken and modified from the OpenAI callback manager in langchian-community.
https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/callbacks/openai_info.py
"""
from contextlib import contextmanager
import threading
from typing import Any, Dict, List, Optional
from contextvars import ContextVar

from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import AIMessage
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.tracers.context import register_configure_hook

from .model_costs import MODEL_COST_PER_1K_TOKENS_INPUT, MODEL_COST_PER_1K_TOKENS_OUTPUT


def get_token_cost_for_model(
model_name: str, num_tokens: int, is_completion: bool = False
) -> float:
"""
Get the cost in USD for a given model and number of tokens.

Args:
model_name: Name of the model
num_tokens: Number of tokens.
is_completion: Whether the model is used for completion or not.
Defaults to False.

Returns:
Cost in USD.
"""
if model_name not in MODEL_COST_PER_1K_TOKENS_INPUT:
return 0.0
if is_completion:
return MODEL_COST_PER_1K_TOKENS_OUTPUT[model_name] * (num_tokens / 1000)

return MODEL_COST_PER_1K_TOKENS_INPUT[model_name] * (num_tokens / 1000)


class CustomCallbackHandler(BaseCallbackHandler):
"""Callback Handler that tracks LLMs info."""

total_tokens: int = 0
prompt_tokens: int = 0
completion_tokens: int = 0
successful_requests: int = 0
total_cost: float = 0.0

def __init__(self, llm_model_name: str) -> None:
super().__init__()
self._lock = threading.Lock()
self.model_name = llm_model_name if llm_model_name else "unknown"

def __repr__(self) -> str:
return (
f"Tokens Used: {self.total_tokens}\n"
f"\tPrompt Tokens: {self.prompt_tokens}\n"
f"\tCompletion Tokens: {self.completion_tokens}\n"
f"Successful Requests: {self.successful_requests}\n"
f"Total Cost (USD): ${self.total_cost}"
)

@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return True

def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Print out the prompts."""
pass

def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Print out the token."""
pass

def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Collect token usage."""
# Check for usage_metadata (langchain-core >= 0.2.2)
try:
generation = response.generations[0][0]
except IndexError:
generation = None
if isinstance(generation, ChatGeneration):
try:
message = generation.message
if isinstance(message, AIMessage):
usage_metadata = message.usage_metadata
else:
usage_metadata = None
except AttributeError:
usage_metadata = None
else:
usage_metadata = None
if usage_metadata:
token_usage = {"total_tokens": usage_metadata["total_tokens"]}
completion_tokens = usage_metadata["output_tokens"]
prompt_tokens = usage_metadata["input_tokens"]


else:
if response.llm_output is None:
return None

if "token_usage" not in response.llm_output:
with self._lock:
self.successful_requests += 1
return None

# compute tokens and cost for this request
token_usage = response.llm_output["token_usage"]
completion_tokens = token_usage.get("completion_tokens", 0)
prompt_tokens = token_usage.get("prompt_tokens", 0)
if self.model_name in MODEL_COST_PER_1K_TOKENS_INPUT:
completion_cost = get_token_cost_for_model(
self.model_name, completion_tokens, is_completion=True
)
prompt_cost = get_token_cost_for_model(self.model_name, prompt_tokens)
else:
completion_cost = 0
prompt_cost = 0

# update shared state behind lock
with self._lock:
self.total_cost += prompt_cost + completion_cost
self.total_tokens += token_usage.get("total_tokens", 0)
self.prompt_tokens += prompt_tokens
self.completion_tokens += completion_tokens
self.successful_requests += 1

def __copy__(self) -> "CustomCallbackHandler":
"""Return a copy of the callback handler."""
return self

def __deepcopy__(self, memo: Any) -> "CustomCallbackHandler":
"""Return a deep copy of the callback handler."""
return self


custom_callback: ContextVar[Optional[CustomCallbackHandler]] = ContextVar(
"custom_callback", default=None
)
register_configure_hook(custom_callback, True)

@contextmanager
def get_custom_callback(llm_model_name: str):
"""
Function to get custom callback for LLM token usage statistics.
"""
cb = CustomCallbackHandler(llm_model_name)
custom_callback.set(cb)
yield cb
custom_callback.set(None)
17 changes: 0 additions & 17 deletions scrapegraphai/utils/custom_openai_callback.py

This file was deleted.

38 changes: 38 additions & 0 deletions scrapegraphai/utils/llm_callback_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
This module provides a custom callback manager for the LLM models.
"""
import threading
from contextlib import contextmanager
from .custom_callback import get_custom_callback

from langchain_community.callbacks import get_openai_callback
from langchain_community.callbacks.manager import get_bedrock_anthropic_callback
from langchain_openai import ChatOpenAI, AzureChatOpenAI
from langchain_aws import ChatBedrock

class CustomLLMCallbackManager:
_lock = threading.Lock()

@contextmanager
def exclusive_get_callback(self, llm_model, llm_model_name):
if CustomLLMCallbackManager._lock.acquire(blocking=False):
if isinstance(llm_model, ChatOpenAI) or isinstance(llm_model, AzureChatOpenAI):
try:
with get_openai_callback() as cb:
yield cb
finally:
CustomLLMCallbackManager._lock.release()
elif isinstance(llm_model, ChatBedrock) and llm_model_name is not None and "claude" in llm_model_name:
try:
with get_bedrock_anthropic_callback() as cb:
yield cb
finally:
CustomLLMCallbackManager._lock.release()
else:
try:
with get_custom_callback(llm_model_name) as cb:
yield cb
finally:
CustomLLMCallbackManager._lock.release()
else:
yield None
105 changes: 105 additions & 0 deletions scrapegraphai/utils/model_costs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""
This file contains the cost of models per 1k tokens for input and output.
The file is on a best effort basis and may not be up to date. Any contributions are welcome.
"""
MODEL_COST_PER_1K_TOKENS_INPUT = {
### MistralAI
# General Purpose
"open-mistral-nemo": 0.00015,
"open-mistral-nemo-2407": 0.00015,
"mistral-large": 0.002,
"mistral-large-2407": 0.002,
"mistral-small": 0.0002,
"mistral-small-2409": 0.0002,
# Specialist Models
"codestral": 0.0002,
"codestral-2405": 0.0002,
"pixtral-12b": 0.00015,
"pixtral-12b-2409": 0.00015,
# Legacy Models
"open-mistral-7b": 0.00025,
"open-mixtral-8x7b": 0.0007,
"open-mixtral-8x22b": 0.002,
"mistral-small-latest": 0.001,
"mistral-medium-latest": 0.00275,

### Bedrock - not Claude
#AI21 Labs
"a121.ju-ultra-v1": 0.0188,
"a121.ju-mid-v1": 0.0125,
"ai21.jamba-instruct-v1:0": 0.0005,
# Meta - LLama
"meta.llama2-13b-chat-v1": 0.00075,
"meta.llama2-70b-chat-v1": 0.00195,
"meta.llama3-8b-instruct-v1:0": 0.0003,
"meta.llama3-70b-instruct-v1:0": 0.00265,
"meta.llama3-1-8b-instruct-v1:0": 0.00022,
"meta.llama3-1-70b-instruct-v1:0": 0.00099,
"meta.llama3-1-405b-instruct-v1:0": 0.00532,
# Cohere - Command
"cohere.command-text-v14": 0.0015,
"cohere.command-light-text-v14": 0.0003,
"cohere.command-r-v1:0": 0.0005,
"cohere.command-r-plus-v1:0": 0.003,
# Mistral
"mistral.mistral-7b-instruct-v0:2": 0.00015,
"mistral.mistral-large-2402-v1:0": 0.004,
"mistral.mistral-large-2407-v1:0": 0.002,
"mistral.mistral-small-2402-v1:0": 0.001,
"mistral.mixtral-7x8b-instruct-v0:1": 0.00045,
# Amazon - Titan
"amazon.titan-text-express-v1": 0.0002,
"amazon.titan-text-lite-v1": 0.00015,
"amazon.titan-text-premier-v1:0": 0.0005,
}

MODEL_COST_PER_1K_TOKENS_OUTPUT = {
### MistralAI
# General Purpose
"open-mistral-nemo": 0.00015,
"open-mistral-nemo-2407": 0.00015,
"mistral-large": 0.002,
"mistral-large-2407": 0.006,
"mistral-small": 0.0002,
"mistral-small-2409": 0.0006,
# Specialist Models
"codestral": 0.0006,
"codestral-2405": 0.0006,
"pixtral-12b": 0.00015,
"pixtral-12b-2409": 0.0006,
# Legacy Models
"open-mistral-7b": 0.00025,
"open-mixtral-8x7b": 0.0007,
"open-mixtral-8x22b": 0.006,
"mistral-small-latest": 0.003,
"mistral-medium-latest": 0.0081,

### Bedrock - not Claude
# AI21 Labs
"a121.ju-ultra-v1": 0.0188,
"a121.ju-mid-v1": 0.0125,
"ai21.jamba-instruct-v1:0": 0.0007,
# Meta - LLama
"meta.llama2-13b-chat-v1": 0.001,
"meta.llama2-70b-chat-v1": 0.00256,
"meta.llama3-8b-instruct-v1:0": 0.0006,
"meta.llama3-70b-instruct-v1:0": 0.0035,
"meta.llama3-1-8b-instruct-v1:0": 0.00022,
"meta.llama3-1-70b-instruct-v1:0": 0.00099,
"meta.llama3-1-405b-instruct-v1:0": 0.016,
# Cohere - Command
"cohere.command-text-v14": 0.002,
"cohere.command-light-text-v14": 0.0006,
"cohere.command-r-v1:0": 0.0015,
"cohere.command-r-plus-v1:0": 0.015,
# Mistral
"mistral.mistral-7b-instruct-v0:2": 0.0002,
"mistral.mistral-large-2402-v1:0": 0.012,
"mistral.mistral-large-2407-v1:0": 0.006,
"mistral.mistral-small-2402-v1:0": 0.003,
"mistral.mixtral-7x8b-instruct-v0:1": 0.0007,
# Amazon - Titan
"amazon.titan-text-express-v1": 0.0006,
"amazon.titan-text-lite-v1": 0.0002,
"amazon.titan-text-premier-v1:0": 0.0015,
}