Skip to content

Commit bed3eed

Browse files
committed
feat(multiple_search): working multiple example
1 parent 05e511e commit bed3eed

File tree

6 files changed

+53
-39
lines changed

6 files changed

+53
-39
lines changed

examples/openai/multiple_search_openai.py

Lines changed: 38 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,36 @@
1010
load_dotenv()
1111

1212

13+
schema= """{
14+
"Job Postings": {
15+
"Company x": [
16+
{
17+
"title": "...",
18+
"description": "...",
19+
"location": "...",
20+
"date_posted": "..",
21+
"requirements": ["...", "...", "..."]
22+
},
23+
{
24+
"title": "...",
25+
"description": "...",
26+
"location": "...",
27+
"date_posted": "..",
28+
"requirements": ["...", "...", "..."]
29+
}
30+
],
31+
"Company y": [
32+
{
33+
"title": "...",
34+
"description": "...",
35+
"location": "...",
36+
"date_posted": "..",
37+
"requirements": ["...", "...", "..."]
38+
}
39+
]
40+
}
41+
}"""
42+
1343
# ************************************************
1444
# Define the configuration for the graph
1545
# ************************************************
@@ -19,47 +49,23 @@
1949
graph_config = {
2050
"llm": {
2151
"api_key": openai_key,
22-
"model": "gpt-4o",
52+
"model": "gpt-3.5-turbo",
2353
},
2454
"verbose": True,
2555
"headless": False,
56+
"schema": schema,
2657
}
2758

28-
schema= """{
29-
"Job Postings": {
30-
"Company A": [
31-
{
32-
"title": "Software Engineer",
33-
"description": "Develop and maintain software applications.",
34-
"location": "New York, NY",
35-
"date_posted": "2024-05-01",
36-
"requirements": ["Python", "Django", "REST APIs"]
37-
},
38-
{
39-
"title": "Data Scientist",
40-
"description": "Analyze and interpret complex data.",
41-
"location": "San Francisco, CA",
42-
"date_posted": "2024-05-05",
43-
"requirements": ["Python", "Machine Learning", "SQL"]
44-
}
45-
],
46-
"Company B": [
47-
{
48-
"title": "Project Manager",
49-
"description": "Manage software development projects.",
50-
"location": "Boston, MA",
51-
"date_posted": "2024-04-20",
52-
"requirements": ["Project Management", "Agile", "Scrum"]
53-
}
54-
]
55-
}
56-
}"""
59+
5760

5861
multiple_search_graph = MultipleSearchGraph(
5962
prompt="List me all the projects with their description",
60-
source= ["https://perinim.github.io/projects/", "https://perinim.github.io/projects/"],
63+
source= [
64+
"https://www.linkedin.com/jobs/machine-learning-engineer-offerte-di-lavoro/?currentJobId=3889037104&originalSubdomain=it",
65+
"https://www.glassdoor.com/Job/italy-machine-learning-engineer-jobs-SRCH_IL.0,5_IN120_KO6,31.html",
66+
"https://it.indeed.com/jobs?q=ML+engineer&vjk=3c2e6d27601ffaaa"
67+
],
6168
config=graph_config,
62-
schema = schema
6369
)
6470

6571
result = multiple_search_graph.run()

