Skip to content

Commit b86aac2

Browse files
committed
feat: Allow end users to pass model instances for llm and embedding model
1 parent 40b2a34 commit b86aac2

File tree

4 files changed

+88
-3
lines changed

4 files changed

+88
-3
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""
2+
Basic example of scraping pipeline using SmartScraper using Azure OpenAI Key
3+
"""
4+
5+
import os
6+
from dotenv import load_dotenv
7+
from langchain_openai import AzureChatOpenAI
8+
from langchain_openai import AzureOpenAIEmbeddings
9+
from scrapegraphai.graphs import SmartScraperGraph
10+
from scrapegraphai.utils import prettify_exec_info
11+
12+
13+
## required environment variable in .env
14+
# AZURE_OPENAI_ENDPOINT
15+
# AZURE_OPENAI_CHAT_DEPLOYMENT_NAME
16+
# MODEL_NAME
17+
# AZURE_OPENAI_API_KEY
18+
# OPENAI_API_TYPE
19+
# AZURE_OPENAI_API_VERSION
20+
# AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT_NAME
21+
load_dotenv()
22+
23+
24+
# ************************************************
25+
# Initialize the model instances
26+
# ************************************************
27+
28+
llm_model_instance = AzureChatOpenAI(
29+
openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
30+
azure_deployment=os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"]
31+
)
32+
33+
embedder_model_instance = AzureOpenAIEmbeddings(
34+
azure_deployment=os.environ["AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT_NAME"],
35+
openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
36+
)
37+
38+
# ************************************************
39+
# Create the SmartScraperGraph instance and run it
40+
# ************************************************
41+
42+
graph_config = {
43+
"llm": {"model_instance": llm_model_instance},
44+
"embeddings": {"model_instance": embedder_model_instance}
45+
}
46+
47+
smart_scraper_graph = SmartScraperGraph(
48+
prompt="List me all the events, with the following fields: company_name, event_name, event_start_date, event_start_time, event_end_date, event_end_time, location, event_mode, event_category, third_party_redirect, no_of_days,
49+
time_in_hours, hosted_or_attending, refreshments_type, registration_available, registration_link",
50+
# also accepts a string with the already downloaded HTML code
51+
source="https://www.hmhco.com/event",
52+
config=graph_config
53+
)
54+
55+
result = smart_scraper_graph.run()
56+
print(result)
57+
58+
# ************************************************
59+
# Get graph execution info
60+
# ************************************************
61+
62+
graph_exec_info = smart_scraper_graph.get_execution_info()
63+
print(prettify_exec_info(graph_exec_info))

scrapegraphai/graphs/abstract_graph.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
1919
self.prompt = prompt
2020
self.source = source
2121
self.config = config
22-
self.llm_model = self._create_llm(config["llm"])
22+
self.llm_model = self._create_llm(config["llm"], chat=True)
2323
self.embedder_model = self.llm_model if "embeddings" not in config else self._create_llm(
2424
config["embeddings"])
2525

@@ -32,7 +32,16 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
3232
self.final_state = None
3333
self.execution_info = None
3434

35-
def _create_llm(self, llm_config: dict):
35+
def _set_model_token(self, llm):
36+
37+
if 'Azure' in str(type(llm)):
38+
try:
39+
self.model_token = models_tokens["azure"][llm.model_name]
40+
except KeyError:
41+
raise KeyError("Model not supported")
42+
43+
44+
def _create_llm(self, llm_config: dict, chat=False) -> object:
3645
"""
3746
Creates an instance of the language model (OpenAI or Gemini) based on configuration.
3847
"""
@@ -42,6 +51,12 @@ def _create_llm(self, llm_config: dict):
4251
}
4352
llm_params = {**llm_defaults, **llm_config}
4453

54+
# If model instance is passed directly instead of the model details
55+
if 'model_instance' in llm_params:
56+
if chat:
57+
self._set_model_token(llm_params['model_instance'])
58+
return llm_params['model_instance']
59+
4560
# Instantiate the language model based on the model name
4661
if "gpt-" in llm_params["model"]:
4762
try:
@@ -129,3 +144,4 @@ def run(self) -> str:
129144
Abstract method to execute the graph and return the result.
130145
"""
131146
pass
147+

scrapegraphai/helpers/models_tokens.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
"gpt-4-32k": 32768,
1919
"gpt-4-32k-0613": 32768,
2020
},
21-
21+
"azure": {
22+
"gpt-3.5-turbo": 4096
23+
},
2224
"gemini": {
2325
"gemini-pro": 128000,
2426
},
@@ -45,3 +47,4 @@
4547
"claude3": 200000
4648
}
4749
}
50+

scrapegraphai/nodes/rag_node.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ def execute(self, state):
9292
if isinstance(embedding_model, OpenAI):
9393
embeddings = OpenAIEmbeddings(
9494
api_key=embedding_model.openai_api_key)
95+
elif isinstance(embedding_model, AzureOpenAIEmbeddings):
96+
embeddings = embedding_model
9597
elif isinstance(embedding_model, AzureOpenAI):
9698
embeddings = AzureOpenAIEmbeddings()
9799
elif isinstance(embedding_model, Ollama):
@@ -133,3 +135,4 @@ def execute(self, state):
133135

134136
state.update({self.output[0]: compressed_docs})
135137
return state
138+

0 commit comments

Comments
 (0)