Skip to content

Changed the way embedding model creation is handled in the AbstractGraph class. #135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/groq/smart_scraper_groq_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
},
"embeddings": {
"api_key": openai_key,
"model": "gpt-3.5-turbo",
"model": "openai",
},
"headless": False
}
Expand Down
88 changes: 86 additions & 2 deletions scrapegraphai/graphs/abstract_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
from abc import ABC, abstractmethod
from typing import Optional

from ..models import OpenAI, Gemini, Ollama, AzureOpenAI, HuggingFace, Groq, Bedrock
from langchain_aws.embeddings.bedrock import BedrockEmbeddings
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings

from ..helpers import models_tokens
from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI


class AbstractGraph(ABC):
Expand Down Expand Up @@ -43,7 +47,8 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
self.source = source
self.config = config
self.llm_model = self._create_llm(config["llm"], chat=True)
self.embedder_model = self.llm_model if "embeddings" not in config else self._create_llm(
self.embedder_model = self._create_default_embedder(
) if "embeddings" not in config else self._create_embedder(
config["embeddings"])

# Set common configuration parameters
Expand Down Expand Up @@ -165,6 +170,85 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
else:
raise ValueError(
"Model provided by the configuration not supported")

def _create_default_embedder(self) -> object:
"""
Create an embedding model instance based on the chosen llm model.

Returns:
object: An instance of the embedding model client.

Raises:
ValueError: If the model is not supported.
"""

if isinstance(self.llm_model, OpenAI):
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
return self.llm_model
elif isinstance(self.llm_model, AzureOpenAI):
return AzureOpenAIEmbeddings()
elif isinstance(self.llm_model, Ollama):
# unwrap the kwargs from the model whihc is a dict
params = self.llm_model._lc_kwargs
# remove streaming and temperature
params.pop("streaming", None)
params.pop("temperature", None)

return OllamaEmbeddings(**params)
elif isinstance(self.llm_model, HuggingFace):
return HuggingFaceHubEmbeddings(model=self.llm_model.model)
elif isinstance(self.llm_model, Bedrock):
return BedrockEmbeddings(client=None, model_id=self.llm_model.model_id)
else:
raise ValueError("Embedding Model missing or not supported")

def _create_embedder(self, embedder_config: dict) -> object:
"""
Create an embedding model instance based on the configuration provided.

Args:
embedder_config (dict): Configuration parameters for the embedding model.

Returns:
object: An instance of the embedding model client.

Raises:
KeyError: If the model is not supported.
"""

# Instantiate the embedding model based on the model name
if "openai" in embedder_config["model"]:
return OpenAIEmbeddings(api_key=embedder_config["api_key"])

elif "azure" in embedder_config["model"]:
return AzureOpenAIEmbeddings()

elif "ollama" in embedder_config["model"]:
embedder_config["model"] = embedder_config["model"].split("/")[-1]
try:
models_tokens["ollama"][embedder_config["model"]]
except KeyError:
raise KeyError("Model not supported")
return OllamaEmbeddings(**embedder_config)

elif "hugging_face" in embedder_config["model"]:
try:
models_tokens["hugging_face"][embedder_config["model"]]
except KeyError:
raise KeyError("Model not supported")
return HuggingFaceHubEmbeddings(model=embedder_config["model"])

elif "bedrock" in embedder_config["model"]:
embedder_config["model"] = embedder_config["model"].split("/")[-1]
try:
models_tokens["bedrock"][embedder_config["model"]]
except KeyError:
raise KeyError("Model not supported")
return BedrockEmbeddings(client=None, model_id=embedder_config["model"])
else:
raise ValueError(
"Model provided by the configuration not supported")

def get_state(self, key=None) -> dict:
"""""
Expand Down
26 changes: 1 addition & 25 deletions scrapegraphai/nodes/rag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,31 +87,7 @@ def execute(self, state: dict) -> dict:
if self.verbose:
print("--- (updated chunks metadata) ---")

# check if embedder_model is provided, if not use llm_model
embedding_model = self.embedder_model if self.embedder_model else self.llm_model

if isinstance(embedding_model, OpenAI):
embeddings = OpenAIEmbeddings(
api_key=embedding_model.openai_api_key)
elif isinstance(embedding_model, AzureOpenAIEmbeddings):
embeddings = embedding_model
elif isinstance(embedding_model, AzureOpenAI):
embeddings = AzureOpenAIEmbeddings()
elif isinstance(embedding_model, Ollama):
# unwrap the kwargs from the model whihc is a dict
params = embedding_model._lc_kwargs
# remove streaming and temperature
params.pop("streaming", None)
params.pop("temperature", None)

embeddings = OllamaEmbeddings(**params)
elif isinstance(embedding_model, HuggingFace):
embeddings = HuggingFaceHubEmbeddings(model=embedding_model.model)
elif isinstance(embedding_model, Bedrock):
embeddings = BedrockEmbeddings(
client=None, model_id=embedding_model.model_id)
else:
raise ValueError("Embedding Model missing or not supported")
embeddings = self.embedder_model

retriever = FAISS.from_documents(
chunked_docs, embeddings).as_retriever()
Expand Down