4
4
"""
5
5
6
6
import re
7
+ import uuid
8
+ from hashlib import md5
7
9
from typing import Any , Dict , List , Tuple
8
10
import inspect
9
11
13
15
raise ImportError ("burr package is not installed. Please install it with 'pip install scrapegraphai[burr]'" )
14
16
15
17
from burr import tracking
16
- from burr .core import Application , ApplicationBuilder , State , Action , default
18
+ from burr .core import Application , ApplicationBuilder , State , Action , default , ApplicationContext
17
19
from burr .lifecycle import PostRunStepHook , PreRunStepHook
18
20
19
21
@@ -55,7 +57,7 @@ def writes(self) -> list[str]:
55
57
56
58
def update (self , result : dict , state : State ) -> State :
57
59
return state .update (** result )
58
-
60
+
59
61
def get_source (self ) -> str :
60
62
return inspect .getsource (self .node .__class__ )
61
63
@@ -100,13 +102,12 @@ class BurrBridge:
100
102
def __init__ (self , base_graph , burr_config ):
101
103
self .base_graph = base_graph
102
104
self .burr_config = burr_config
103
- self .project_name = burr_config .get ("project_name" , "default-project" )
104
- self .tracker = tracking .LocalTrackingClient (project = self .project_name )
105
+ self .project_name = burr_config .get ("project_name" , "scrapegraph: {}" )
105
106
self .app_instance_id = burr_config .get ("app_instance_id" , "default-instance" )
106
107
self .burr_inputs = burr_config .get ("inputs" , {})
107
108
self .burr_app = None
108
109
109
- def _initialize_burr_app (self , initial_state : Dict [str , Any ] = {} ) -> Application :
110
+ def _initialize_burr_app (self , initial_state : Dict [str , Any ] = None ) -> Application :
110
111
"""
111
112
Initialize a Burr application from the base graph.
112
113
@@ -116,24 +117,41 @@ def _initialize_burr_app(self, initial_state: Dict[str, Any] = {}) -> Applicatio
116
117
Returns:
117
118
Application: The Burr application instance.
118
119
"""
120
+ if initial_state is None :
121
+ initial_state = {}
119
122
120
123
actions = self ._create_actions ()
121
124
transitions = self ._create_transitions ()
122
125
hooks = [PrintLnHook ()]
123
126
burr_state = State (initial_state )
124
-
125
- app = (
127
+ application_context = ApplicationContext . get ()
128
+ builder = (
126
129
ApplicationBuilder ()
127
130
.with_actions (** actions )
128
131
.with_transitions (* transitions )
129
132
.with_entrypoint (self .base_graph .entry_point )
130
133
.with_state (** burr_state )
131
- .with_identifiers (app_id = self .app_instance_id )
132
- .with_tracker (self .tracker )
134
+ .with_identifiers (app_id = str (uuid .uuid4 ())) # TODO -- grab this from state
133
135
.with_hooks (* hooks )
134
- .build ()
135
136
)
136
- return app
137
+ if application_context is not None :
138
+ builder = (
139
+ builder
140
+ # if we're using a tracker, we want to copy it/pass in
141
+ .with_tracker (
142
+ application_context .tracker .copy () if application_context .tracker is not None else None
143
+ ) # remember to do `copy()` here!
144
+ .with_spawning_parent (
145
+ application_context .app_id ,
146
+ application_context .sequence_id ,
147
+ application_context .partition_key ,
148
+ )
149
+ )
150
+ else :
151
+ # This is the case in which nothing is spawning it
152
+ # in this case, we want to create a new tracker from scratch
153
+ builder = builder .with_tracker (tracking .LocalTrackingClient (project = self .project_name ))
154
+ return builder .build ()
137
155
138
156
def _create_actions (self ) -> Dict [str , Any ]:
139
157
"""
0 commit comments