Skip to content

Commit a94ebcd

Browse files
committed
refactor: move embeddings code from AbstractGraph to RAGNode
1 parent bb73d91 commit a94ebcd

File tree

2 files changed

+142
-125
lines changed

2 files changed

+142
-125
lines changed

scrapegraphai/graphs/abstract_graph.py

Lines changed: 2 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,8 @@
77
import uuid
88
from pydantic import BaseModel
99

10-
from langchain_community.chat_models import ChatOllama, ErnieBotChat
11-
from langchain_aws import BedrockEmbeddings, ChatBedrock
12-
from langchain_huggingface import ChatHuggingFace, HuggingFaceEmbeddings
13-
from langchain_community.embeddings import OllamaEmbeddings
14-
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
15-
from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings
16-
from langchain_fireworks import FireworksEmbeddings, ChatFireworks
17-
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings, ChatOpenAI, AzureChatOpenAI
18-
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings, ChatNVIDIA
10+
from langchain_community.chat_models import ErnieBotChat
11+
from langchain_nvidia_ai_endpoints import ChatNVIDIA
1912
from langchain.chat_models import init_chat_model
2013

2114
from ..helpers import models_tokens
@@ -66,8 +59,6 @@ def __init__(self, prompt: str, config: dict,
6659
self.config = config
6760
self.schema = schema
6861
self.llm_model = self._create_llm(config["llm"])
69-
self.embedder_model = self._create_default_embedder(llm_config=config["llm"]) if "embeddings" not in config else self._create_embedder(
70-
config["embeddings"])
7162
self.verbose = False if config is None else config.get(
7263
"verbose", False)
7364
self.headless = True if config is None else config.get(
@@ -237,116 +228,6 @@ def handle_model(model_name, provider, token_key, default_token=8192):
237228
# Raise an error if the model did not match any of the previous cases
238229
raise ValueError("Model provided by the configuration not supported")
239230

240-
def _create_default_embedder(self, llm_config=None) -> object:
241-
"""
242-
Create an embedding model instance based on the chosen llm model.
243-
244-
Returns:
245-
object: An instance of the embedding model client.
246-
247-
Raises:
248-
ValueError: If the model is not supported.
249-
"""
250-
if isinstance(self.llm_model, ChatGoogleGenerativeAI):
251-
return GoogleGenerativeAIEmbeddings(
252-
google_api_key=llm_config["api_key"], model="models/embedding-001"
253-
)
254-
if isinstance(self.llm_model, ChatOpenAI):
255-
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key,
256-
base_url=self.llm_model.openai_api_base)
257-
elif isinstance(self.llm_model, DeepSeek):
258-
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
259-
elif isinstance(self.llm_model, ChatVertexAI):
260-
return VertexAIEmbeddings()
261-
elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
262-
return self.llm_model
263-
elif isinstance(self.llm_model, AzureChatOpenAI):
264-
return AzureOpenAIEmbeddings()
265-
elif isinstance(self.llm_model, ChatFireworks):
266-
return FireworksEmbeddings(model=self.llm_model.model_name)
267-
elif isinstance(self.llm_model, ChatNVIDIA):
268-
return NVIDIAEmbeddings(model=self.llm_model.model_name)
269-
elif isinstance(self.llm_model, ChatOllama):
270-
# unwrap the kwargs from the model whihc is a dict
271-
params = self.llm_model._lc_kwargs
272-
# remove streaming and temperature
273-
params.pop("streaming", None)
274-
params.pop("temperature", None)
275-
276-
return OllamaEmbeddings(**params)
277-
elif isinstance(self.llm_model, ChatHuggingFace):
278-
return HuggingFaceEmbeddings(model=self.llm_model.model)
279-
elif isinstance(self.llm_model, ChatBedrock):
280-
return BedrockEmbeddings(client=None, model_id=self.llm_model.model_id)
281-
else:
282-
raise ValueError("Embedding Model missing or not supported")
283-
284-
def _create_embedder(self, embedder_config: dict) -> object:
285-
"""
286-
Create an embedding model instance based on the configuration provided.
287-
288-
Args:
289-
embedder_config (dict): Configuration parameters for the embedding model.
290-
291-
Returns:
292-
object: An instance of the embedding model client.
293-
294-
Raises:
295-
KeyError: If the model is not supported.
296-
"""
297-
embedder_params = {**embedder_config}
298-
if "model_instance" in embedder_config:
299-
return embedder_params["model_instance"]
300-
# Instantiate the embedding model based on the model name
301-
if "openai" in embedder_params["model"]:
302-
return OpenAIEmbeddings(api_key=embedder_params["api_key"])
303-
if "azure" in embedder_params["model"]:
304-
return AzureOpenAIEmbeddings()
305-
if "nvidia" in embedder_params["model"]:
306-
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
307-
try:
308-
models_tokens["nvidia"][embedder_params["model"]]
309-
except KeyError as exc:
310-
raise KeyError("Model not supported") from exc
311-
return NVIDIAEmbeddings(model=embedder_params["model"],
312-
nvidia_api_key=embedder_params["api_key"])
313-
if "ollama" in embedder_params["model"]:
314-
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
315-
try:
316-
models_tokens["ollama"][embedder_params["model"]]
317-
except KeyError as exc:
318-
raise KeyError("Model not supported") from exc
319-
return OllamaEmbeddings(**embedder_params)
320-
if "hugging_face" in embedder_params["model"]:
321-
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
322-
try:
323-
models_tokens["hugging_face"][embedder_params["model"]]
324-
except KeyError as exc:
325-
raise KeyError("Model not supported") from exc
326-
return HuggingFaceEmbeddings(model=embedder_params["model"])
327-
if "fireworks" in embedder_params["model"]:
328-
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
329-
try:
330-
models_tokens["fireworks"][embedder_params["model"]]
331-
except KeyError as exc:
332-
raise KeyError("Model not supported") from exc
333-
return FireworksEmbeddings(model=embedder_params["model"])
334-
if "gemini" in embedder_params["model"]:
335-
try:
336-
models_tokens["gemini"][embedder_params["model"]]
337-
except KeyError as exc:
338-
raise KeyError("Model not supported") from exc
339-
return GoogleGenerativeAIEmbeddings(model=embedder_params["model"])
340-
if "bedrock" in embedder_params["model"]:
341-
embedder_params["model"] = embedder_params["model"].split("/")[-1]
342-
client = embedder_params.get("client", None)
343-
try:
344-
models_tokens["bedrock"][embedder_params["model"]]
345-
except KeyError as exc:
346-
raise KeyError("Model not supported") from exc
347-
return BedrockEmbeddings(client=client, model_id=embedder_params["model"])
348-
349-
raise ValueError("Model provided by the configuration not supported")
350231

351232
def get_state(self, key=None) -> dict:
352233
""" ""

scrapegraphai/nodes/rag_node.py

Lines changed: 140 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,20 @@
1414
from langchain_community.document_transformers import EmbeddingsRedundantFilter
1515
from langchain_community.vectorstores import FAISS
1616

17+
from langchain_community.chat_models import ChatOllama
18+
from langchain_aws import BedrockEmbeddings, ChatBedrock
19+
from langchain_huggingface import ChatHuggingFace, HuggingFaceEmbeddings
20+
from langchain_community.embeddings import OllamaEmbeddings
21+
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
22+
from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings
23+
from langchain_fireworks import FireworksEmbeddings, ChatFireworks
24+
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings, ChatOpenAI, AzureChatOpenAI
25+
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings, ChatNVIDIA
26+
1727
from ..utils.logging import get_logger
1828
from .base_node import BaseNode
29+
from ..helpers import models_tokens
30+
from ..models import DeepSeek
1931

2032

2133
class RAGNode(BaseNode):
@@ -95,10 +107,21 @@ def execute(self, state: dict) -> dict:
95107
self.logger.info("--- (updated chunks metadata) ---")
96108

97109
# check if embedder_model is provided, if not use llm_model
98-
self.embedder_model = (
99-
self.embedder_model if self.embedder_model else self.llm_model
100-
)
101-
embeddings = self.embedder_model
110+
if self.embedder_model is not None:
111+
embeddings = self.embedder_model
112+
elif 'embeddings' in self.node_config:
113+
try:
114+
embeddings = self._create_embedder(self.node_config['embedder_config'])
115+
except Exception:
116+
try:
117+
embeddings = self._create_default_embedder()
118+
self.embedder_model = embeddings
119+
except ValueError:
120+
embeddings = self.llm_model
121+
self.embedder_model = self.llm_model
122+
else:
123+
embeddings = self.llm_model
124+
self.embedder_model = self.llm_model
102125

103126
folder_name = self.node_config.get("cache_path", "cache")
104127

@@ -141,3 +164,116 @@ def execute(self, state: dict) -> dict:
141164

142165
state.update({self.output[0]: compressed_docs})
143166
return state
167+
168+
169+
def _create_default_embedder(self, llm_config=None) -> object:
170+
"""
171+
Create an embedding model instance based on the chosen llm model.
172+
173+
Returns:
174+
object: An instance of the embedding model client.
175+
176+
Raises:
177+
ValueError: If the model is not supported.
178+
"""
179+
if isinstance(self.llm_model, ChatGoogleGenerativeAI):
180+
return GoogleGenerativeAIEmbeddings(
181+
google_api_key=llm_config["api_key"], model="models/embedding-001"
182+
)
183+
if isinstance(self.llm_model, ChatOpenAI):
184+
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key,
185+
base_url=self.llm_model.openai_api_base)
186+
elif isinstance(self.llm_model, DeepSeek):
187+
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
188+
elif isinstance(self.llm_model, ChatVertexAI):
189+
return VertexAIEmbeddings()
190+
elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
191+
return self.llm_model
192+
elif isinstance(self.llm_model, AzureChatOpenAI):
193+
return AzureOpenAIEmbeddings()
194+
elif isinstance(self.llm_model, ChatFireworks):
195+
return FireworksEmbeddings(model=self.llm_model.model_name)
196+
elif isinstance(self.llm_model, ChatNVIDIA):
197+
return NVIDIAEmbeddings(model=self.llm_model.model_name)
198+
elif isinstance(self.llm_model, ChatOllama):
199+
# unwrap the kwargs from the model whihc is a dict
200+
params = self.llm_model._lc_kwargs
201+
# remove streaming and temperature
202+
params.pop("streaming", None)
203+
params.pop("temperature", None)
204+
205+
return OllamaEmbeddings(**params)
206+
elif isinstance(self.llm_model, ChatHuggingFace):
207+
return HuggingFaceEmbeddings(model=self.llm_model.model)
208+
elif isinstance(self.llm_model, ChatBedrock):
209+
return BedrockEmbeddings(client=None, model_id=self.llm_model.model_id)
210+
else:
211+
raise ValueError("Embedding Model missing or not supported")
212+
213+
214+
def _create_embedder(self, embedder_config: dict) -> object:
215+
"""
216+
Create an embedding model instance based on the configuration provided.
217+
218+
Args:
219+
embedder_config (dict): Configuration parameters for the embedding model.
220+
221+
Returns:
222+
object: An instance of the embedding model client.
223+
224+
Raises:
225+
KeyError: If the model is not supported.
226+
"""
227+
embedder_params = {**embedder_config}
228+
if "model_instance" in embedder_config:
229+
return embedder_params["model_instance"]
230+
# Instantiate the embedding model based on the model name
231+
if "openai" in embedder_params["model"]:
232+
return OpenAIEmbeddings(api_key=embedder_params["api_key"])
233+
if "azure" in embedder_params["model"]:
234+
return AzureOpenAIEmbeddings()
235+
if "nvidia" in embedder_params["model"]:
236+
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
237+
try:
238+
models_tokens["nvidia"][embedder_params["model"]]
239+
except KeyError as exc:
240+
raise KeyError("Model not supported") from exc
241+
return NVIDIAEmbeddings(model=embedder_params["model"],
242+
nvidia_api_key=embedder_params["api_key"])
243+
if "ollama" in embedder_params["model"]:
244+
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
245+
try:
246+
models_tokens["ollama"][embedder_params["model"]]
247+
except KeyError as exc:
248+
raise KeyError("Model not supported") from exc
249+
return OllamaEmbeddings(**embedder_params)
250+
if "hugging_face" in embedder_params["model"]:
251+
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
252+
try:
253+
models_tokens["hugging_face"][embedder_params["model"]]
254+
except KeyError as exc:
255+
raise KeyError("Model not supported") from exc
256+
return HuggingFaceEmbeddings(model=embedder_params["model"])
257+
if "fireworks" in embedder_params["model"]:
258+
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
259+
try:
260+
models_tokens["fireworks"][embedder_params["model"]]
261+
except KeyError as exc:
262+
raise KeyError("Model not supported") from exc
263+
return FireworksEmbeddings(model=embedder_params["model"])
264+
if "gemini" in embedder_params["model"]:
265+
try:
266+
models_tokens["gemini"][embedder_params["model"]]
267+
except KeyError as exc:
268+
raise KeyError("Model not supported") from exc
269+
return GoogleGenerativeAIEmbeddings(model=embedder_params["model"])
270+
if "bedrock" in embedder_params["model"]:
271+
embedder_params["model"] = embedder_params["model"].split("/")[-1]
272+
client = embedder_params.get("client", None)
273+
try:
274+
models_tokens["bedrock"][embedder_params["model"]]
275+
except KeyError as exc:
276+
raise KeyError("Model not supported") from exc
277+
return BedrockEmbeddings(client=client, model_id=embedder_params["model"])
278+
279+
raise ValueError("Model provided by the configuration not supported")

0 commit comments

Comments
 (0)