Skip to content

Commit 7599234

Browse files
committed
feat: Enable end users to pass model instances of HuggingFaceHub
1 parent 98dec36 commit 7599234

File tree

4 files changed

+77
-0
lines changed

4 files changed

+77
-0
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""
2+
Basic example of scraping pipeline using SmartScraper using Azure OpenAI Key
3+
"""
4+
5+
import os
6+
from dotenv import load_dotenv
7+
from scrapegraphai.graphs import SmartScraperGraph
8+
from scrapegraphai.utils import prettify_exec_info
9+
from langchain_community.llms import HuggingFaceEndpoint
10+
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
11+
12+
13+
14+
15+
## required environment variable in .env
16+
#HUGGINGFACEHUB_API_TOKEN
17+
load_dotenv()
18+
19+
HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
20+
# ************************************************
21+
# Initialize the model instances
22+
# ************************************************
23+
24+
repo_id = "mistralai/Mistral-7B-Instruct-v0.2"
25+
26+
llm_model_instance = HuggingFaceEndpoint(
27+
repo_id=repo_id, max_length=128, temperature=0.5, token=HUGGINGFACEHUB_API_TOKEN
28+
)
29+
30+
31+
32+
33+
embedder_model_instance = HuggingFaceInferenceAPIEmbeddings(
34+
api_key=HUGGINGFACEHUB_API_TOKEN, model_name="sentence-transformers/all-MiniLM-l6-v2"
35+
)
36+
37+
# ************************************************
38+
# Create the SmartScraperGraph instance and run it
39+
# ************************************************
40+
41+
graph_config = {
42+
"llm": {"model_instance": llm_model_instance},
43+
"embeddings": {"model_instance": embedder_model_instance}
44+
}
45+
46+
smart_scraper_graph = SmartScraperGraph(
47+
prompt="List me all the events, with the following fields: company_name, event_name, event_start_date, event_start_time, event_end_date, event_end_time, location, event_mode, event_category, third_party_redirect, no_of_days, time_in_hours, hosted_or_attending, refreshments_type, registration_available, registration_link",
48+
# also accepts a string with the already downloaded HTML code
49+
source="https://www.hmhco.com/event",
50+
config=graph_config
51+
)
52+
53+
result = smart_scraper_graph.run()
54+
print(result)
55+
56+
# ************************************************
57+
# Get graph execution info
58+
# ************************************************
59+
60+
graph_exec_info = smart_scraper_graph.get_execution_info()
61+
print(prettify_exec_info(graph_exec_info))
62+
63+

scrapegraphai/graphs/abstract_graph.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,13 @@ def _set_model_token(self, llm):
6464
self.model_token = models_tokens["azure"][llm.model_name]
6565
except KeyError:
6666
raise KeyError("Model not supported")
67+
68+
elif 'HuggingFaceEndpoint' in str(type(llm)):
69+
if 'mistral' in llm.repo_id:
70+
try:
71+
self.model_token = models_tokens['mistral'][llm.repo_id]
72+
except KeyError:
73+
raise KeyError("Model not supported")
6774

6875

6976
def _create_llm(self, llm_config: dict, chat=False) -> object:

scrapegraphai/helpers/models_tokens.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,5 +65,8 @@
6565
"mistral.mistral-large-2402-v1:0": 32768,
6666
"cohere.embed-english-v3": 512,
6767
"cohere.embed-multilingual-v3": 512
68+
},
69+
"mistral": {
70+
"mistralai/Mistral-7B-Instruct-v0.2": 32000
6871
}
6972
}

scrapegraphai/nodes/rag_node.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from langchain_community.vectorstores import FAISS
1313
from langchain_community.embeddings import OllamaEmbeddings
1414
from langchain_openai import OpenAIEmbeddings, AzureOpenAIEmbeddings
15+
from langchain_community.embeddings.huggingface import HuggingFaceInferenceAPIEmbeddings
1516

1617
from ..models import OpenAI, Ollama, AzureOpenAI, HuggingFace, Bedrock
1718
from .base_node import BaseNode
@@ -95,6 +96,9 @@ def execute(self, state: dict) -> dict:
9596
api_key=embedding_model.openai_api_key)
9697
elif isinstance(embedding_model, AzureOpenAIEmbeddings):
9798
embeddings = embedding_model
99+
elif isinstance(embedding_model, HuggingFaceInferenceAPIEmbeddings):
100+
embeddings = embedding_model
101+
98102
elif isinstance(embedding_model, AzureOpenAI):
99103
embeddings = AzureOpenAIEmbeddings()
100104
elif isinstance(embedding_model, Ollama):

0 commit comments

Comments
 (0)