6
6
import warnings
7
7
from langchain_community .callbacks import get_openai_callback
8
8
from typing import Tuple
9
+ from collections import deque
9
10
10
11
11
12
class BaseGraph :
@@ -26,6 +27,8 @@ class BaseGraph:
26
27
27
28
Raises:
28
29
Warning: If the entry point node is not the first node in the list.
30
+ ValueError: If conditional_node does not have exactly two outgoing edges
31
+
29
32
30
33
Example:
31
34
>>> BaseGraph(
@@ -48,7 +51,7 @@ def __init__(self, nodes: list, edges: list, entry_point: str):
48
51
49
52
self .nodes = nodes
50
53
self .edges = self ._create_edges ({e for e in edges })
51
- self .entry_point = entry_point . node_name
54
+ self .entry_point = entry_point
52
55
53
56
if nodes [0 ].node_name != entry_point .node_name :
54
57
# raise a warning if the entry point is not the first node in the list
@@ -68,13 +71,16 @@ def _create_edges(self, edges: list) -> dict:
68
71
69
72
edge_dict = {}
70
73
for from_node , to_node in edges :
71
- edge_dict [from_node .node_name ] = to_node .node_name
74
+ if from_node in edge_dict :
75
+ edge_dict [from_node ].append (to_node )
76
+ else :
77
+ edge_dict [from_node ] = [to_node ]
72
78
return edge_dict
73
79
74
80
def execute (self , initial_state : dict ) -> Tuple [dict , list ]:
75
81
"""
76
- Executes the graph by traversing nodes starting from the entry point. The execution
77
- follows the edges based on the result of each node's execution and continues until
82
+ Executes the graph by traversing nodes in breadth-first order starting from the entry point.
83
+ The execution follows the edges based on the result of each node's execution and continues until
78
84
it reaches a node with no outgoing edges.
79
85
80
86
Args:
@@ -84,7 +90,6 @@ def execute(self, initial_state: dict) -> Tuple[dict, list]:
84
90
Tuple[dict, list]: A tuple containing the final state and a list of execution info.
85
91
"""
86
92
87
- current_node_name = self .nodes [0 ]
88
93
state = initial_state
89
94
90
95
# variables for tracking execution info
@@ -98,23 +103,22 @@ def execute(self, initial_state: dict) -> Tuple[dict, list]:
98
103
"total_cost_USD" : 0.0 ,
99
104
}
100
105
101
- for index in self .nodes :
102
-
106
+ queue = deque ([self .entry_point ])
107
+ while queue :
108
+ current_node = queue .popleft ()
103
109
curr_time = time .time ()
104
- current_node = index
105
-
106
- with get_openai_callback () as cb :
110
+ with get_openai_callback () as callback :
107
111
result = current_node .execute (state )
108
112
node_exec_time = time .time () - curr_time
109
113
total_exec_time += node_exec_time
110
114
111
115
cb = {
112
- "node_name" : index .node_name ,
113
- "total_tokens" : cb .total_tokens ,
114
- "prompt_tokens" : cb .prompt_tokens ,
115
- "completion_tokens" : cb .completion_tokens ,
116
- "successful_requests" : cb .successful_requests ,
117
- "total_cost_USD" : cb .total_cost ,
116
+ "node_name" : current_node .node_name ,
117
+ "total_tokens" : callback .total_tokens ,
118
+ "prompt_tokens" : callback .prompt_tokens ,
119
+ "completion_tokens" : callback .completion_tokens ,
120
+ "successful_requests" : callback .successful_requests ,
121
+ "total_cost_USD" : callback .total_cost ,
118
122
"exec_time" : node_exec_time ,
119
123
}
120
124
@@ -128,21 +132,30 @@ def execute(self, initial_state: dict) -> Tuple[dict, list]:
128
132
cb_total ["successful_requests" ] += cb ["successful_requests" ]
129
133
cb_total ["total_cost_USD" ] += cb ["total_cost_USD" ]
130
134
131
- if current_node .node_type == "conditional_node" :
132
- current_node_name = result
133
- elif current_node_name in self .edges :
134
- current_node_name = self .edges [current_node_name ]
135
- else :
136
- current_node_name = None
137
-
138
- exec_info .append ({
139
- "node_name" : "TOTAL RESULT" ,
140
- "total_tokens" : cb_total ["total_tokens" ],
141
- "prompt_tokens" : cb_total ["prompt_tokens" ],
142
- "completion_tokens" : cb_total ["completion_tokens" ],
143
- "successful_requests" : cb_total ["successful_requests" ],
144
- "total_cost_USD" : cb_total ["total_cost_USD" ],
145
- "exec_time" : total_exec_time ,
146
- })
135
+ if current_node in self .edges :
136
+ current_node_connections = self .edges [current_node ]
137
+ if current_node .node_type == 'conditional_node' :
138
+ # Assert that there are exactly two out edges from the conditional node
139
+ if len (current_node_connections ) != 2 :
140
+ raise ValueError (f"Conditional node should have exactly two out connections { current_node_connections .node_name } " )
141
+ if result ["next_node" ] == 0 :
142
+ queue .append (current_node_connections [0 ])
143
+ else :
144
+ queue .append (current_node_connections [1 ])
145
+ # remove the conditional node result
146
+ del result ["next_node" ]
147
+ else :
148
+ queue .extend (node for node in current_node_connections )
149
+
150
+
151
+ exec_info .append ({
152
+ "node_name" : "TOTAL RESULT" ,
153
+ "total_tokens" : cb_total ["total_tokens" ],
154
+ "prompt_tokens" : cb_total ["prompt_tokens" ],
155
+ "completion_tokens" : cb_total ["completion_tokens" ],
156
+ "successful_requests" : cb_total ["successful_requests" ],
157
+ "total_cost_USD" : cb_total ["total_cost_USD" ],
158
+ "exec_time" : total_exec_time ,
159
+ })
147
160
148
161
return state , exec_info
0 commit comments