Skip to content

Commit d7afdb1

Browse files
authored
Merge pull request #670 from LorenzoPaleari/576-exec-info-misses-nested-graphs
Added CustomOpenaiCallback to ensure exclusive access to nested data.
2 parents 063dd1a + e657113 commit d7afdb1

File tree

3 files changed

+39
-19
lines changed

3 files changed

+39
-19
lines changed

scrapegraphai/graphs/base_graph.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import time
55
import warnings
66
from typing import Tuple
7-
from langchain_community.callbacks import get_openai_callback
87
from ..telemetry import log_graph_execution
8+
from ..utils import CustomOpenAiCallbackManager
99

1010
class BaseGraph:
1111
"""
@@ -52,6 +52,7 @@ def __init__(self, nodes: list, edges: list, entry_point: str, use_burr: bool =
5252
self.entry_point = entry_point.node_name
5353
self.graph_name = graph_name
5454
self.initial_state = {}
55+
self.callback_manager = CustomOpenAiCallbackManager()
5556

5657
if nodes[0].node_name != entry_point.node_name:
5758
# raise a warning if the entry point is not the first node in the list
@@ -154,7 +155,7 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
154155
except Exception as e:
155156
schema = None
156157

157-
with get_openai_callback() as cb:
158+
with self.callback_manager.exclusive_get_openai_callback() as cb:
158159
try:
159160
result = current_node.execute(state)
160161
except Exception as e:
@@ -176,23 +177,24 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
176177
node_exec_time = time.time() - curr_time
177178
total_exec_time += node_exec_time
178179

179-
cb_data = {
180-
"node_name": current_node.node_name,
181-
"total_tokens": cb.total_tokens,
182-
"prompt_tokens": cb.prompt_tokens,
183-
"completion_tokens": cb.completion_tokens,
184-
"successful_requests": cb.successful_requests,
185-
"total_cost_USD": cb.total_cost,
186-
"exec_time": node_exec_time,
187-
}
188-
189-
exec_info.append(cb_data)
190-
191-
cb_total["total_tokens"] += cb_data["total_tokens"]
192-
cb_total["prompt_tokens"] += cb_data["prompt_tokens"]
193-
cb_total["completion_tokens"] += cb_data["completion_tokens"]
194-
cb_total["successful_requests"] += cb_data["successful_requests"]
195-
cb_total["total_cost_USD"] += cb_data["total_cost_USD"]
180+
if cb is not None:
181+
cb_data = {
182+
"node_name": current_node.node_name,
183+
"total_tokens": cb.total_tokens,
184+
"prompt_tokens": cb.prompt_tokens,
185+
"completion_tokens": cb.completion_tokens,
186+
"successful_requests": cb.successful_requests,
187+
"total_cost_USD": cb.total_cost,
188+
"exec_time": node_exec_time,
189+
}
190+
191+
exec_info.append(cb_data)
192+
193+
cb_total["total_tokens"] += cb_data["total_tokens"]
194+
cb_total["prompt_tokens"] += cb_data["prompt_tokens"]
195+
cb_total["completion_tokens"] += cb_data["completion_tokens"]
196+
cb_total["successful_requests"] += cb_data["successful_requests"]
197+
cb_total["total_cost_USD"] += cb_data["total_cost_USD"]
196198

197199
if current_node.node_type == "conditional_node":
198200
current_node_name = result

scrapegraphai/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@
1717
from .screenshot_scraping.text_detection import detect_text
1818
from .tokenizer import num_tokens_calculus
1919
from .split_text_into_chunks import split_text_into_chunks
20+
from .custom_openai_callback import CustomOpenAiCallbackManager
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import threading
2+
from contextlib import contextmanager
3+
from langchain_community.callbacks import get_openai_callback
4+
5+
class CustomOpenAiCallbackManager:
6+
_lock = threading.Lock()
7+
8+
@contextmanager
9+
def exclusive_get_openai_callback(self):
10+
if CustomOpenAiCallbackManager._lock.acquire(blocking=False):
11+
try:
12+
with get_openai_callback() as cb:
13+
yield cb
14+
finally:
15+
CustomOpenAiCallbackManager._lock.release()
16+
else:
17+
yield None

0 commit comments

Comments
 (0)