Skip to content

Commit c06e1e9

Browse files
authored
Merge pull request #137 from VinciGit00/133-support-claude3
feat: 133 support claude3
2 parents 5aa600c + 0ab7272 commit c06e1e9

File tree

5 files changed

+62
-37
lines changed

5 files changed

+62
-37
lines changed

SECURITY.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
## Reporting a Vulnerability
44

55
For reporting a vulnerability contact directly [email protected]
6+

scrapegraphai/graphs/abstract_graph.py

Lines changed: 39 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
1111

1212
from ..helpers import models_tokens
13-
from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI
13+
from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI, Claude
1414

1515

1616
class AbstractGraph(ABC):
@@ -22,7 +22,8 @@ class AbstractGraph(ABC):
2222
source (str): The source of the graph.
2323
config (dict): Configuration parameters for the graph.
2424
llm_model: An instance of a language model client, configured for generating answers.
25-
embedder_model: An instance of an embedding model client, configured for generating embeddings.
25+
embedder_model: An instance of an embedding model client,
26+
configured for generating embeddings.
2627
verbose (bool): A flag indicating whether to show print statements during execution.
2728
headless (bool): A flag indicating whether to run the graph in headless mode.
2829
@@ -47,8 +48,8 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
4748
self.source = source
4849
self.config = config
4950
self.llm_model = self._create_llm(config["llm"], chat=True)
50-
self.embedder_model = self._create_default_embedder(
51-
) if "embeddings" not in config else self._create_embedder(
51+
self.embedder_model = self._create_default_embedder(
52+
) if "embeddings" not in config else self._create_embedder(
5253
config["embeddings"])
5354

5455
# Set common configuration parameters
@@ -61,23 +62,21 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
6162
self.final_state = None
6263
self.execution_info = None
6364

64-
6565
def _set_model_token(self, llm):
6666

6767
if 'Azure' in str(type(llm)):
6868
try:
6969
self.model_token = models_tokens["azure"][llm.model_name]
7070
except KeyError:
7171
raise KeyError("Model not supported")
72-
72+
7373
elif 'HuggingFaceEndpoint' in str(type(llm)):
7474
if 'mistral' in llm.repo_id:
7575
try:
7676
self.model_token = models_tokens['mistral'][llm.repo_id]
7777
except KeyError:
7878
raise KeyError("Model not supported")
7979

80-
8180
def _create_llm(self, llm_config: dict, chat=False) -> object:
8281
"""
8382
Create a large language model instance based on the configuration provided.
@@ -103,31 +102,36 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
103102
if chat:
104103
self._set_model_token(llm_params['model_instance'])
105104
return llm_params['model_instance']
106-
105+
107106
# Instantiate the language model based on the model name
108107
if "gpt-" in llm_params["model"]:
109108
try:
110109
self.model_token = models_tokens["openai"][llm_params["model"]]
111-
except KeyError:
112-
raise KeyError("Model not supported")
110+
except KeyError as exc:
111+
raise KeyError("Model not supported") from exc
113112
return OpenAI(llm_params)
114113

115114
elif "azure" in llm_params["model"]:
116115
# take the model after the last dash
117116
llm_params["model"] = llm_params["model"].split("/")[-1]
118117
try:
119118
self.model_token = models_tokens["azure"][llm_params["model"]]
120-
except KeyError:
121-
raise KeyError("Model not supported")
119+
except KeyError as exc:
120+
raise KeyError("Model not supported") from exc
122121
return AzureOpenAI(llm_params)
123122

124123
elif "gemini" in llm_params["model"]:
125124
try:
126125
self.model_token = models_tokens["gemini"][llm_params["model"]]
127-
except KeyError:
128-
raise KeyError("Model not supported")
126+
except KeyError as exc:
127+
raise KeyError("Model not supported") from exc
129128
return Gemini(llm_params)
130-
129+
elif "claude" in llm_params["model"]:
130+
try:
131+
self.model_token = models_tokens["claude"][llm_params["model"]]
132+
except KeyError as exc:
133+
raise KeyError("Model not supported") from exc
134+
return Claude(llm_params)
131135
elif "ollama" in llm_params["model"]:
132136
llm_params["model"] = llm_params["model"].split("/")[-1]
133137

@@ -138,8 +142,8 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
138142
elif llm_params["model"] in models_tokens["ollama"]:
139143
try:
140144
self.model_token = models_tokens["ollama"][llm_params["model"]]
141-
except KeyError:
142-
raise KeyError("Model not supported")
145+
except KeyError as exc:
146+
raise KeyError("Model not supported") from exc
143147
else:
144148
self.model_token = 8192
145149
except AttributeError:
@@ -149,25 +153,25 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
149153
elif "hugging_face" in llm_params["model"]:
150154
try:
151155
self.model_token = models_tokens["hugging_face"][llm_params["model"]]
152-
except KeyError:
153-
raise KeyError("Model not supported")
156+
except KeyError as exc:
157+
raise KeyError("Model not supported") from exc
154158
return HuggingFace(llm_params)
155159
elif "groq" in llm_params["model"]:
156160
llm_params["model"] = llm_params["model"].split("/")[-1]
157161

158162
try:
159163
self.model_token = models_tokens["groq"][llm_params["model"]]
160-
except KeyError:
161-
raise KeyError("Model not supported")
164+
except KeyError as exc:
165+
raise KeyError("Model not supported") from exc
162166
return Groq(llm_params)
163167
elif "bedrock" in llm_params["model"]:
164168
llm_params["model"] = llm_params["model"].split("/")[-1]
165169
model_id = llm_params["model"]
166170

167171
try:
168172
self.model_token = models_tokens["bedrock"][llm_params["model"]]
169-
except KeyError:
170-
raise KeyError("Model not supported")
173+
except KeyError as exc:
174+
raise KeyError("Model not supported") from exc
171175
return Bedrock({
172176
"model_id": model_id,
173177
"model_kwargs": {
@@ -177,7 +181,7 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
177181
else:
178182
raise ValueError(
179183
"Model provided by the configuration not supported")
180-
184+
181185
def _create_default_embedder(self) -> object:
182186
"""
183187
Create an embedding model instance based on the chosen llm model.
@@ -208,7 +212,7 @@ def _create_default_embedder(self) -> object:
208212
return BedrockEmbeddings(client=None, model_id=self.llm_model.model_id)
209213
else:
210214
raise ValueError("Embedding Model missing or not supported")
211-
215+
212216
def _create_embedder(self, embedder_config: dict) -> object:
213217
"""
214218
Create an embedding model instance based on the configuration provided.
@@ -237,27 +241,27 @@ def _create_embedder(self, embedder_config: dict) -> object:
237241
embedder_config["model"] = embedder_config["model"].split("/")[-1]
238242
try:
239243
models_tokens["ollama"][embedder_config["model"]]
240-
except KeyError:
241-
raise KeyError("Model not supported")
244+
except KeyError as exc:
245+
raise KeyError("Model not supported") from exc
242246
return OllamaEmbeddings(**embedder_config)
243-
247+
244248
elif "hugging_face" in embedder_config["model"]:
245249
try:
246250
models_tokens["hugging_face"][embedder_config["model"]]
247-
except KeyError:
248-
raise KeyError("Model not supported")
251+
except KeyError as exc:
252+
raise KeyError("Model not supported")from exc
249253
return HuggingFaceHubEmbeddings(model=embedder_config["model"])
250-
254+
251255
elif "bedrock" in embedder_config["model"]:
252256
embedder_config["model"] = embedder_config["model"].split("/")[-1]
253257
try:
254258
models_tokens["bedrock"][embedder_config["model"]]
255-
except KeyError:
256-
raise KeyError("Model not supported")
259+
except KeyError as exc:
260+
raise KeyError("Model not supported") from exc
257261
return BedrockEmbeddings(client=None, model_id=embedder_config["model"])
258262
else:
259263
raise ValueError(
260-
"Model provided by the configuration not supported")
264+
"Model provided by the configuration not supported")
261265

262266
def get_state(self, key=None) -> dict:
263267
"""""
@@ -281,7 +285,7 @@ def get_execution_info(self):
281285
Returns:
282286
dict: The execution information of the graph.
283287
"""
284-
288+
285289
return self.execution_info
286290

287291
@abstractmethod
@@ -297,4 +301,3 @@ def run(self) -> str:
297301
Abstract method to execute the graph and return the result.
298302
"""
299303
pass
300-

scrapegraphai/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@
1111
from .hugging_face import HuggingFace
1212
from .groq import Groq
1313
from .bedrock import Bedrock
14+
from .claude import Claude

scrapegraphai/models/claude.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
"""
2+
Claude Module
3+
"""
4+
5+
from langchain_anthropic import ChatAnthropic
6+
7+
8+
class Claude(ChatAnthropic):
9+
"""
10+
A wrapper for the ChatAnthropic class that provides default configuration
11+
and could be extended with additional methods if needed.
12+
13+
Args:
14+
llm_config (dict): Configuration parameters for the language model
15+
(e.g., model="claude_instant")
16+
"""
17+
18+
def __init__(self, llm_config: dict):
19+
super().__init__(**llm_config)

scrapegraphai/models/gemini.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ class Gemini(ChatGoogleGenerativeAI):
1010
and could be extended with additional methods if needed.
1111
1212
Args:
13-
llm_config (dict): Configuration parameters for the language model (e.g., model="gemini-pro")
13+
llm_config (dict): Configuration parameters for the language model
14+
(e.g., model="gemini-pro")
1415
"""
1516

1617
def __init__(self, llm_config: dict):

0 commit comments

Comments
 (0)