Skip to content

Commit 2abe05a

Browse files
authored
Merge pull request #135 from S4mpl3r/feature
2 parents 98dec36 + 819cbcd commit 2abe05a

File tree

3 files changed

+88
-28
lines changed

3 files changed

+88
-28
lines changed

examples/groq/smart_scraper_groq_openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
},
2626
"embeddings": {
2727
"api_key": openai_key,
28-
"model": "gpt-3.5-turbo",
28+
"model": "openai",
2929
},
3030
"headless": False
3131
}

scrapegraphai/graphs/abstract_graph.py

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,12 @@
55
from abc import ABC, abstractmethod
66
from typing import Optional
77

8-
from ..models import OpenAI, Gemini, Ollama, AzureOpenAI, HuggingFace, Groq, Bedrock
8+
from langchain_aws.embeddings.bedrock import BedrockEmbeddings
9+
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
10+
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
11+
912
from ..helpers import models_tokens
13+
from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI
1014

1115

1216
class AbstractGraph(ABC):
@@ -43,7 +47,8 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
4347
self.source = source
4448
self.config = config
4549
self.llm_model = self._create_llm(config["llm"], chat=True)
46-
self.embedder_model = self.llm_model if "embeddings" not in config else self._create_llm(
50+
self.embedder_model = self._create_default_embedder(
51+
) if "embeddings" not in config else self._create_embedder(
4752
config["embeddings"])
4853

4954
# Set common configuration parameters
@@ -165,6 +170,85 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
165170
else:
166171
raise ValueError(
167172
"Model provided by the configuration not supported")
173+
174+
def _create_default_embedder(self) -> object:
175+
"""
176+
Create an embedding model instance based on the chosen llm model.
177+
178+
Returns:
179+
object: An instance of the embedding model client.
180+
181+
Raises:
182+
ValueError: If the model is not supported.
183+
"""
184+
185+
if isinstance(self.llm_model, OpenAI):
186+
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
187+
elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
188+
return self.llm_model
189+
elif isinstance(self.llm_model, AzureOpenAI):
190+
return AzureOpenAIEmbeddings()
191+
elif isinstance(self.llm_model, Ollama):
192+
# unwrap the kwargs from the model whihc is a dict
193+
params = self.llm_model._lc_kwargs
194+
# remove streaming and temperature
195+
params.pop("streaming", None)
196+
params.pop("temperature", None)
197+
198+
return OllamaEmbeddings(**params)
199+
elif isinstance(self.llm_model, HuggingFace):
200+
return HuggingFaceHubEmbeddings(model=self.llm_model.model)
201+
elif isinstance(self.llm_model, Bedrock):
202+
return BedrockEmbeddings(client=None, model_id=self.llm_model.model_id)
203+
else:
204+
raise ValueError("Embedding Model missing or not supported")
205+
206+
def _create_embedder(self, embedder_config: dict) -> object:
207+
"""
208+
Create an embedding model instance based on the configuration provided.
209+
210+
Args:
211+
embedder_config (dict): Configuration parameters for the embedding model.
212+
213+
Returns:
214+
object: An instance of the embedding model client.
215+
216+
Raises:
217+
KeyError: If the model is not supported.
218+
"""
219+
220+
# Instantiate the embedding model based on the model name
221+
if "openai" in embedder_config["model"]:
222+
return OpenAIEmbeddings(api_key=embedder_config["api_key"])
223+
224+
elif "azure" in embedder_config["model"]:
225+
return AzureOpenAIEmbeddings()
226+
227+
elif "ollama" in embedder_config["model"]:
228+
embedder_config["model"] = embedder_config["model"].split("/")[-1]
229+
try:
230+
models_tokens["ollama"][embedder_config["model"]]
231+
except KeyError:
232+
raise KeyError("Model not supported")
233+
return OllamaEmbeddings(**embedder_config)
234+
235+
elif "hugging_face" in embedder_config["model"]:
236+
try:
237+
models_tokens["hugging_face"][embedder_config["model"]]
238+
except KeyError:
239+
raise KeyError("Model not supported")
240+
return HuggingFaceHubEmbeddings(model=embedder_config["model"])
241+
242+
elif "bedrock" in embedder_config["model"]:
243+
embedder_config["model"] = embedder_config["model"].split("/")[-1]
244+
try:
245+
models_tokens["bedrock"][embedder_config["model"]]
246+
except KeyError:
247+
raise KeyError("Model not supported")
248+
return BedrockEmbeddings(client=None, model_id=embedder_config["model"])
249+
else:
250+
raise ValueError(
251+
"Model provided by the configuration not supported")
168252

169253
def get_state(self, key=None) -> dict:
170254
"""""

scrapegraphai/nodes/rag_node.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -87,31 +87,7 @@ def execute(self, state: dict) -> dict:
8787
if self.verbose:
8888
print("--- (updated chunks metadata) ---")
8989

90-
# check if embedder_model is provided, if not use llm_model
91-
embedding_model = self.embedder_model if self.embedder_model else self.llm_model
92-
93-
if isinstance(embedding_model, OpenAI):
94-
embeddings = OpenAIEmbeddings(
95-
api_key=embedding_model.openai_api_key)
96-
elif isinstance(embedding_model, AzureOpenAIEmbeddings):
97-
embeddings = embedding_model
98-
elif isinstance(embedding_model, AzureOpenAI):
99-
embeddings = AzureOpenAIEmbeddings()
100-
elif isinstance(embedding_model, Ollama):
101-
# unwrap the kwargs from the model whihc is a dict
102-
params = embedding_model._lc_kwargs
103-
# remove streaming and temperature
104-
params.pop("streaming", None)
105-
params.pop("temperature", None)
106-
107-
embeddings = OllamaEmbeddings(**params)
108-
elif isinstance(embedding_model, HuggingFace):
109-
embeddings = HuggingFaceHubEmbeddings(model=embedding_model.model)
110-
elif isinstance(embedding_model, Bedrock):
111-
embeddings = BedrockEmbeddings(
112-
client=None, model_id=embedding_model.model_id)
113-
else:
114-
raise ValueError("Embedding Model missing or not supported")
90+
embeddings = self.embedder_model
11591

11692
retriever = FAISS.from_documents(
11793
chunked_docs, embeddings).as_retriever()

0 commit comments

Comments
 (0)