Skip to content

feat: 133 support claude3 haiku and others using litellm #137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions SECURITY.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
## Reporting a Vulnerability

For reporting a vulnerability contact directly [email protected]

75 changes: 39 additions & 36 deletions scrapegraphai/graphs/abstract_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings

from ..helpers import models_tokens
from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI
from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI, Claude


class AbstractGraph(ABC):
Expand All @@ -22,7 +22,8 @@ class AbstractGraph(ABC):
source (str): The source of the graph.
config (dict): Configuration parameters for the graph.
llm_model: An instance of a language model client, configured for generating answers.
embedder_model: An instance of an embedding model client, configured for generating embeddings.
embedder_model: An instance of an embedding model client,
configured for generating embeddings.
verbose (bool): A flag indicating whether to show print statements during execution.
headless (bool): A flag indicating whether to run the graph in headless mode.

Expand All @@ -47,8 +48,8 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
self.source = source
self.config = config
self.llm_model = self._create_llm(config["llm"], chat=True)
self.embedder_model = self._create_default_embedder(
) if "embeddings" not in config else self._create_embedder(
self.embedder_model = self._create_default_embedder(
) if "embeddings" not in config else self._create_embedder(
config["embeddings"])

# Set common configuration parameters
Expand All @@ -61,23 +62,21 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
self.final_state = None
self.execution_info = None


def _set_model_token(self, llm):

if 'Azure' in str(type(llm)):
try:
self.model_token = models_tokens["azure"][llm.model_name]
except KeyError:
raise KeyError("Model not supported")

elif 'HuggingFaceEndpoint' in str(type(llm)):
if 'mistral' in llm.repo_id:
try:
self.model_token = models_tokens['mistral'][llm.repo_id]
except KeyError:
raise KeyError("Model not supported")


def _create_llm(self, llm_config: dict, chat=False) -> object:
"""
Create a large language model instance based on the configuration provided.
Expand All @@ -103,31 +102,36 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
if chat:
self._set_model_token(llm_params['model_instance'])
return llm_params['model_instance']

# Instantiate the language model based on the model name
if "gpt-" in llm_params["model"]:
try:
self.model_token = models_tokens["openai"][llm_params["model"]]
except KeyError:
raise KeyError("Model not supported")
except KeyError as exc:
raise KeyError("Model not supported") from exc
return OpenAI(llm_params)

elif "azure" in llm_params["model"]:
# take the model after the last dash
llm_params["model"] = llm_params["model"].split("/")[-1]
try:
self.model_token = models_tokens["azure"][llm_params["model"]]
except KeyError:
raise KeyError("Model not supported")
except KeyError as exc:
raise KeyError("Model not supported") from exc
return AzureOpenAI(llm_params)

elif "gemini" in llm_params["model"]:
try:
self.model_token = models_tokens["gemini"][llm_params["model"]]
except KeyError:
raise KeyError("Model not supported")
except KeyError as exc:
raise KeyError("Model not supported") from exc
return Gemini(llm_params)

elif "claude" in llm_params["model"]:
try:
self.model_token = models_tokens["claude"][llm_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return Claude(llm_params)
elif "ollama" in llm_params["model"]:
llm_params["model"] = llm_params["model"].split("/")[-1]

Expand All @@ -138,8 +142,8 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
elif llm_params["model"] in models_tokens["ollama"]:
try:
self.model_token = models_tokens["ollama"][llm_params["model"]]
except KeyError:
raise KeyError("Model not supported")
except KeyError as exc:
raise KeyError("Model not supported") from exc
else:
self.model_token = 8192
except AttributeError:
Expand All @@ -149,25 +153,25 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
elif "hugging_face" in llm_params["model"]:
try:
self.model_token = models_tokens["hugging_face"][llm_params["model"]]
except KeyError:
raise KeyError("Model not supported")
except KeyError as exc:
raise KeyError("Model not supported") from exc
return HuggingFace(llm_params)
elif "groq" in llm_params["model"]:
llm_params["model"] = llm_params["model"].split("/")[-1]

try:
self.model_token = models_tokens["groq"][llm_params["model"]]
except KeyError:
raise KeyError("Model not supported")
except KeyError as exc:
raise KeyError("Model not supported") from exc
return Groq(llm_params)
elif "bedrock" in llm_params["model"]:
llm_params["model"] = llm_params["model"].split("/")[-1]
model_id = llm_params["model"]

try:
self.model_token = models_tokens["bedrock"][llm_params["model"]]
except KeyError:
raise KeyError("Model not supported")
except KeyError as exc:
raise KeyError("Model not supported") from exc
return Bedrock({
"model_id": model_id,
"model_kwargs": {
Expand All @@ -177,7 +181,7 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
else:
raise ValueError(
"Model provided by the configuration not supported")

def _create_default_embedder(self) -> object:
"""
Create an embedding model instance based on the chosen llm model.
Expand Down Expand Up @@ -208,7 +212,7 @@ def _create_default_embedder(self) -> object:
return BedrockEmbeddings(client=None, model_id=self.llm_model.model_id)
else:
raise ValueError("Embedding Model missing or not supported")

def _create_embedder(self, embedder_config: dict) -> object:
"""
Create an embedding model instance based on the configuration provided.
Expand Down Expand Up @@ -237,27 +241,27 @@ def _create_embedder(self, embedder_config: dict) -> object:
embedder_config["model"] = embedder_config["model"].split("/")[-1]
try:
models_tokens["ollama"][embedder_config["model"]]
except KeyError:
raise KeyError("Model not supported")
except KeyError as exc:
raise KeyError("Model not supported") from exc
return OllamaEmbeddings(**embedder_config)

elif "hugging_face" in embedder_config["model"]:
try:
models_tokens["hugging_face"][embedder_config["model"]]
except KeyError:
raise KeyError("Model not supported")
except KeyError as exc:
raise KeyError("Model not supported")from exc
return HuggingFaceHubEmbeddings(model=embedder_config["model"])

elif "bedrock" in embedder_config["model"]:
embedder_config["model"] = embedder_config["model"].split("/")[-1]
try:
models_tokens["bedrock"][embedder_config["model"]]
except KeyError:
raise KeyError("Model not supported")
except KeyError as exc:
raise KeyError("Model not supported") from exc
return BedrockEmbeddings(client=None, model_id=embedder_config["model"])
else:
raise ValueError(
"Model provided by the configuration not supported")
"Model provided by the configuration not supported")

def get_state(self, key=None) -> dict:
"""""
Expand All @@ -281,7 +285,7 @@ def get_execution_info(self):
Returns:
dict: The execution information of the graph.
"""

return self.execution_info

@abstractmethod
Expand All @@ -297,4 +301,3 @@ def run(self) -> str:
Abstract method to execute the graph and return the result.
"""
pass

1 change: 1 addition & 0 deletions scrapegraphai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@
from .hugging_face import HuggingFace
from .groq import Groq
from .bedrock import Bedrock
from .claude import Claude
19 changes: 19 additions & 0 deletions scrapegraphai/models/claude.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""
Claude Module
"""

from langchain_anthropic import ChatAnthropic


class Claude(ChatAnthropic):
"""
A wrapper for the ChatAnthropic class that provides default configuration
and could be extended with additional methods if needed.

Args:
llm_config (dict): Configuration parameters for the language model
(e.g., model="claude_instant")
"""

def __init__(self, llm_config: dict):
super().__init__(**llm_config)
3 changes: 2 additions & 1 deletion scrapegraphai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ class Gemini(ChatGoogleGenerativeAI):
and could be extended with additional methods if needed.

Args:
llm_config (dict): Configuration parameters for the language model (e.g., model="gemini-pro")
llm_config (dict): Configuration parameters for the language model
(e.g., model="gemini-pro")
"""

def __init__(self, llm_config: dict):
Expand Down