Skip to content

Commit e714a59

Browse files
committed
refactoring of engine
1 parent 1b004d8 commit e714a59

File tree

3 files changed

+13
-15
lines changed

3 files changed

+13
-15
lines changed

examples/openai/custom_graph_openai.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,20 +62,19 @@
6262
# ************************************************
6363

6464
graph = BaseGraph(
65-
nodes={
65+
nodes=[
6666
robot_node,
6767
fetch_node,
6868
parse_node,
6969
rag_node,
7070
generate_answer_node,
71-
},
71+
],
7272
edges={
7373
(robot_node, fetch_node),
7474
(fetch_node, parse_node),
7575
(parse_node, rag_node),
7676
(rag_node, generate_answer_node)
7777
},
78-
entry_point=robot_node
7978
)
8079

8180
# ************************************************

scrapegraphai/graphs/base_graph.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,14 @@ class BaseGraph:
2626
entry_point (BaseNode): The node instance that represents the entry point of the graph.
2727
"""
2828

29-
def __init__(self, nodes: dict, edges: dict, entry_point: str):
29+
def __init__(self, nodes: list, edges: list):
3030
"""
3131
Initializes the graph with nodes, edges, and the entry point.
3232
"""
33-
self.nodes = {node.node_name: node for node in nodes}
33+
self.nodes = nodes
3434
self.edges = self._create_edges(edges)
35-
self.entry_point = entry_point.node_name
3635

37-
def _create_edges(self, edges: dict) -> dict:
36+
def _create_edges(self, edges: list) -> dict:
3837
"""
3938
Helper method to create a dictionary of edges from the given iterable of tuples.
4039
@@ -61,7 +60,7 @@ def execute(self, initial_state: dict) -> dict:
6160
Returns:
6261
dict: The state after execution has completed, which may have been altered by the nodes.
6362
"""
64-
current_node_name = self.entry_point
63+
current_node_name = self.nodes[0]
6564
state = initial_state
6665

6766
# variables for tracking execution info
@@ -75,10 +74,10 @@ def execute(self, initial_state: dict) -> dict:
7574
"total_cost_USD": 0.0,
7675
}
7776

78-
while current_node_name is not None:
77+
for index in self.nodes:
7978

8079
curr_time = time.time()
81-
current_node = self.nodes[current_node_name]
80+
current_node = index
8281

8382
with get_openai_callback() as cb:
8483
result = current_node.execute(state)

scrapegraphai/graphs/smart_scraper_graph.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
)
1111
from .abstract_graph import AbstractGraph
1212

13+
1314
class SmartScraperGraph(AbstractGraph):
1415
"""
1516
SmartScraper is a comprehensive web scraping tool that automates the process of extracting
@@ -52,25 +53,24 @@ def _create_graph(self):
5253
)
5354

5455
return BaseGraph(
55-
nodes={
56+
nodes=[
5657
fetch_node,
5758
parse_node,
5859
rag_node,
5960
generate_answer_node,
60-
},
61+
],
6162
edges={
6263
(fetch_node, parse_node),
6364
(parse_node, rag_node),
6465
(rag_node, generate_answer_node)
65-
},
66-
entry_point=fetch_node
66+
}
6767
)
6868

6969
def run(self) -> str:
7070
"""
7171
Executes the web scraping process and returns the answer to the prompt.
7272
"""
73-
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
73+
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
7474
self.final_state, self.execution_info = self.graph.execute(inputs)
7575

7676
return self.final_state.get("answer", "No answer found.")

0 commit comments

Comments
 (0)