Skip to content

Commit 73fa31d

Browse files
committed
feat(base_graph): alligned with main
1 parent 02745a4 commit 73fa31d

File tree

1 file changed

+34
-48
lines changed

1 file changed

+34
-48
lines changed

scrapegraphai/graphs/base_graph.py

Lines changed: 34 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import warnings
77
from langchain_community.callbacks import get_openai_callback
88
from typing import Tuple
9-
from collections import deque
109

1110

1211
class BaseGraph:
@@ -27,8 +26,6 @@ class BaseGraph:
2726
2827
Raises:
2928
Warning: If the entry point node is not the first node in the list.
30-
ValueError: If conditional_node does not have exactly two outgoing edges
31-
3229
3330
Example:
3431
>>> BaseGraph(
@@ -51,7 +48,7 @@ def __init__(self, nodes: list, edges: list, entry_point: str):
5148

5249
self.nodes = nodes
5350
self.edges = self._create_edges({e for e in edges})
54-
self.entry_point = entry_point
51+
self.entry_point = entry_point.node_name
5552

5653
if nodes[0].node_name != entry_point.node_name:
5754
# raise a warning if the entry point is not the first node in the list
@@ -71,16 +68,13 @@ def _create_edges(self, edges: list) -> dict:
7168

7269
edge_dict = {}
7370
for from_node, to_node in edges:
74-
if from_node in edge_dict:
75-
edge_dict[from_node].append(to_node)
76-
else:
77-
edge_dict[from_node] = [to_node]
71+
edge_dict[from_node.node_name] = to_node.node_name
7872
return edge_dict
7973

8074
def execute(self, initial_state: dict) -> Tuple[dict, list]:
8175
"""
82-
Executes the graph by traversing nodes in breadth-first order starting from the entry point.
83-
The execution follows the edges based on the result of each node's execution and continues until
76+
Executes the graph by traversing nodes starting from the entry point. The execution
77+
follows the edges based on the result of each node's execution and continues until
8478
it reaches a node with no outgoing edges.
8579
8680
Args:
@@ -90,6 +84,7 @@ def execute(self, initial_state: dict) -> Tuple[dict, list]:
9084
Tuple[dict, list]: A tuple containing the final state and a list of execution info.
9185
"""
9286

87+
current_node_name = self.nodes[0]
9388
state = initial_state
9489

9590
# variables for tracking execution info
@@ -103,22 +98,23 @@ def execute(self, initial_state: dict) -> Tuple[dict, list]:
10398
"total_cost_USD": 0.0,
10499
}
105100

106-
queue = deque([self.entry_point])
107-
while queue:
108-
current_node = queue.popleft()
101+
for index in self.nodes:
102+
109103
curr_time = time.time()
110-
with get_openai_callback() as callback:
104+
current_node = index
105+
106+
with get_openai_callback() as cb:
111107
result = current_node.execute(state)
112108
node_exec_time = time.time() - curr_time
113109
total_exec_time += node_exec_time
114110

115111
cb = {
116-
"node_name": current_node.node_name,
117-
"total_tokens": callback.total_tokens,
118-
"prompt_tokens": callback.prompt_tokens,
119-
"completion_tokens": callback.completion_tokens,
120-
"successful_requests": callback.successful_requests,
121-
"total_cost_USD": callback.total_cost,
112+
"node_name": index.node_name,
113+
"total_tokens": cb.total_tokens,
114+
"prompt_tokens": cb.prompt_tokens,
115+
"completion_tokens": cb.completion_tokens,
116+
"successful_requests": cb.successful_requests,
117+
"total_cost_USD": cb.total_cost,
122118
"exec_time": node_exec_time,
123119
}
124120

@@ -132,31 +128,21 @@ def execute(self, initial_state: dict) -> Tuple[dict, list]:
132128
cb_total["successful_requests"] += cb["successful_requests"]
133129
cb_total["total_cost_USD"] += cb["total_cost_USD"]
134130

135-
136-
137-
current_node_connections = self.edges[current_node]
138-
if current_node.node_type == 'conditional_node':
139-
# Assert that there are exactly two out edges from the conditional node
140-
if len(current_node_connections) != 2:
141-
raise ValueError(f"Conditional node should have exactly two out connections {current_node_connections.node_name}")
142-
if result["next_node"] == 0:
143-
queue.append(current_node_connections[0])
144-
else:
145-
queue.append(current_node_connections[1])
146-
# remove the conditional node result
147-
del result["next_node"]
148-
else:
149-
queue.extend(node for node in current_node_connections)
150-
151-
152-
exec_info.append({
153-
"node_name": "TOTAL RESULT",
154-
"total_tokens": cb_total["total_tokens"],
155-
"prompt_tokens": cb_total["prompt_tokens"],
156-
"completion_tokens": cb_total["completion_tokens"],
157-
"successful_requests": cb_total["successful_requests"],
158-
"total_cost_USD": cb_total["total_cost_USD"],
159-
"exec_time": total_exec_time,
160-
})
161-
162-
return state, exec_info
131+
if current_node.node_type == "conditional_node":
132+
current_node_name = result
133+
elif current_node_name in self.edges:
134+
current_node_name = self.edges[current_node_name]
135+
else:
136+
current_node_name = None
137+
138+
exec_info.append({
139+
"node_name": "TOTAL RESULT",
140+
"total_tokens": cb_total["total_tokens"],
141+
"prompt_tokens": cb_total["prompt_tokens"],
142+
"completion_tokens": cb_total["completion_tokens"],
143+
"successful_requests": cb_total["successful_requests"],
144+
"total_cost_USD": cb_total["total_cost_USD"],
145+
"exec_time": total_exec_time,
146+
})
147+
148+
return state, exec_info

0 commit comments

Comments
 (0)