1
+ """
2
+ base_graph module
3
+ """
1
4
import time
2
5
import warnings
3
- from langchain_community .callbacks import get_openai_callback
4
6
from typing import Tuple
7
+ from langchain_community .callbacks import get_openai_callback
8
+ from ..integrations import BurrBridge
5
9
6
10
# Import telemetry functions
7
11
from ..telemetry import log_graph_execution , log_event
@@ -56,7 +60,7 @@ def __init__(self, nodes: list, edges: list, entry_point: str, use_burr: bool =
56
60
# raise a warning if the entry point is not the first node in the list
57
61
warnings .warn (
58
62
"Careful! The entry point node is different from the first node in the graph." )
59
-
63
+
60
64
# Burr configuration
61
65
self .use_burr = use_burr
62
66
self .burr_config = burr_config or {}
@@ -79,7 +83,8 @@ def _create_edges(self, edges: list) -> dict:
79
83
80
84
def _execute_standard (self , initial_state : dict ) -> Tuple [dict , list ]:
81
85
"""
82
- Executes the graph by traversing nodes starting from the entry point using the standard method.
86
+ Executes the graph by traversing nodes starting from the
87
+ entry point using the standard method.
83
88
84
89
Args:
85
90
initial_state (dict): The initial state to pass to the entry point node.
@@ -114,23 +119,25 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
114
119
curr_time = time .time ()
115
120
current_node = next (node for node in self .nodes if node .node_name == current_node_name )
116
121
117
-
118
122
# check if there is a "source" key in the node config
119
123
if current_node .__class__ .__name__ == "FetchNode" :
120
124
# get the second key name of the state dictionary
121
125
source_type = list (state .keys ())[1 ]
122
126
if state .get ("user_prompt" , None ):
123
- prompt = state ["user_prompt" ] if type (state ["user_prompt" ]) == str else None
124
- # quick fix for local_dir source type
127
+ # Set 'prompt' if 'user_prompt' is a string, otherwise None
128
+ prompt = state ["user_prompt" ] if isinstance (state ["user_prompt" ], str ) else None
129
+
130
+ # Convert 'local_dir' source type to 'html_dir'
125
131
if source_type == "local_dir" :
126
132
source_type = "html_dir"
127
133
elif source_type == "url" :
128
- if type ( state [ source_type ]) == list :
129
- # iterate through the list of urls and see if they are strings
134
+ # If the source is a list, add string URLs to 'source'
135
+ if isinstance ( state [ source_type ], list ):
130
136
for url in state [source_type ]:
131
- if type (url ) == str :
137
+ if isinstance (url , str ) :
132
138
source .append (url )
133
- elif type (state [source_type ]) == str :
139
+ # If the source is a single string, add it to 'source'
140
+ elif isinstance (state [source_type ], str ):
134
141
source .append (state [source_type ])
135
142
136
143
# check if there is an "llm_model" variable in the class
@@ -164,7 +171,6 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
164
171
result = current_node .execute (state )
165
172
except Exception as e :
166
173
error_node = current_node .node_name
167
-
168
174
graph_execution_time = time .time () - start_time
169
175
log_graph_execution (
170
176
graph_name = self .graph_name ,
@@ -221,7 +227,7 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
221
227
graph_execution_time = time .time () - start_time
222
228
response = state .get ("answer" , None ) if source_type == "url" else None
223
229
content = state .get ("parsed_doc" , None ) if response is not None else None
224
-
230
+
225
231
log_graph_execution (
226
232
graph_name = self .graph_name ,
227
233
source = source ,
@@ -251,26 +257,25 @@ def execute(self, initial_state: dict) -> Tuple[dict, list]:
251
257
252
258
self .initial_state = initial_state
253
259
if self .use_burr :
254
- from ..integrations import BurrBridge
255
-
260
+
256
261
bridge = BurrBridge (self , self .burr_config )
257
262
result = bridge .execute (initial_state )
258
263
return (result ["_state" ], [])
259
264
else :
260
265
return self ._execute_standard (initial_state )
261
-
266
+
262
267
def append_node (self , node ):
263
268
"""
264
269
Adds a node to the graph.
265
270
266
271
Args:
267
272
node (BaseNode): The node instance to add to the graph.
268
273
"""
269
-
274
+
270
275
# if node name already exists in the graph, raise an exception
271
276
if node .node_name in {n .node_name for n in self .nodes }:
272
277
raise ValueError (f"Node with name '{ node .node_name } ' already exists in the graph. You can change it by setting the 'node_name' attribute." )
273
-
278
+
274
279
# get the last node in the list
275
280
last_node = self .nodes [- 1 ]
276
281
# add the edge connecting the last node to the new node
0 commit comments