Skip to content

Commit 0571b6d

Browse files
committed
feat: update base_graph
1 parent 66a29bc commit 0571b6d

File tree

1 file changed

+22
-17
lines changed

1 file changed

+22
-17
lines changed

scrapegraphai/graphs/base_graph.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
"""
2+
base_graph module
3+
"""
14
import time
25
import warnings
3-
from langchain_community.callbacks import get_openai_callback
46
from typing import Tuple
7+
from langchain_community.callbacks import get_openai_callback
8+
from ..integrations import BurrBridge
59

610
# Import telemetry functions
711
from ..telemetry import log_graph_execution, log_event
@@ -56,7 +60,7 @@ def __init__(self, nodes: list, edges: list, entry_point: str, use_burr: bool =
5660
# raise a warning if the entry point is not the first node in the list
5761
warnings.warn(
5862
"Careful! The entry point node is different from the first node in the graph.")
59-
63+
6064
# Burr configuration
6165
self.use_burr = use_burr
6266
self.burr_config = burr_config or {}
@@ -79,7 +83,8 @@ def _create_edges(self, edges: list) -> dict:
7983

8084
def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
8185
"""
82-
Executes the graph by traversing nodes starting from the entry point using the standard method.
86+
Executes the graph by traversing nodes starting from the
87+
entry point using the standard method.
8388
8489
Args:
8590
initial_state (dict): The initial state to pass to the entry point node.
@@ -114,23 +119,25 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
114119
curr_time = time.time()
115120
current_node = next(node for node in self.nodes if node.node_name == current_node_name)
116121

117-
118122
# check if there is a "source" key in the node config
119123
if current_node.__class__.__name__ == "FetchNode":
120124
# get the second key name of the state dictionary
121125
source_type = list(state.keys())[1]
122126
if state.get("user_prompt", None):
123-
prompt = state["user_prompt"] if type(state["user_prompt"]) == str else None
124-
# quick fix for local_dir source type
127+
# Set 'prompt' if 'user_prompt' is a string, otherwise None
128+
prompt = state["user_prompt"] if isinstance(state["user_prompt"], str) else None
129+
130+
# Convert 'local_dir' source type to 'html_dir'
125131
if source_type == "local_dir":
126132
source_type = "html_dir"
127133
elif source_type == "url":
128-
if type(state[source_type]) == list:
129-
# iterate through the list of urls and see if they are strings
134+
# If the source is a list, add string URLs to 'source'
135+
if isinstance(state[source_type], list):
130136
for url in state[source_type]:
131-
if type(url) == str:
137+
if isinstance(url, str):
132138
source.append(url)
133-
elif type(state[source_type]) == str:
139+
# If the source is a single string, add it to 'source'
140+
elif isinstance(state[source_type], str):
134141
source.append(state[source_type])
135142

136143
# check if there is an "llm_model" variable in the class
@@ -164,7 +171,6 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
164171
result = current_node.execute(state)
165172
except Exception as e:
166173
error_node = current_node.node_name
167-
168174
graph_execution_time = time.time() - start_time
169175
log_graph_execution(
170176
graph_name=self.graph_name,
@@ -221,7 +227,7 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
221227
graph_execution_time = time.time() - start_time
222228
response = state.get("answer", None) if source_type == "url" else None
223229
content = state.get("parsed_doc", None) if response is not None else None
224-
230+
225231
log_graph_execution(
226232
graph_name=self.graph_name,
227233
source=source,
@@ -251,26 +257,25 @@ def execute(self, initial_state: dict) -> Tuple[dict, list]:
251257

252258
self.initial_state = initial_state
253259
if self.use_burr:
254-
from ..integrations import BurrBridge
255-
260+
256261
bridge = BurrBridge(self, self.burr_config)
257262
result = bridge.execute(initial_state)
258263
return (result["_state"], [])
259264
else:
260265
return self._execute_standard(initial_state)
261-
266+
262267
def append_node(self, node):
263268
"""
264269
Adds a node to the graph.
265270
266271
Args:
267272
node (BaseNode): The node instance to add to the graph.
268273
"""
269-
274+
270275
# if node name already exists in the graph, raise an exception
271276
if node.node_name in {n.node_name for n in self.nodes}:
272277
raise ValueError(f"Node with name '{node.node_name}' already exists in the graph. You can change it by setting the 'node_name' attribute.")
273-
278+
274279
# get the last node in the list
275280
last_node = self.nodes[-1]
276281
# add the edge connecting the last node to the new node

0 commit comments

Comments
 (0)