Skip to content

Commit 42d2ab1

Browse files
authored
Merge pull request #250 from mayurdb/graphRevamp
fix: Add a new graph traversal that allows more than one edges out of a graph
2 parents 8727d03 + 0b71b9a commit 42d2ab1

File tree

2 files changed

+55
-53
lines changed

2 files changed

+55
-53
lines changed

scrapegraphai/graphs/base_graph.py

Lines changed: 45 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import warnings
77
from langchain_community.callbacks import get_openai_callback
88
from typing import Tuple
9+
from collections import deque
910

1011

1112
class BaseGraph:
@@ -26,6 +27,8 @@ class BaseGraph:
2627
2728
Raises:
2829
Warning: If the entry point node is not the first node in the list.
30+
ValueError: If conditional_node does not have exactly two outgoing edges
31+
2932
3033
Example:
3134
>>> BaseGraph(
@@ -48,7 +51,7 @@ def __init__(self, nodes: list, edges: list, entry_point: str):
4851

4952
self.nodes = nodes
5053
self.edges = self._create_edges({e for e in edges})
51-
self.entry_point = entry_point.node_name
54+
self.entry_point = entry_point
5255

5356
if nodes[0].node_name != entry_point.node_name:
5457
# raise a warning if the entry point is not the first node in the list
@@ -68,13 +71,16 @@ def _create_edges(self, edges: list) -> dict:
6871

6972
edge_dict = {}
7073
for from_node, to_node in edges:
71-
edge_dict[from_node.node_name] = to_node.node_name
74+
if from_node in edge_dict:
75+
edge_dict[from_node].append(to_node)
76+
else:
77+
edge_dict[from_node] = [to_node]
7278
return edge_dict
7379

7480
def execute(self, initial_state: dict) -> Tuple[dict, list]:
7581
"""
76-
Executes the graph by traversing nodes starting from the entry point. The execution
77-
follows the edges based on the result of each node's execution and continues until
82+
Executes the graph by traversing nodes in breadth-first order starting from the entry point.
83+
The execution follows the edges based on the result of each node's execution and continues until
7884
it reaches a node with no outgoing edges.
7985
8086
Args:
@@ -84,7 +90,6 @@ def execute(self, initial_state: dict) -> Tuple[dict, list]:
8490
Tuple[dict, list]: A tuple containing the final state and a list of execution info.
8591
"""
8692

87-
current_node_name = self.nodes[0]
8893
state = initial_state
8994

9095
# variables for tracking execution info
@@ -98,23 +103,22 @@ def execute(self, initial_state: dict) -> Tuple[dict, list]:
98103
"total_cost_USD": 0.0,
99104
}
100105

101-
for index in self.nodes:
102-
106+
queue = deque([self.entry_point])
107+
while queue:
108+
current_node = queue.popleft()
103109
curr_time = time.time()
104-
current_node = index
105-
106-
with get_openai_callback() as cb:
110+
with get_openai_callback() as callback:
107111
result = current_node.execute(state)
108112
node_exec_time = time.time() - curr_time
109113
total_exec_time += node_exec_time
110114

111115
cb = {
112-
"node_name": index.node_name,
113-
"total_tokens": cb.total_tokens,
114-
"prompt_tokens": cb.prompt_tokens,
115-
"completion_tokens": cb.completion_tokens,
116-
"successful_requests": cb.successful_requests,
117-
"total_cost_USD": cb.total_cost,
116+
"node_name": current_node.node_name,
117+
"total_tokens": callback.total_tokens,
118+
"prompt_tokens": callback.prompt_tokens,
119+
"completion_tokens": callback.completion_tokens,
120+
"successful_requests": callback.successful_requests,
121+
"total_cost_USD": callback.total_cost,
118122
"exec_time": node_exec_time,
119123
}
120124

@@ -128,21 +132,30 @@ def execute(self, initial_state: dict) -> Tuple[dict, list]:
128132
cb_total["successful_requests"] += cb["successful_requests"]
129133
cb_total["total_cost_USD"] += cb["total_cost_USD"]
130134

