1
1
"""
2
2
SmartScraperGraph Module Burr Version
3
3
"""
4
- from typing import Tuple
4
+ from typing import Tuple , Union
5
5
6
6
from burr import tracking
7
7
from burr .core import Application , ApplicationBuilder , State , default , when
14
14
from langchain_community .document_transformers import Html2TextTransformer , EmbeddingsRedundantFilter
15
15
from langchain_community .vectorstores import FAISS
16
16
from langchain_core .documents import Document
17
+ from langchain_core import load as lc_serde
17
18
from langchain_core .output_parsers import JsonOutputParser
18
19
from langchain_core .prompts import PromptTemplate
19
20
from langchain_core .runnables import RunnableParallel
@@ -67,10 +68,10 @@ def parse_node(state: State, chunk_size: int = 4096) -> tuple[dict, State]:
67
68
68
69
@action (reads = ["user_prompt" , "parsed_doc" , "doc" ],
69
70
writes = ["relevant_chunks" ])
70
- def rag_node (state : State , llm_model : object , embedder_model : object ) -> tuple [dict , State ]:
71
- # bug around input serialization with tracker
72
- llm_model = OpenAI ({"model_name" : "gpt-3.5-turbo" })
73
- embedder_model = OpenAIEmbeddings ()
71
+ def rag_node (state : State , llm_model : str , embedder_model : object ) -> tuple [dict , State ]:
72
+ # bug around input serialization with tracker -- so instantiate objects here:
73
+ llm_model = OpenAI ({"model_name" : llm_model })
74
+ embedder_model = OpenAIEmbeddings () if embedder_model == "openai" else None
74
75
user_prompt = state ["user_prompt" ]
75
76
doc = state ["parsed_doc" ]
76
77
@@ -104,8 +105,10 @@ def rag_node(state: State, llm_model: object, embedder_model: object) -> tuple[d
104
105
105
106
@action (reads = ["user_prompt" , "relevant_chunks" , "parsed_doc" , "doc" ],
106
107
writes = ["answer" ])
107
- def generate_answer_node (state : State , llm_model : object ) -> tuple [dict , State ]:
108
- llm_model = OpenAI ({"model_name" : "gpt-3.5-turbo" })
108
+ def generate_answer_node (state : State , llm_model : str ) -> tuple [dict , State ]:
109
+ # bug around input serialization with tracker -- so instantiate objects here:
110
+ llm_model = OpenAI ({"model_name" : llm_model })
111
+
109
112
user_prompt = state ["user_prompt" ]
110
113
doc = state .get ("relevant_chunks" ,
111
114
state .get ("parsed_doc" ,
@@ -207,21 +210,49 @@ def post_run_step(
207
210
):
208
211
print (f"Finishing action: { action .name } " )
209
212
213
+ import json
214
+
215
+ def _deserialize_document (x : Union [str , dict ]) -> Document :
216
+ if isinstance (x , dict ):
217
+ return lc_serde .load (x )
218
+ elif isinstance (x , str ):
219
+ try :
220
+ return lc_serde .loads (x )
221
+ except json .JSONDecodeError :
222
+ return Document (page_content = x )
223
+ raise ValueError ("Couldn't deserialize document" )
224
+
210
225
211
226
def run (prompt : str , input_key : str , source : str , config : dict ) -> str :
227
+ # these configs aren't really used yet.
212
228
llm_model = config ["llm_model" ]
213
-
214
229
embedder_model = config ["embedder_model" ]
215
- open_ai_embedder = OpenAIEmbeddings ()
230
+ # open_ai_embedder = OpenAIEmbeddings()
216
231
chunk_size = config ["model_token" ]
217
232
233
+ tracker = tracking .LocalTrackingClient (project = "smart-scraper-graph" )
234
+ app_instance_id = "testing-12345678919"
218
235
initial_state = {
219
236
"user_prompt" : prompt ,
220
237
input_key : source ,
221
238
}
222
- from burr .core import expr
223
- tracker = tracking .LocalTrackingClient (project = "smart-scraper-graph" )
224
-
239
+ entry_point = "fetch_node"
240
+ if app_instance_id :
241
+ persisted_state = tracker .load (None , app_id = app_instance_id , sequence_no = None )
242
+ if not persisted_state :
243
+ print (f"Warning: No persisted state found for app_id { app_instance_id } ." )
244
+ else :
245
+ initial_state = persisted_state ["state" ]
246
+ # for now we need to manually deserialize LangChain messages into LangChain Objects
247
+ # i.e. we know which objects need to be LC objects
248
+ initial_state = initial_state .update (** {
249
+ "doc" : _deserialize_document (initial_state ["doc" ])
250
+ })
251
+ docs = [_deserialize_document (doc ) for doc in initial_state ["relevant_chunks" ]]
252
+ initial_state = initial_state .update (** {
253
+ "relevant_chunks" : docs
254
+ })
255
+ entry_point = persisted_state ["position" ]
225
256
226
257
app = (
227
258
ApplicationBuilder ()
@@ -236,16 +267,17 @@ def run(prompt: str, input_key: str, source: str, config: dict) -> str:
236
267
("parse_node" , "rag_node" , default ),
237
268
("rag_node" , "generate_answer_node" , default )
238
269
)
239
- # .with_entrypoint("fetch_node")
240
- # .with_state(**initial_state)
241
- .initialize_from (
242
- tracker ,
243
- resume_at_next_action = True , # always resume from entrypoint in the case of failure
244
- default_state = initial_state ,
245
- default_entrypoint = "fetch_node" ,
246
- )
247
- # .with_identifiers(app_id="testing-123456")
248
- .with_tracker (project = "smart-scraper-graph" )
270
+ .with_entrypoint (entry_point )
271
+ .with_state (** initial_state )
272
+ # this will work once we get serialization plugin for langchain objects done
273
+ # .initialize_from(
274
+ # tracker,
275
+ # resume_at_next_action=True, # always resume from entrypoint in the case of failure
276
+ # default_state=initial_state,
277
+ # default_entrypoint="fetch_node",
278
+ # )
279
+ .with_identifiers (app_id = app_instance_id )
280
+ .with_tracker (tracker )
249
281
.with_hooks (PrintLnHook ())
250
282
.build ()
251
283
)
@@ -270,8 +302,8 @@ def run(prompt: str, input_key: str, source: str, config: dict) -> str:
270
302
source = "https://en.wikipedia.org/wiki/Paris"
271
303
input_key = "url"
272
304
config = {
273
- "llm_model" : "rag-token " ,
274
- "embedder_model" : "foo " ,
305
+ "llm_model" : "gpt-3.5-turbo " ,
306
+ "embedder_model" : "openai " ,
275
307
"model_token" : "bar" ,
276
308
}
277
- run (prompt , input_key , source , config )
309
+ print ( run (prompt , input_key , source , config ) )
0 commit comments