Skip to content

Enable end users to pass model instances of HuggingFaceHub #148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions examples/huggingfacehub/smart_scraper_huggingfacehub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""
Basic example of scraping pipeline using SmartScraper using Azure OpenAI Key
"""

import os
from dotenv import load_dotenv
from scrapegraphai.graphs import SmartScraperGraph
from scrapegraphai.utils import prettify_exec_info
from langchain_community.llms import HuggingFaceEndpoint
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings




## required environment variable in .env
#HUGGINGFACEHUB_API_TOKEN
load_dotenv()

HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
# ************************************************
# Initialize the model instances
# ************************************************

repo_id = "mistralai/Mistral-7B-Instruct-v0.2"

llm_model_instance = HuggingFaceEndpoint(
repo_id=repo_id, max_length=128, temperature=0.5, token=HUGGINGFACEHUB_API_TOKEN
)




embedder_model_instance = HuggingFaceInferenceAPIEmbeddings(
api_key=HUGGINGFACEHUB_API_TOKEN, model_name="sentence-transformers/all-MiniLM-l6-v2"
)

# ************************************************
# Create the SmartScraperGraph instance and run it
# ************************************************

graph_config = {
"llm": {"model_instance": llm_model_instance},
"embeddings": {"model_instance": embedder_model_instance}
}

smart_scraper_graph = SmartScraperGraph(
prompt="List me all the events, with the following fields: company_name, event_name, event_start_date, event_start_time, event_end_date, event_end_time, location, event_mode, event_category, third_party_redirect, no_of_days, time_in_hours, hosted_or_attending, refreshments_type, registration_available, registration_link",
# also accepts a string with the already downloaded HTML code
source="https://www.hmhco.com/event",
config=graph_config
)

result = smart_scraper_graph.run()
print(result)

# ************************************************
# Get graph execution info
# ************************************************

graph_exec_info = smart_scraper_graph.get_execution_info()
print(prettify_exec_info(graph_exec_info))


11 changes: 10 additions & 1 deletion scrapegraphai/graphs/abstract_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ def _set_model_token(self, llm):
self.model_token = models_tokens["azure"][llm.model_name]
except KeyError:
raise KeyError("Model not supported")

elif 'HuggingFaceEndpoint' in str(type(llm)):
if 'mistral' in llm.repo_id:
try:
self.model_token = models_tokens['mistral'][llm.repo_id]
except KeyError:
raise KeyError("Model not supported")


def _create_llm(self, llm_config: dict, chat=False) -> object:
Expand Down Expand Up @@ -181,7 +188,6 @@ def _create_default_embedder(self) -> object:
Raises:
ValueError: If the model is not supported.
"""

if isinstance(self.llm_model, OpenAI):
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
Expand Down Expand Up @@ -216,6 +222,9 @@ def _create_embedder(self, embedder_config: dict) -> object:
Raises:
KeyError: If the model is not supported.
"""

if 'model_instance' in embedder_config:
return embedder_config['model_instance']

# Instantiate the embedding model based on the model name
if "openai" in embedder_config["model"]:
Expand Down
3 changes: 3 additions & 0 deletions scrapegraphai/helpers/models_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,8 @@
"mistral.mistral-large-2402-v1:0": 32768,
"cohere.embed-english-v3": 512,
"cohere.embed-multilingual-v3": 512
},
"mistral": {
"mistralai/Mistral-7B-Instruct-v0.2": 32000
}
}
2 changes: 2 additions & 0 deletions scrapegraphai/nodes/rag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def execute(self, state: dict) -> dict:
if self.verbose:
print("--- (updated chunks metadata) ---")

# check if embedder_model is provided, if not use llm_model
self.embedder_model = self.embedder_model if self.embedder_model else self.llm_model
embeddings = self.embedder_model

retriever = FAISS.from_documents(
Expand Down