Skip to content

Commit c881f64

Browse files
committed
fix(cache): correctly pass the node arguments and logging
1 parent 543b487 commit c881f64

File tree

3 files changed

+15
-10
lines changed

3 files changed

+15
-10
lines changed

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
sphinx==7.1.2
22
furo==2024.5.6
33
pytest==8.0.0
4-
burr[start]==0.19.1
4+
burr[start]==0.22.1

scrapegraphai/graphs/abstract_graph.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def __init__(self, prompt: str, config: dict,
7676
self.headless = True if config is None else config.get(
7777
"headless", True)
7878
self.loader_kwargs = config.get("loader_kwargs", {})
79+
self.cache_path = config.get("cache_path", False)
7980

8081
# Create the graph
8182
self.graph = self._create_graph()
@@ -91,15 +92,13 @@ def __init__(self, prompt: str, config: dict,
9192
else:
9293
set_verbosity_warning()
9394

94-
self.headless = True if config is None else config.get("headless", True)
95-
self.loader_kwargs = config.get("loader_kwargs", {})
96-
9795
common_params = {
9896
"headless": self.headless,
9997
"verbose": self.verbose,
10098
"loader_kwargs": self.loader_kwargs,
10199
"llm_model": self.llm_model,
102-
"embedder_model": self.embedder_model
100+
"embedder_model": self.embedder_model,
101+
"cache_path": self.cache_path,
103102
}
104103

105104
self.set_common_params(common_params, overwrite=False)

scrapegraphai/nodes/rag_node.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(
5151
self.verbose = (
5252
False if node_config is None else node_config.get("verbose", False)
5353
)
54+
self.cache_path = node_config.get("cache_path", False)
5455

5556
def execute(self, state: dict) -> dict:
5657
"""
@@ -99,15 +100,20 @@ def execute(self, state: dict) -> dict:
99100
)
100101
embeddings = self.embedder_model
101102

102-
folder_name = self.node_config.get("cache", "cache")
103+
folder_name = self.node_config.get("cache_path", "cache")
103104

104-
if self.node_config.get("cache", False) and not os.path.exists(folder_name):
105+
if self.node_config.get("cache_path", False) and not os.path.exists(folder_name):
105106
index = FAISS.from_documents(chunked_docs, embeddings)
106107
os.makedirs(folder_name)
107-
108108
index.save_local(folder_name)
109-
if self.node_config.get("cache", False) and os.path.exists(folder_name):
110-
index = FAISS.load_local(folder_path=folder_name, embeddings=embeddings)
109+
self.logger.info("--- (indexes saved to cache) ---")
110+
111+
elif self.node_config.get("cache_path", False) and os.path.exists(folder_name):
112+
index = FAISS.load_local(folder_path=folder_name,
113+
embeddings=embeddings,
114+
allow_dangerous_deserialization=True)
115+
self.logger.info("--- (indexes loaded from cache) ---")
116+
111117
else:
112118
index = FAISS.from_documents(chunked_docs, embeddings)
113119

0 commit comments

Comments
 (0)