4
4
import time
5
5
import warnings
6
6
from typing import Tuple
7
- from langchain_community .callbacks import get_openai_callback
8
7
from ..telemetry import log_graph_execution
8
+ from ..utils import CustomOpenAiCallbackManager
9
9
10
10
class BaseGraph :
11
11
"""
@@ -52,6 +52,7 @@ def __init__(self, nodes: list, edges: list, entry_point: str, use_burr: bool =
52
52
self .entry_point = entry_point .node_name
53
53
self .graph_name = graph_name
54
54
self .initial_state = {}
55
+ self .callback_manager = CustomOpenAiCallbackManager ()
55
56
56
57
if nodes [0 ].node_name != entry_point .node_name :
57
58
# raise a warning if the entry point is not the first node in the list
@@ -154,7 +155,7 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
154
155
except Exception as e :
155
156
schema = None
156
157
157
- with get_openai_callback () as cb :
158
+ with self . callback_manager . exclusive_get_openai_callback () as cb :
158
159
try :
159
160
result = current_node .execute (state )
160
161
except Exception as e :
@@ -176,23 +177,24 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
176
177
node_exec_time = time .time () - curr_time
177
178
total_exec_time += node_exec_time
178
179
179
- cb_data = {
180
- "node_name" : current_node .node_name ,
181
- "total_tokens" : cb .total_tokens ,
182
- "prompt_tokens" : cb .prompt_tokens ,
183
- "completion_tokens" : cb .completion_tokens ,
184
- "successful_requests" : cb .successful_requests ,
185
- "total_cost_USD" : cb .total_cost ,
186
- "exec_time" : node_exec_time ,
187
- }
188
-
189
- exec_info .append (cb_data )
190
-
191
- cb_total ["total_tokens" ] += cb_data ["total_tokens" ]
192
- cb_total ["prompt_tokens" ] += cb_data ["prompt_tokens" ]
193
- cb_total ["completion_tokens" ] += cb_data ["completion_tokens" ]
194
- cb_total ["successful_requests" ] += cb_data ["successful_requests" ]
195
- cb_total ["total_cost_USD" ] += cb_data ["total_cost_USD" ]
180
+ if cb is not None :
181
+ cb_data = {
182
+ "node_name" : current_node .node_name ,
183
+ "total_tokens" : cb .total_tokens ,
184
+ "prompt_tokens" : cb .prompt_tokens ,
185
+ "completion_tokens" : cb .completion_tokens ,
186
+ "successful_requests" : cb .successful_requests ,
187
+ "total_cost_USD" : cb .total_cost ,
188
+ "exec_time" : node_exec_time ,
189
+ }
190
+
191
+ exec_info .append (cb_data )
192
+
193
+ cb_total ["total_tokens" ] += cb_data ["total_tokens" ]
194
+ cb_total ["prompt_tokens" ] += cb_data ["prompt_tokens" ]
195
+ cb_total ["completion_tokens" ] += cb_data ["completion_tokens" ]
196
+ cb_total ["successful_requests" ] += cb_data ["successful_requests" ]
197
+ cb_total ["total_cost_USD" ] += cb_data ["total_cost_USD" ]
196
198
197
199
if current_node .node_type == "conditional_node" :
198
200
current_node_name = result
0 commit comments