131-
if current_node.node_type == "conditional_node":
132-
current_node_name = result
133-
elif current_node_name in self.edges:
134-
current_node_name = self.edges[current_node_name]
135-
else:
136-
current_node_name = None
137-
138-
exec_info.append({
139-
"node_name": "TOTAL RESULT",
140-
"total_tokens": cb_total["total_tokens"],
141-
"prompt_tokens": cb_total["prompt_tokens"],
142-
"completion_tokens": cb_total["completion_tokens"],
143-
"successful_requests": cb_total["successful_requests"],
144-
"total_cost_USD": cb_total["total_cost_USD"],
145-
"exec_time": total_exec_time,
146-
})
135+
if current_node in self.edges:
136+
current_node_connections = self.edges[current_node]
137+
if current_node.node_type == 'conditional_node':
138+
# Assert that there are exactly two out edges from the conditional node
139+
if len(current_node_connections) != 2:
140+
raise ValueError(f"Conditional node should have exactly two out connections {current_node_connections.node_name}")
141+
if result["next_node"] == 0:
142+
queue.append(current_node_connections[0])
143+
else:
144+
queue.append(current_node_connections[1])
145+
# remove the conditional node result
146+
del result["next_node"]
147+
else:
148+
queue.extend(node for node in current_node_connections)
149+
150+
151+
exec_info.append({
152+
"node_name": "TOTAL RESULT",
153+
"total_tokens": cb_total["total_tokens"],
154+
"prompt_tokens": cb_total["prompt_tokens"],
155+
"completion_tokens": cb_total["completion_tokens"],
156+
"successful_requests": cb_total["successful_requests"],
157+
"total_cost_USD": cb_total["total_cost_USD"],
158+
"exec_time": total_exec_time,
159+
})
147160

148161
return state, exec_info

scrapegraphai/nodes/conditional_node.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,46 +13,33 @@ class ConditionalNode(BaseNode):
1313
This node type is used to implement branching logic within the graph, allowing
1414
for dynamic paths based on the data available in the current state.
1515
16+
It is expected thar exactly two edges are created out of this node.
17+
The first node is chosen for execution if the key exists and has a non-empty value,
18+
and the second node is chosen if the key does not exist or is empty.
19+
1620
Attributes:
1721
key_name (str): The name of the key in the state to check for its presence.
18-
next_nodes (list): A list of two node instances. The first node is chosen
19-
for execution if the key exists and has a non-empty value,
20-
and the second node is chosen if the key does not exist or
21-
is empty.
2222
2323
Args:
2424
key_name (str): The name of the key to check in the graph's state. This is
2525
used to determine the path the graph's execution should take.
26-
next_nodes (list): A list containing exactly two node instances, specifying
27-
the next nodes to execute based on the condition's outcome.
2826
node_name (str, optional): The unique identifier name for the node. Defaults
2927
to "ConditionalNode".
3028
31-
Raises:
32-
ValueError: If next_nodes does not contain exactly two elements, indicating
33-
a misconfiguration in specifying the conditional paths.
3429
"""
3530

36-
def __init__(self, key_name: str, next_nodes: list, node_name="ConditionalNode"):
31+
def __init__(self, key_name: str, node_name="ConditionalNode"):
3732
"""
3833
Initializes the node with the key to check and the next node names based on the condition.
3934
4035
Args:
4136
key_name (str): The name of the key to check in the state.
42-
next_nodes (list): A list containing exactly two names of the next nodes.
43-
The first is used if the key exists, the second if it does not.
44-
45-
Raises:
46-
ValueError: If next_nodes does not contain exactly two elements.
4737
"""
4838

4939
super().__init__(node_name, "conditional_node")
5040
self.key_name = key_name
51-
if len(next_nodes) != 2:
52-
raise ValueError("next_nodes must contain exactly two elements.")
53-
self.next_nodes = next_nodes
5441

55-
def execute(self, state: dict) -> str:
42+
def execute(self, state: dict) -> dict:
5643
"""
5744
Checks if the specified key is present in the state and decides the next node accordingly.
5845
@@ -64,5 +51,7 @@ def execute(self, state: dict) -> str:
6451
"""
6552

6653
if self.key_name in state and len(state[self.key_name]) > 0:
67-
return self.next_nodes[0].node_name
68-
return self.next_nodes[1].node_name
54+
state["next_node"] = 0
55+
else:
56+
state["next_node"] = 1
57+
return state

0 commit comments

Comments
 (0)