Skip to content

Commit 39f64e5

Browse files
committed
add claude model
1 parent 8b94fe8 commit 39f64e5

File tree

6 files changed

+65
-40
lines changed

6 files changed

+65
-40
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ free-proxy = "1.1.1"
4141
langchain-groq = "0.1.3"
4242
playwright = "^1.43.0"
4343
langchain-aws = "^0.1.2"
44-
44+
langchain-anthropic = "^0.1.11"
4545

4646
[tool.poetry.dev-dependencies]
4747
pytest = "8.0.0"

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@ free-proxy==1.1.1
1515
langchain-groq==0.1.3
1616
playwright==1.43.0
1717
langchain-aws==0.1.2
18+
langchain-anthropic==0.1.11

scrapegraphai/graphs/abstract_graph.py

Lines changed: 41 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
1111

1212
from ..helpers import models_tokens
13-
from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI
13+
from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI, Claude
1414

1515

1616
class AbstractGraph(ABC):
@@ -22,7 +22,8 @@ class AbstractGraph(ABC):
2222
source (str): The source of the graph.
2323
config (dict): Configuration parameters for the graph.
2424
llm_model: An instance of a language model client, configured for generating answers.
25-
embedder_model: An instance of an embedding model client, configured for generating embeddings.
25+
embedder_model: An instance of an embedding model client,
26+
configured for generating embeddings.
2627
verbose (bool): A flag indicating whether to show print statements during execution.
2728
headless (bool): A flag indicating whether to run the graph in headless mode.
2829
@@ -47,8 +48,8 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
4748
self.source = source
4849
self.config = config
4950
self.llm_model = self._create_llm(config["llm"], chat=True)
50-
self.embedder_model = self._create_default_embedder(
51-
) if "embeddings" not in config else self._create_embedder(
51+
self.embedder_model = self._create_default_embedder(
52+
) if "embeddings" not in config else self._create_embedder(
5253
config["embeddings"])
5354

5455
# Set common configuration parameters
@@ -61,15 +62,13 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
6162
self.final_state = None
6263
self.execution_info = None
6364

64-
6565
def _set_model_token(self, llm):
6666

6767
if 'Azure' in str(type(llm)):
6868
try:
6969
self.model_token = models_tokens["azure"][llm.model_name]
70-
except KeyError:
71-
raise KeyError("Model not supported")
72-
70+
except KeyError as exc:
71+
raise KeyError("Model not supported") from exc
7372

7473
def _create_llm(self, llm_config: dict, chat=False) -> object:
7574
"""
@@ -96,31 +95,36 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
9695
if chat:
9796
self._set_model_token(llm_params['model_instance'])
9897
return llm_params['model_instance']
99-
98+
10099
# Instantiate the language model based on the model name
101100
if "gpt-" in llm_params["model"]:
102101
try:
103102
self.model_token = models_tokens["openai"][llm_params["model"]]
104-
except KeyError:
105-
raise KeyError("Model not supported")
103+
except KeyError as exc:
104+
raise KeyError("Model not supported") from exc
106105
return OpenAI(llm_params)
107106

108107
elif "azure" in llm_params["model"]:
109108
# take the model after the last dash
110109
llm_params["model"] = llm_params["model"].split("/")[-1]
111110
try:
112111
self.model_token = models_tokens["azure"][llm_params["model"]]
113-
except KeyError:
114-
raise KeyError("Model not supported")
112+
except KeyError as exc:
113+
raise KeyError("Model not supported") from exc
115114
return AzureOpenAI(llm_params)
116115

117116
elif "gemini" in llm_params["model"]:
118117
try:
119118
self.model_token = models_tokens["gemini"][llm_params["model"]]
120-
except KeyError:
121-
raise KeyError("Model not supported")
119+
except KeyError as exc:
120+
raise KeyError("Model not supported") from exc
122121
return Gemini(llm_params)
123-
122+
elif "claude" in llm_params["model"]:
123+
try:
124+
self.model_token = models_tokens["claude"][llm_params["model"]]
125+
except KeyError as exc:
126+
raise KeyError("Model not supported") from exc
127+
return Claude(llm_params)
124128
elif "ollama" in llm_params["model"]:
125129
llm_params["model"] = llm_params["model"].split("/")[-1]
126130

@@ -131,8 +135,8 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
131135
elif llm_params["model"] in models_tokens["ollama"]:
132136
try:
133137
self.model_token = models_tokens["ollama"][llm_params["model"]]
134-
except KeyError:
135-
raise KeyError("Model not supported")
138+
except KeyError as exc:
139+
raise KeyError("Model not supported") from exc
136140
else:
137141
self.model_token = 8192
138142
except AttributeError:
@@ -142,25 +146,25 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
142146
elif "hugging_face" in llm_params["model"]:
143147
try:
144148
self.model_token = models_tokens["hugging_face"][llm_params["model"]]
145-
except KeyError:
146-
raise KeyError("Model not supported")
149+
except KeyError as exc:
150+
raise KeyError("Model not supported") from exc
147151
return HuggingFace(llm_params)
148152
elif "groq" in llm_params["model"]:
149153
llm_params["model"] = llm_params["model"].split("/")[-1]
150154

151155
try:
152156
self.model_token = models_tokens["groq"][llm_params["model"]]
153-
except KeyError:
154-
raise KeyError("Model not supported")
157+
except KeyError as exc:
158+
raise KeyError("Model not supported") from exc
155159
return Groq(llm_params)
156160
elif "bedrock" in llm_params["model"]:
157161
llm_params["model"] = llm_params["model"].split("/")[-1]
158162
model_id = llm_params["model"]
159163

160164
try:
161165
self.model_token = models_tokens["bedrock"][llm_params["model"]]
162-
except KeyError:
163-
raise KeyError("Model not supported")
166+
except KeyError as exc:
167+
raise KeyError("Model not supported") from exc
164168
return Bedrock({
165169
"model_id": model_id,
166170
"model_kwargs": {
@@ -170,7 +174,7 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
170174
else:
171175
raise ValueError(
172176
"Model provided by the configuration not supported")
173-
177+
174178
def _create_default_embedder(self) -> object:
175179
"""
176180
Create an embedding model instance based on the chosen llm model.
@@ -202,7 +206,7 @@ def _create_default_embedder(self) -> object:
202206
return BedrockEmbeddings(client=None, model_id=self.llm_model.model_id)
203207
else:
204208
raise ValueError("Embedding Model missing or not supported")
205-
209+
206210
def _create_embedder(self, embedder_config: dict) -> object:
207211
"""
208212
Create an embedding model instance based on the configuration provided.
@@ -216,7 +220,7 @@ def _create_embedder(self, embedder_config: dict) -> object:
216220
Raises:
217221
KeyError: If the model is not supported.
218222
"""
219-
223+
220224
# Instantiate the embedding model based on the model name
221225
if "openai" in embedder_config["model"]:
222226
return OpenAIEmbeddings(api_key=embedder_config["api_key"])
@@ -228,27 +232,27 @@ def _create_embedder(self, embedder_config: dict) -> object:
228232
embedder_config["model"] = embedder_config["model"].split("/")[-1]
229233
try:
230234
models_tokens["ollama"][embedder_config["model"]]
231-
except KeyError:
232-
raise KeyError("Model not supported")
235+
except KeyError as exc:
236+
raise KeyError("Model not supported") from exc
233237
return OllamaEmbeddings(**embedder_config)
234-
238+
235239
elif "hugging_face" in embedder_config["model"]:
236240
try:
237241
models_tokens["hugging_face"][embedder_config["model"]]
238-
except KeyError:
239-
raise KeyError("Model not supported")
242+
except KeyError as exc:
243+
raise KeyError("Model not supported")from exc
240244
return HuggingFaceHubEmbeddings(model=embedder_config["model"])
241-
245+
242246
elif "bedrock" in embedder_config["model"]:
243247
embedder_config["model"] = embedder_config["model"].split("/")[-1]
244248
try:
245249
models_tokens["bedrock"][embedder_config["model"]]
246-
except KeyError:
247-
raise KeyError("Model not supported")
250+
except KeyError as exc:
251+
raise KeyError("Model not supported") from exc
248252
return BedrockEmbeddings(client=None, model_id=embedder_config["model"])
249253
else:
250254
raise ValueError(
251-
"Model provided by the configuration not supported")
255+
"Model provided by the configuration not supported")
252256

