Skip to content

Commit cd8d3e7

Browse files
committed
refactoring of the graphs
1 parent b2ebabd commit cd8d3e7

File tree

7 files changed

+33
-27
lines changed

7 files changed

+33
-27
lines changed

examples/openai/custom_graph_openai.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,13 @@
6969
rag_node,
7070
generate_answer_node,
7171
],
72-
edges={
72+
edges=[
7373
(robot_node, fetch_node),
7474
(fetch_node, parse_node),
7575
(parse_node, rag_node),
7676
(rag_node, generate_answer_node)
77-
},
77+
],
78+
entry_point=robot_node
7879
)
7980

8081
# ************************************************

manual deployment/commit_and_push_with_tests.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ pylint pylint scrapegraphai/**/*.py scrapegraphai/*.py tests/**/*.py
1313

1414
cd tests
1515

16+
poetry install
17+
1618
# Run pytest
1719
if ! pytest; then
1820
echo "Pytest failed. Aborting commit and push."

scrapegraphai/graphs/base_graph.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,29 +11,29 @@ class BaseGraph:
1111
BaseGraph manages the execution flow of a graph composed of interconnected nodes.
1212
1313
Attributes:
14-
nodes (dict): A dictionary mapping each node's name to its corresponding node instance.
15-
edges (dict): A dictionary representing the directed edges of the graph where each
14+
nodes (list): A dictionary mapping each node's name to its corresponding node instance.
15+
edges (list): A dictionary representing the directed edges of the graph where each
1616
key-value pair corresponds to the from-node and to-node relationship.
1717
entry_point (str): The name of the entry point node from which the graph execution begins.
1818
1919
Methods:
20-
execute(initial_state): Executes the graph's nodes starting from the entry point and
20+
execute(initial_state): Executes the graph's nodes starting from the entry point and
2121
traverses the graph based on the provided initial state.
2222
2323
Args:
2424
nodes (iterable): An iterable of node instances that will be part of the graph.
25-
edges (iterable): An iterable of tuples where each tuple represents a directed edge
25+
edges (iterable): An iterable of tuples where each tuple represents a directed edge
2626
in the graph, defined by a pair of nodes (from_node, to_node).
2727
entry_point (BaseNode): The node instance that represents the entry point of the graph.
2828
"""
2929

30-
def __init__(self, nodes: list, edges: dict, entry_point: str):
30+
def __init__(self, nodes: list, edges: list, entry_point: str):
3131
"""
3232
Initializes the graph with nodes, edges, and the entry point.
3333
"""
3434

35-
self.nodes = {node.node_name: node for node in nodes}
36-
self.edges = self._create_edges(edges)
35+
self.nodes = nodes
36+
self.edges = self._create_edges({e for e in edges})
3737
self.entry_point = entry_point.node_name
3838

3939
if nodes[0].node_name != entry_point.node_name:
@@ -58,8 +58,8 @@ def _create_edges(self, edges: list) -> dict:
5858

5959
def execute(self, initial_state: dict) -> dict:
6060
"""
61-
Executes the graph by traversing nodes starting from the entry point. The execution
62-
follows the edges based on the result of each node's execution and continues until
61+
Executes the graph by traversing nodes starting from the entry point. The execution
62+
follows the edges based on the result of each node's execution and continues until
6363
it reaches a node with no outgoing edges.
6464
6565
Args:
@@ -68,6 +68,7 @@ def execute(self, initial_state: dict) -> dict:
6868
Returns:
6969
dict: The state after execution has completed, which may have been altered by the nodes.
7070
"""
71+
print(self.nodes)
7172
current_node_name = self.nodes[0]
7273
state = initial_state
7374

scrapegraphai/graphs/script_creator_graph.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""
1+
"""
22
Module for creating the smart scraper
33
"""
44
from .base_graph import BaseGraph
@@ -57,17 +57,17 @@ def _create_graph(self):
5757
)
5858

5959
return BaseGraph(
60-
nodes={
60+
nodes=[
6161
fetch_node,
6262
parse_node,
6363
rag_node,
6464
generate_scraper_node,
65-
},
66-
edges={
65+
],
66+
edges=[
6767
(fetch_node, parse_node),
6868
(parse_node, rag_node),
6969
(rag_node, generate_scraper_node)
70-
},
70+
],
7171
entry_point=fetch_node
7272
)
7373

scrapegraphai/graphs/search_graph.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212
from .abstract_graph import AbstractGraph
1313

14+
1415
class SearchGraph(AbstractGraph):
1516
"""
1617
Module for searching info on the internet
@@ -49,19 +50,19 @@ def _create_graph(self):
4950
)
5051

5152
return BaseGraph(
52-
nodes={
53+
nodes=[
5354
search_internet_node,
5455
fetch_node,
5556
parse_node,
5657
rag_node,
5758
generate_answer_node,
58-
},
59-
edges={
59+
],
60+
edges=[
6061
(search_internet_node, fetch_node),
6162
(fetch_node, parse_node),
6263
(parse_node, rag_node),
6364
(rag_node, generate_answer_node)
64-
},
65+
],
6566
entry_point=search_internet_node
6667
)
6768

scrapegraphai/graphs/smart_scraper_graph.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""
1+
"""
22
Module for creating the smart scraper
33
"""
44
from .base_graph import BaseGraph
@@ -59,11 +59,12 @@ def _create_graph(self):
5959
rag_node,
6060
generate_answer_node,
6161
],
62-
edges={
62+
edges=[
6363
(fetch_node, parse_node),
6464
(parse_node, rag_node),
6565
(rag_node, generate_answer_node)
66-
}
66+
],
67+
entry_point=fetch_node
6768
)
6869

6970
def run(self) -> str:

scrapegraphai/graphs/speech_graph.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,19 +62,19 @@ def _create_graph(self):
6262
)
6363

6464
return BaseGraph(
65-
nodes={
65+
nodes=[
6666
fetch_node,
6767
parse_node,
6868
rag_node,
6969
generate_answer_node,
7070
text_to_speech_node
71-
},
72-
edges={
71+
],
72+
edges=[
7373
(fetch_node, parse_node),
7474
(parse_node, rag_node),
7575
(rag_node, generate_answer_node),
7676
(generate_answer_node, text_to_speech_node)
77-
},
77+
],
7878
entry_point=fetch_node
7979
)
8080

0 commit comments

Comments
 (0)