Skip to content

Commit f2bb1cc

Browse files
committed
Fixes LC document deserialization
Depends on apache/burr#175.
1 parent 82afa0e commit f2bb1cc

File tree

1 file changed

+57
-25
lines changed

1 file changed

+57
-25
lines changed

scrapegraphai/graphs/smart_scraper_graph_burr.py

Lines changed: 57 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
SmartScraperGraph Module Burr Version
33
"""
4-
from typing import Tuple
4+
from typing import Tuple, Union
55

66
from burr import tracking
77
from burr.core import Application, ApplicationBuilder, State, default, when
@@ -14,6 +14,7 @@
1414
from langchain_community.document_transformers import Html2TextTransformer, EmbeddingsRedundantFilter
1515
from langchain_community.vectorstores import FAISS
1616
from langchain_core.documents import Document
17+
from langchain_core import load as lc_serde
1718
from langchain_core.output_parsers import JsonOutputParser
1819
from langchain_core.prompts import PromptTemplate
1920
from langchain_core.runnables import RunnableParallel
@@ -67,10 +68,10 @@ def parse_node(state: State, chunk_size: int = 4096) -> tuple[dict, State]:
6768

6869
@action(reads=["user_prompt", "parsed_doc", "doc"],
6970
writes=["relevant_chunks"])
70-
def rag_node(state: State, llm_model: object, embedder_model: object) -> tuple[dict, State]:
71-
# bug around input serialization with tracker
72-
llm_model = OpenAI({"model_name": "gpt-3.5-turbo"})
73-
embedder_model = OpenAIEmbeddings()
71+
def rag_node(state: State, llm_model: str, embedder_model: object) -> tuple[dict, State]:
72+
# bug around input serialization with tracker -- so instantiate objects here:
73+
llm_model = OpenAI({"model_name": llm_model})
74+
embedder_model = OpenAIEmbeddings() if embedder_model == "openai" else None
7475
user_prompt = state["user_prompt"]
7576
doc = state["parsed_doc"]
7677

@@ -104,8 +105,10 @@ def rag_node(state: State, llm_model: object, embedder_model: object) -> tuple[d
104105

105106
@action(reads=["user_prompt", "relevant_chunks", "parsed_doc", "doc"],
106107
writes=["answer"])
107-
def generate_answer_node(state: State, llm_model: object) -> tuple[dict, State]:
108-
llm_model = OpenAI({"model_name": "gpt-3.5-turbo"})
108+
def generate_answer_node(state: State, llm_model: str) -> tuple[dict, State]:
109+
# bug around input serialization with tracker -- so instantiate objects here:
110+
llm_model = OpenAI({"model_name": llm_model})
111+
109112
user_prompt = state["user_prompt"]
110113
doc = state.get("relevant_chunks",
111114
state.get("parsed_doc",
@@ -207,21 +210,49 @@ def post_run_step(
207210
):
208211
print(f"Finishing action: {action.name}")
209212

213+
import json
214+
215+
def _deserialize_document(x: Union[str, dict]) -> Document:
216+
if isinstance(x, dict):
217+
return lc_serde.load(x)
218+
elif isinstance(x, str):
219+
try:
220+
return lc_serde.loads(x)
221+
except json.JSONDecodeError:
222+
return Document(page_content=x)
223+
raise ValueError("Couldn't deserialize document")
224+
210225

211226
def run(prompt: str, input_key: str, source: str, config: dict) -> str:
227+
# these configs aren't really used yet.
212228
llm_model = config["llm_model"]
213-
214229
embedder_model = config["embedder_model"]
215-
open_ai_embedder = OpenAIEmbeddings()
230+
# open_ai_embedder = OpenAIEmbeddings()
216231
chunk_size = config["model_token"]
217232

233+
tracker = tracking.LocalTrackingClient(project="smart-scraper-graph")
234+
app_instance_id = "testing-12345678919"
218235
initial_state = {
219236
"user_prompt": prompt,
220237
input_key: source,
221238
}
222-
from burr.core import expr
223-
tracker = tracking.LocalTrackingClient(project="smart-scraper-graph")
224-
239+
entry_point = "fetch_node"
240+
if app_instance_id:
241+
persisted_state = tracker.load(None, app_id=app_instance_id, sequence_no=None)
242+
if not persisted_state:
243+
print(f"Warning: No persisted state found for app_id {app_instance_id}.")
244+
else:
245+
initial_state = persisted_state["state"]
246+
# for now we need to manually deserialize LangChain messages into LangChain Objects
247+
# i.e. we know which objects need to be LC objects
248+
initial_state = initial_state.update(**{
249+
"doc": _deserialize_document(initial_state["doc"])
250+
})
251+
docs = [_deserialize_document(doc) for doc in initial_state["relevant_chunks"]]
252+
initial_state = initial_state.update(**{
253+
"relevant_chunks": docs
254+
})
255+
entry_point = persisted_state["position"]
225256

226257
app = (
227258
ApplicationBuilder()
@@ -236,16 +267,17 @@ def run(prompt: str, input_key: str, source: str, config: dict) -> str:
236267
("parse_node", "rag_node", default),
237268
("rag_node", "generate_answer_node", default)
238269
)
239-
# .with_entrypoint("fetch_node")
240-
# .with_state(**initial_state)
241-
.initialize_from(
242-
tracker,
243-
resume_at_next_action=True, # always resume from entrypoint in the case of failure
244-
default_state=initial_state,
245-
default_entrypoint="fetch_node",
246-
)
247-
# .with_identifiers(app_id="testing-123456")
248-
.with_tracker(project="smart-scraper-graph")
270+
.with_entrypoint(entry_point)
271+
.with_state(**initial_state)
272+
# this will work once we get serialization plugin for langchain objects done
273+
# .initialize_from(
274+
# tracker,
275+
# resume_at_next_action=True, # always resume from entrypoint in the case of failure
276+
# default_state=initial_state,
277+
# default_entrypoint="fetch_node",
278+
# )
279+
.with_identifiers(app_id=app_instance_id)
280+
.with_tracker(tracker)
249281
.with_hooks(PrintLnHook())
250282
.build()
251283
)
@@ -270,8 +302,8 @@ def run(prompt: str, input_key: str, source: str, config: dict) -> str:
270302
source = "https://en.wikipedia.org/wiki/Paris"
271303
input_key = "url"
272304
config = {
273-
"llm_model": "rag-token",
274-
"embedder_model": "foo",
305+
"llm_model": "gpt-3.5-turbo",
306+
"embedder_model": "openai",
275307
"model_token": "bar",
276308
}
277-
run(prompt, input_key, source, config)
309+
print(run(prompt, input_key, source, config))

0 commit comments

Comments
 (0)