Skip to content

Commit a44a2e7

Browse files
authored
Merge pull request #148 from shorthills-ai/pre/beta
2 parents d277b34 + d05093a commit a44a2e7

File tree

4 files changed

+78
-1
lines changed

4 files changed

+78
-1
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""
2+
Basic example of scraping pipeline using SmartScraper using Azure OpenAI Key
3+
"""
4+
5+
import os
6+
from dotenv import load_dotenv
7+
from scrapegraphai.graphs import SmartScraperGraph
8+
from scrapegraphai.utils import prettify_exec_info
9+
from langchain_community.llms import HuggingFaceEndpoint
10+
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
11+
12+
13+
14+
15+
## required environment variable in .env
16+
#HUGGINGFACEHUB_API_TOKEN
17+
load_dotenv()
18+
19+
HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
20+
# ************************************************
21+
# Initialize the model instances
22+
# ************************************************
23+
24+
repo_id = "mistralai/Mistral-7B-Instruct-v0.2"
25+
26+
llm_model_instance = HuggingFaceEndpoint(
27+
repo_id=repo_id, max_length=128, temperature=0.5, token=HUGGINGFACEHUB_API_TOKEN
28+
)
29+
30+
31+
32+
33+
embedder_model_instance = HuggingFaceInferenceAPIEmbeddings(
34+
api_key=HUGGINGFACEHUB_API_TOKEN, model_name="sentence-transformers/all-MiniLM-l6-v2"
35+
)
36+
37+
# ************************************************
38+
# Create the SmartScraperGraph instance and run it
39+
# ************************************************
40+
41+
graph_config = {
42+
"llm": {"model_instance": llm_model_instance},
43+
"embeddings": {"model_instance": embedder_model_instance}
44+
}
45+
46+
smart_scraper_graph = SmartScraperGraph(
47+
prompt="List me all the events, with the following fields: company_name, event_name, event_start_date, event_start_time, event_end_date, event_end_time, location, event_mode, event_category, third_party_redirect, no_of_days, time_in_hours, hosted_or_attending, refreshments_type, registration_available, registration_link",
48+
# also accepts a string with the already downloaded HTML code
49+
source="https://www.hmhco.com/event",
50+
config=graph_config
51+
)
52+
53+
result = smart_scraper_graph.run()
54+
print(result)
55+
56+
# ************************************************
57+
# Get graph execution info
58+
# ************************************************
59+
60+
graph_exec_info = smart_scraper_graph.get_execution_info()
61+
print(prettify_exec_info(graph_exec_info))
62+
63+

scrapegraphai/graphs/abstract_graph.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,13 @@ def _set_model_token(self, llm):
6969
self.model_token = models_tokens["azure"][llm.model_name]
7070
except KeyError:
7171
raise KeyError("Model not supported")
72+
73+
elif 'HuggingFaceEndpoint' in str(type(llm)):
74+
if 'mistral' in llm.repo_id:
75+
try:
76+
self.model_token = models_tokens['mistral'][llm.repo_id]
77+
except KeyError:
78+
raise KeyError("Model not supported")
7279

7380

7481
def _create_llm(self, llm_config: dict, chat=False) -> object:
@@ -181,7 +188,6 @@ def _create_default_embedder(self) -> object:
181188
Raises:
182189
ValueError: If the model is not supported.
183190
"""
184-
185191
if isinstance(self.llm_model, OpenAI):
186192
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
187193
elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
@@ -216,6 +222,9 @@ def _create_embedder(self, embedder_config: dict) -> object:
216222
Raises:
217223
KeyError: If the model is not supported.
218224
"""
225+
226+
if 'model_instance' in embedder_config:
227+
return embedder_config['model_instance']
219228

220229
# Instantiate the embedding model based on the model name
221230
if "openai" in embedder_config["model"]:

scrapegraphai/helpers/models_tokens.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,5 +65,8 @@
6565
"mistral.mistral-large-2402-v1:0": 32768,
6666
"cohere.embed-english-v3": 512,
6767
"cohere.embed-multilingual-v3": 512
68+
},
69+
"mistral": {
70+
"mistralai/Mistral-7B-Instruct-v0.2": 32000
6871
}
6972
}

scrapegraphai/nodes/rag_node.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ def execute(self, state: dict) -> dict:
8282
if self.verbose:
8383
print("--- (updated chunks metadata) ---")
8484

85+
# check if embedder_model is provided, if not use llm_model
86+
self.embedder_model = self.embedder_model if self.embedder_model else self.llm_model
8587
embeddings = self.embedder_model
8688

8789
retriever = FAISS.from_documents(

0 commit comments

Comments
 (0)