Skip to content

Commit 627cbee

Browse files
committed
feat(parallel-exeuction): add asyncio event loop dispatcher with semaphore for parallel graph instances
TODO: still untested
1 parent 7ae50c0 commit 627cbee

File tree

1 file changed

+79
-24
lines changed

1 file changed

+79
-24
lines changed

scrapegraphai/nodes/graph_iterator_node.py

Lines changed: 79 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,18 @@
22
GraphIterator Module
33
"""
44

5-
from typing import List, Optional
5+
import asyncio
66
import copy
7-
from tqdm import tqdm
7+
from typing import List, Optional
8+
9+
from tqdm.asyncio import tqdm
10+
811
from .base_node import BaseNode
912

1013

14+
_default_batchsize = 4
15+
16+
1117
class GraphIteratorNode(BaseNode):
1218
"""
1319
A node responsible for instantiating and running multiple graph instances in parallel.
@@ -23,12 +29,20 @@ class GraphIteratorNode(BaseNode):
2329
node_name (str): The unique identifier name for the node, defaulting to "Parse".
2430
"""
2531

26-
def __init__(self, input: str, output: List[str], node_config: Optional[dict]=None, node_name: str = "GraphIterator"):
32+
def __init__(
33+
self,
34+
input: str,
35+
output: List[str],
36+
node_config: Optional[dict] = None,
37+
node_name: str = "GraphIterator",
38+
):
2739
super().__init__(node_name, "node", input, output, 2, node_config)
2840

29-
self.verbose = False if node_config is None else node_config.get("verbose", False)
41+
self.verbose = (
42+
False if node_config is None else node_config.get("verbose", False)
43+
)
3044

31-
def execute(self, state: dict) -> dict:
45+
def execute(self, state: dict) -> dict:
3246
"""
3347
Executes the node's logic to instantiate and run multiple graph instances in parallel.
3448
@@ -43,37 +57,78 @@ def execute(self, state: dict) -> dict:
4357
KeyError: If the input keys are not found in the state, indicating that the
4458
necessary information for running the graph instances is missing.
4559
"""
60+
batchsize = self.node_config.get("batchsize", _default_batchsize)
4661

4762
if self.verbose:
48-
print(f"--- Executing {self.node_name} Node ---")
63+
print(f"--- Executing {self.node_name} Node with batchsize {batchsize} ---")
64+
65+
try:
66+
eventloop = asyncio.get_event_loop()
67+
except RuntimeError:
68+
eventloop = None
69+
70+
if eventloop and eventloop.is_running():
71+
state = eventloop.run_until_complete(self._async_execute(state, batchsize))
72+
else:
73+
state = asyncio.run(self._async_execute(state, batchsize))
74+
75+
return state
76+
77+
async def _async_execute(self, state: dict, batchsize: int) -> dict:
78+
"""asynchronously executes the node's logic with multiple graph instances
79+
running in parallel, using a semaphore of some size for concurrency regulation
80+
81+
Args:
82+
state: The current state of the graph.
83+
batchsize: The maximum number of concurrent instances allowed.
84+
85+
Returns:
86+
The updated state with the output key containing the results
87+
aggregated out of all parallel graph instances.
4988
50-
# Interpret input keys based on the provided input expression
89+
Raises:
90+
KeyError: If the input keys are not found in the state.
91+
"""
92+
93+
# interprets input keys based on the provided input expression
5194
input_keys = self.get_input_keys(state)
5295

53-
# Fetching data from the state based on the input keys
96+
# fetches data from the state based on the input keys
5497
input_data = [state[key] for key in input_keys]
5598

5699
user_prompt = input_data[0]
57100
urls = input_data[1]
58101

59102
graph_instance = self.node_config.get("graph_instance", None)
103+
60104
if graph_instance is None:
61-
raise ValueError("Graph instance is required for graph iteration.")
62-
63-
# set the prompt and source for each url
105+
raise ValueError("graph instance is required for concurrent execution")
106+
107+
# sets the prompt for the graph instance
64108
graph_instance.prompt = user_prompt
65-
graphs_instances = []
109+
110+
participants = []
111+
112+
# semaphore to limit the number of concurrent tasks
113+
semaphore = asyncio.Semaphore(batchsize)
114+
115+
async def _async_run(graph):
116+
async with semaphore:
117+
return await asyncio.to_thread(graph.run)
118+
119+
# creates a deepcopy of the graph instance for each endpoint
66120
for url in urls:
67-
# make a copy of the graph instance
68-
copy_graph_instance = copy.copy(graph_instance)
69-
copy_graph_instance.source = url
70-
graphs_instances.append(copy_graph_instance)
71-
72-
# run the graph for each url and use tqdm for progress bar
73-
graphs_answers = []
74-
for graph in tqdm(graphs_instances, desc="Processing Graph Instances", disable=not self.verbose):
75-
result = graph.run()
76-
graphs_answers.append(result)
77-
78-
state.update({self.output[0]: graphs_answers})
121+
instance = copy.deepcopy(graph_instance)
122+
instance.source = url
123+
124+
participants.append(instance)
125+
126+
futures = [_async_run(graph) for graph in participants]
127+
128+
answers = await tqdm.gather(
129+
*futures, desc="processing graph instances", disable=not self.verbose
130+
)
131+
132+
state.update({self.output[0]: answers})
133+
79134
return state

0 commit comments

Comments
 (0)