Skip to content

Commit 039ba2e

Browse files
fix: Fixed pydantic error on SearchGraphs
Changed instatiation location of iterated graph classes
1 parent 88b2c46 commit 039ba2e

File tree

4 files changed

+45
-33
lines changed

4 files changed

+45
-33
lines changed

scrapegraphai/graphs/omni_search_graph.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,12 @@ def _create_graph(self) -> BaseGraph:
6161
BaseGraph: A graph instance representing the web scraping and searching workflow.
6262
"""
6363

64-
omni_scraper_instance = OmniScraperGraph(
65-
prompt="",
66-
source="",
67-
config=self.copy_config,
68-
schema=self.copy_schema
69-
)
64+
# omni_scraper_instance = OmniScraperGraph(
65+
# prompt="",
66+
# source="",
67+
# config=self.copy_config,
68+
# schema=self.copy_schema
69+
# )
7070

7171
search_internet_node = SearchInternetNode(
7272
input="user_prompt",
@@ -81,8 +81,10 @@ def _create_graph(self) -> BaseGraph:
8181
input="user_prompt & urls",
8282
output=["results"],
8383
node_config={
84-
"graph_instance": omni_scraper_instance,
85-
}
84+
"graph_instance": OmniScraperGraph,
85+
"scraper_config": self.copy_config,
86+
},
87+
schema=self.copy_schema
8688
)
8789

8890
merge_answers_node = MergeAnswersNode(

scrapegraphai/graphs/search_graph.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,12 @@ def _create_graph(self) -> BaseGraph:
6262
BaseGraph: A graph instance representing the web scraping and searching workflow.
6363
"""
6464

65-
smart_scraper_instance = SmartScraperGraph(
66-
prompt="",
67-
source="",
68-
config=self.copy_config,
69-
schema=self.copy_schema
70-
)
65+
# smart_scraper_instance = SmartScraperGraph(
66+
# prompt="",
67+
# source="",
68+
# config=self.copy_config,
69+
# schema=self.copy_schema
70+
# )
7171

7272
search_internet_node = SearchInternetNode(
7373
input="user_prompt",
@@ -82,8 +82,10 @@ def _create_graph(self) -> BaseGraph:
8282
input="user_prompt & urls",
8383
output=["results"],
8484
node_config={
85-
"graph_instance": smart_scraper_instance,
86-
}
85+
"graph_instance": SmartScraperGraph,
86+
"scraper_config": self.copy_config
87+
},
88+
schema=self.copy_schema
8789
)
8890

8991
merge_answers_node = MergeAnswersNode(

scrapegraphai/nodes/generate_answer_node.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,14 @@ def execute(self, state: dict) -> dict:
8989
doc = input_data[1]
9090

9191
if self.node_config.get("schema", None) is not None:
92-
92+
9393
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
9494
self.llm_model = self.llm_model.with_structured_output(
95-
schema = self.node_config["schema"],
96-
method="function_calling") # json schema works only on specific models
95+
schema = self.node_config["schema"]) # json schema works only on specific models
9796

9897
# default parser to empty lambda function
99-
output_parser = lambda x: x
98+
def output_parser(x):
99+
return x
100100
if is_basemodel_subclass(self.node_config["schema"]):
101101
output_parser = dict
102102
format_instructions = "NA"

scrapegraphai/nodes/graph_iterator_node.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22
GraphIterator Module
33
"""
44
import asyncio
5-
import copy
65
from typing import List, Optional
76
from tqdm.asyncio import tqdm
8-
from ..utils.logging import get_logger
97
from .base_node import BaseNode
8+
from langchain_core.pydantic_v1 import BaseModel
109

1110
DEFAULT_BATCHSIZE = 16
1211

@@ -31,12 +30,14 @@ def __init__(
3130
output: List[str],
3231
node_config: Optional[dict] = None,
3332
node_name: str = "GraphIterator",
33+
schema: Optional[BaseModel] = None,
3434
):
3535
super().__init__(node_name, "node", input, output, 2, node_config)
3636

3737
self.verbose = (
3838
False if node_config is None else node_config.get("verbose", False)
3939
)
40+
self.schema = schema
4041

4142
def execute(self, state: dict) -> dict:
4243
"""
@@ -97,16 +98,24 @@ async def _async_execute(self, state: dict, batchsize: int) -> dict:
9798
urls = input_data[1]
9899

99100
graph_instance = self.node_config.get("graph_instance", None)
101+
scraper_config = self.node_config.get("scraper_config", None)
100102

101103
if graph_instance is None:
102104
raise ValueError("graph instance is required for concurrent execution")
103105

104-
if "graph_depth" in graph_instance.config:
105-
graph_instance.config["graph_depth"] += 1
106-
else:
107-
graph_instance.config["graph_depth"] = 1
106+
graph_instance = [graph_instance(
107+
prompt="",
108+
source="",
109+
config=scraper_config,
110+
schema=self.schema) for _ in range(len(urls))]
111+
112+
for graph in graph_instance:
113+
if "graph_depth" in graph.config:
114+
graph.config["graph_depth"] += 1
115+
else:
116+
graph.config["graph_depth"] = 1
108117

109-
graph_instance.prompt = user_prompt
118+
graph.prompt = user_prompt
110119

111120
participants = []
112121

@@ -116,13 +125,12 @@ async def _async_run(graph):
116125
async with semaphore:
117126
return await asyncio.to_thread(graph.run)
118127

119-
for url in urls:
120-
instance = copy.copy(graph_instance)
121-
instance.source = url
128+
for url, graph in zip(urls, graph_instance):
129+
graph.source = url
122130
if url.startswith("http"):
123-
instance.input_key = "url"
124-
participants.append(instance)
125-
131+
graph.input_key = "url"
132+
participants.append(graph)
133+
126134
futures = [_async_run(graph) for graph in participants]
127135

128136
answers = await tqdm.gather(

0 commit comments

Comments
 (0)