6
6
import warnings
7
7
from langchain_community .callbacks import get_openai_callback
8
8
from typing import Tuple
9
- from collections import deque
10
9
11
10
12
11
class BaseGraph :
@@ -27,8 +26,6 @@ class BaseGraph:
27
26
28
27
Raises:
29
28
Warning: If the entry point node is not the first node in the list.
30
- ValueError: If conditional_node does not have exactly two outgoing edges
31
-
32
29
33
30
Example:
34
31
>>> BaseGraph(
@@ -51,7 +48,7 @@ def __init__(self, nodes: list, edges: list, entry_point: str):
51
48
52
49
self .nodes = nodes
53
50
self .edges = self ._create_edges ({e for e in edges })
54
- self .entry_point = entry_point
51
+ self .entry_point = entry_point . node_name
55
52
56
53
if nodes [0 ].node_name != entry_point .node_name :
57
54
# raise a warning if the entry point is not the first node in the list
@@ -71,16 +68,13 @@ def _create_edges(self, edges: list) -> dict:
71
68
72
69
edge_dict = {}
73
70
for from_node , to_node in edges :
74
- if from_node in edge_dict :
75
- edge_dict [from_node ].append (to_node )
76
- else :
77
- edge_dict [from_node ] = [to_node ]
71
+ edge_dict [from_node .node_name ] = to_node .node_name
78
72
return edge_dict
79
73
80
74
def execute (self , initial_state : dict ) -> Tuple [dict , list ]:
81
75
"""
82
- Executes the graph by traversing nodes in breadth-first order starting from the entry point.
83
- The execution follows the edges based on the result of each node's execution and continues until
76
+ Executes the graph by traversing nodes starting from the entry point. The execution
77
+ follows the edges based on the result of each node's execution and continues until
84
78
it reaches a node with no outgoing edges.
85
79
86
80
Args:
@@ -90,6 +84,7 @@ def execute(self, initial_state: dict) -> Tuple[dict, list]:
90
84
Tuple[dict, list]: A tuple containing the final state and a list of execution info.
91
85
"""
92
86
87
+ current_node_name = self .nodes [0 ]
93
88
state = initial_state
94
89
95
90
# variables for tracking execution info
@@ -103,22 +98,23 @@ def execute(self, initial_state: dict) -> Tuple[dict, list]:
103
98
"total_cost_USD" : 0.0 ,
104
99
}
105
100
106
- queue = deque ([self .entry_point ])
107
- while queue :
108
- current_node = queue .popleft ()
101
+ for index in self .nodes :
102
+
109
103
curr_time = time .time ()
110
- with get_openai_callback () as callback :
104
+ current_node = index
105
+
106
+ with get_openai_callback () as cb :
111
107
result = current_node .execute (state )
112
108
node_exec_time = time .time () - curr_time
113
109
total_exec_time += node_exec_time
114
110
115
111
cb = {
116
- "node_name" : current_node .node_name ,
117
- "total_tokens" : callback .total_tokens ,
118
- "prompt_tokens" : callback .prompt_tokens ,
119
- "completion_tokens" : callback .completion_tokens ,
120
- "successful_requests" : callback .successful_requests ,
121
- "total_cost_USD" : callback .total_cost ,
112
+ "node_name" : index .node_name ,
113
+ "total_tokens" : cb .total_tokens ,
114
+ "prompt_tokens" : cb .prompt_tokens ,
115
+ "completion_tokens" : cb .completion_tokens ,
116
+ "successful_requests" : cb .successful_requests ,
117
+ "total_cost_USD" : cb .total_cost ,
122
118
"exec_time" : node_exec_time ,
123
119
}
124
120
@@ -132,31 +128,21 @@ def execute(self, initial_state: dict) -> Tuple[dict, list]:
132
128
cb_total ["successful_requests" ] += cb ["successful_requests" ]
133
129
cb_total ["total_cost_USD" ] += cb ["total_cost_USD" ]
134
130
135
-
136
-
137
- current_node_connections = self .edges [current_node ]
138
- if current_node .node_type == 'conditional_node' :
139
- # Assert that there are exactly two out edges from the conditional node
140
- if len (current_node_connections ) != 2 :
141
- raise ValueError (f"Conditional node should have exactly two out connections { current_node_connections .node_name } " )
142
- if result ["next_node" ] == 0 :
143
- queue .append (current_node_connections [0 ])
144
- else :
145
- queue .append (current_node_connections [1 ])
146
- # remove the conditional node result
147
- del result ["next_node" ]
148
- else :
149
- queue .extend (node for node in current_node_connections )
150
-
151
-
152
- exec_info .append ({
153
- "node_name" : "TOTAL RESULT" ,
154
- "total_tokens" : cb_total ["total_tokens" ],
155
- "prompt_tokens" : cb_total ["prompt_tokens" ],
156
- "completion_tokens" : cb_total ["completion_tokens" ],
157
- "successful_requests" : cb_total ["successful_requests" ],
158
- "total_cost_USD" : cb_total ["total_cost_USD" ],
159
- "exec_time" : total_exec_time ,
160
- })
161
-
162
- return state , exec_info
131
+ if current_node .node_type == "conditional_node" :
132
+ current_node_name = result
133
+ elif current_node_name in self .edges :
134
+ current_node_name = self .edges [current_node_name ]
135
+ else :
136
+ current_node_name = None
137
+
138
+ exec_info .append ({
139
+ "node_name" : "TOTAL RESULT" ,
140
+ "total_tokens" : cb_total ["total_tokens" ],
141
+ "prompt_tokens" : cb_total ["prompt_tokens" ],
142
+ "completion_tokens" : cb_total ["completion_tokens" ],
143
+ "successful_requests" : cb_total ["successful_requests" ],
144
+ "total_cost_USD" : cb_total ["total_cost_USD" ],
145
+ "exec_time" : total_exec_time ,
146
+ })
147
+
148
+ return state , exec_info
0 commit comments