scrapegraphai/graphs/abstract_graph.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,11 @@ class AbstractGraph(ABC):
4040
>>> result = my_graph.run()
4141
"""
4242

43-
def __init__(self, prompt: str, config: dict, source: Optional[str] = None, schema: Optional[dict]=None):
43+
def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
4444

4545
self.prompt = prompt
4646
self.source = source
4747
self.config = config
48-
self.schema = schema
4948
self.llm_model = self._create_llm(config["llm"], chat=True)
5049
self.embedder_model = self._create_default_embedder(llm_config=config["llm"]
5150
) if "embeddings" not in config else self._create_embedder(
@@ -62,13 +61,15 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None, sche
6261
self.headless = True if config is None else config.get(
6362
"headless", True)
6463
self.loader_kwargs = config.get("loader_kwargs", {})
64+
self.schema = config.get("schema", None)
6565

6666
common_params = {"headless": self.headless,
6767
"verbose": self.verbose,
6868
"loader_kwargs": self.loader_kwargs,
6969
"llm_model": self.llm_model,
7070
"embedder_model": self.embedder_model,
7171
"schema": self.schema}
72+
7273
self.set_common_params(common_params, overwrite=False)
7374

7475
def set_common_params(self, params: dict, overwrite=False):

scrapegraphai/graphs/multiple_search_graph.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from .smart_scraper_graph import SmartScraperGraph
1515

1616
from typing import List, Optional
17+
18+
1719
class MultipleSearchGraph(AbstractGraph):
1820
"""
1921
MultipleSearchGraph is a scraping pipeline that searches the internet for answers to a given prompt.
@@ -39,7 +41,7 @@ class MultipleSearchGraph(AbstractGraph):
3941
>>> result = search_graph.run()
4042
"""
4143

42-
def __init__(self, prompt: str, source: List[str], config: dict, schema:Optional[dict]= None):
44+
def __init__(self, prompt: str, source: List[str], config: dict):
4345

4446
self.max_results = config.get("max_results", 3)
4547

@@ -48,7 +50,7 @@ def __init__(self, prompt: str, source: List[str], config: dict, schema:Optional
4850
else:
4951
self.copy_config = deepcopy(config)
5052

51-
super().__init__(prompt, config)
53+
super().__init__(prompt, config, source)
5254

5355
def _create_graph(self) -> BaseGraph:
5456
"""
@@ -65,7 +67,7 @@ def _create_graph(self) -> BaseGraph:
6567
smart_scraper_instance = SmartScraperGraph(
6668
prompt="",
6769
source="",
68-
config=self.copy_config
70+
config=self.copy_config,
6971
)
7072

7173
# ************************************************
@@ -85,6 +87,7 @@ def _create_graph(self) -> BaseGraph:
8587
output=["answer"],
8688
node_config={
8789
"llm_model": self.llm_model,
90+
"schema": self.config.get("schema", None),
8891
}
8992
)
9093

scrapegraphai/graphs/smart_scraper_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ def _create_graph(self) -> BaseGraph:
8181
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
8282
output=["answer"],
8383
node_config={
84-
"llm_model": self.llm_model
84+
"llm_model": self.llm_model,
85+
"schema": self.config.get("schema", None),
8586
}
8687
)
8788

scrapegraphai/nodes/generate_answer_node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class GenerateAnswerNode(BaseNode):
3535

3636
def __init__(self, input: str, output: List[str], node_config: Optional[dict] = None,
3737
node_name: str = "GenerateAnswer"):
38-
print(node_config)
38+
3939
super().__init__(node_name, "node", input, output, 2, node_config)
4040

4141
self.llm_model = node_config["llm_model"]

scrapegraphai/nodes/merge_answers_node.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ def execute(self, state: dict) -> dict:
7979
You need to merge the content from the different websites into a single answer without repetitions (if there are any). \n
8080
The scraped contents are in a JSON format and you need to merge them based on the context and providing a correct JSON structure.\n
8181
OUTPUT INSTRUCTIONS: {format_instructions}\n
82+
You must format the output with the following schema, if not None:\n
83+
SCHEMA: {schema}\n
8284
USER PROMPT: {user_prompt}\n
8385
WEBSITE CONTENT: {website_content}
8486
"""
@@ -89,6 +91,7 @@ def execute(self, state: dict) -> dict:
8991
partial_variables={
9092
"format_instructions": format_instructions,
9193
"website_content": answers_str,
94+
"schema": self.node_config.get("schema", None),
9295
},
9396
)
9497

0 commit comments

Comments
 (0)