253257
def get_state(self, key=None) -> dict:
254258
"""""
@@ -272,7 +276,7 @@ def get_execution_info(self):
272276
Returns:
273277
dict: The execution information of the graph.
274278
"""
275-
279+
276280
return self.execution_info
277281

278282
@abstractmethod
@@ -288,4 +292,3 @@ def run(self) -> str:
288292
Abstract method to execute the graph and return the result.
289293
"""
290294
pass
291-

scrapegraphai/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@
1111
from .hugging_face import HuggingFace
1212
from .groq import Groq
1313
from .bedrock import Bedrock
14+
from .claude import Claude

scrapegraphai/models/claude.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
"""
2+
Claude Module
3+
"""
4+
5+
from langchain_anthropic import ChatAnthropic
6+
7+
8+
class Claude(ChatAnthropic):
9+
"""
10+
A wrapper for the ChatAnthropic class that provides default configuration
11+
and could be extended with additional methods if needed.
12+
13+
Args:
14+
llm_config (dict): Configuration parameters for the language model
15+
(e.g., model="claude_instant")
16+
"""
17+
18+
def __init__(self, llm_config: dict):
19+
super().__init__(**llm_config)

scrapegraphai/models/gemini.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ class Gemini(ChatGoogleGenerativeAI):
1010
and could be extended with additional methods if needed.
1111
1212
Args:
13-
llm_config (dict): Configuration parameters for the language model (e.g., model="gemini-pro")
13+
llm_config (dict): Configuration parameters for the language model
14+
(e.g., model="gemini-pro")
1415
"""
1516

1617
def __init__(self, llm_config: dict):

0 commit comments

Comments
 (0)