2
2
GraphIterator Module
3
3
"""
4
4
5
- from typing import List , Optional
5
+ import asyncio
6
6
import copy
7
- from tqdm import tqdm
7
+ from typing import List , Optional
8
+
9
+ from tqdm .asyncio import tqdm
10
+
8
11
from .base_node import BaseNode
9
12
10
13
14
+ _default_batchsize = 4
15
+
16
+
11
17
class GraphIteratorNode (BaseNode ):
12
18
"""
13
19
A node responsible for instantiating and running multiple graph instances in parallel.
@@ -23,12 +29,20 @@ class GraphIteratorNode(BaseNode):
23
29
node_name (str): The unique identifier name for the node, defaulting to "Parse".
24
30
"""
25
31
26
- def __init__ (self , input : str , output : List [str ], node_config : Optional [dict ]= None , node_name : str = "GraphIterator" ):
32
+ def __init__ (
33
+ self ,
34
+ input : str ,
35
+ output : List [str ],
36
+ node_config : Optional [dict ] = None ,
37
+ node_name : str = "GraphIterator" ,
38
+ ):
27
39
super ().__init__ (node_name , "node" , input , output , 2 , node_config )
28
40
29
- self .verbose = False if node_config is None else node_config .get ("verbose" , False )
41
+ self .verbose = (
42
+ False if node_config is None else node_config .get ("verbose" , False )
43
+ )
30
44
31
- def execute (self , state : dict ) -> dict :
45
+ def execute (self , state : dict ) -> dict :
32
46
"""
33
47
Executes the node's logic to instantiate and run multiple graph instances in parallel.
34
48
@@ -43,37 +57,78 @@ def execute(self, state: dict) -> dict:
43
57
KeyError: If the input keys are not found in the state, indicating that the
44
58
necessary information for running the graph instances is missing.
45
59
"""
60
+ batchsize = self .node_config .get ("batchsize" , _default_batchsize )
46
61
47
62
if self .verbose :
48
- print (f"--- Executing { self .node_name } Node ---" )
63
+ print (f"--- Executing { self .node_name } Node with batchsize { batchsize } ---" )
64
+
65
+ try :
66
+ eventloop = asyncio .get_event_loop ()
67
+ except RuntimeError :
68
+ eventloop = None
69
+
70
+ if eventloop and eventloop .is_running ():
71
+ state = eventloop .run_until_complete (self ._async_execute (state , batchsize ))
72
+ else :
73
+ state = asyncio .run (self ._async_execute (state , batchsize ))
74
+
75
+ return state
76
+
77
+ async def _async_execute (self , state : dict , batchsize : int ) -> dict :
78
+ """asynchronously executes the node's logic with multiple graph instances
79
+ running in parallel, using a semaphore of some size for concurrency regulation
80
+
81
+ Args:
82
+ state: The current state of the graph.
83
+ batchsize: The maximum number of concurrent instances allowed.
84
+
85
+ Returns:
86
+ The updated state with the output key containing the results
87
+ aggregated out of all parallel graph instances.
49
88
50
- # Interpret input keys based on the provided input expression
89
+ Raises:
90
+ KeyError: If the input keys are not found in the state.
91
+ """
92
+
93
+ # interprets input keys based on the provided input expression
51
94
input_keys = self .get_input_keys (state )
52
95
53
- # Fetching data from the state based on the input keys
96
+ # fetches data from the state based on the input keys
54
97
input_data = [state [key ] for key in input_keys ]
55
98
56
99
user_prompt = input_data [0 ]
57
100
urls = input_data [1 ]
58
101
59
102
graph_instance = self .node_config .get ("graph_instance" , None )
103
+
60
104
if graph_instance is None :
61
- raise ValueError ("Graph instance is required for graph iteration. " )
62
-
63
- # set the prompt and source for each url
105
+ raise ValueError ("graph instance is required for concurrent execution " )
106
+
107
+ # sets the prompt for the graph instance
64
108
graph_instance .prompt = user_prompt
65
- graphs_instances = []
109
+
110
+ participants = []
111
+
112
+ # semaphore to limit the number of concurrent tasks
113
+ semaphore = asyncio .Semaphore (batchsize )
114
+
115
+ async def _async_run (graph ):
116
+ async with semaphore :
117
+ return await asyncio .to_thread (graph .run )
118
+
119
+ # creates a deepcopy of the graph instance for each endpoint
66
120
for url in urls :
67
- # make a copy of the graph instance
68
- copy_graph_instance = copy .copy (graph_instance )
69
- copy_graph_instance .source = url
70
- graphs_instances .append (copy_graph_instance )
71
-
72
- # run the graph for each url and use tqdm for progress bar
73
- graphs_answers = []
74
- for graph in tqdm (graphs_instances , desc = "Processing Graph Instances" , disable = not self .verbose ):
75
- result = graph .run ()
76
- graphs_answers .append (result )
77
-
78
- state .update ({self .output [0 ]: graphs_answers })
121
+ instance = copy .deepcopy (graph_instance )
122
+ instance .source = url
123
+
124
+ participants .append (instance )
125
+
126
+ futures = [_async_run (graph ) for graph in participants ]
127
+
128
+ answers = await tqdm .gather (
129
+ * futures , desc = "processing graph instances" , disable = not self .verbose
130
+ )
131
+
132
+ state .update ({self .output [0 ]: answers })
133
+
79
134
return state
0 commit comments