Skip to content

Commit 589da1d

Browse files
authored
Merge pull request #351 from VinciGit00/faiss_integration
Faiss integration
2 parents fa951b4 + edddb68 commit 589da1d

File tree

5 files changed

+29
-6
lines changed

5 files changed

+29
-6
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ docs/source/_static/
2323
venv/
2424
.venv/
2525
.vscode/
26+
.conda/
2627

2728
# exclude pdf, mp3
2829
*.pdf
@@ -38,3 +39,6 @@ lib/
3839
*.html
3940
.idea
4041

42+
# extras
43+
cache/
44+
run_smart_scraper.py

docs/source/scrapers/graph_config.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Some interesting ones are:
1313
- `loader_kwargs`: A dictionary with additional parameters to be passed to the `Loader` class, such as `proxy`.
1414
- `burr_kwargs`: A dictionary with additional parameters to enable `Burr` graphical user interface.
1515
- `max_images`: The maximum number of images to be analyzed. Useful in `OmniScraperGraph` and `OmniSearchGraph`.
16+
- `cache_path`: The path where the cache files will be saved. If already exists, the cache will be loaded from this path.
1617

1718
.. _Burr:
1819

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
sphinx==7.1.2
22
furo==2024.5.6
33
pytest==8.0.0
4-
burr[start]==0.19.1
4+
burr[start]==0.22.1

scrapegraphai/graphs/abstract_graph.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def __init__(self, prompt: str, config: dict,
7878
self.headless = True if config is None else config.get(
7979
"headless", True)
8080
self.loader_kwargs = config.get("loader_kwargs", {})
81+
self.cache_path = config.get("cache_path", False)
8182

8283
# Create the graph
8384
self.graph = self._create_graph()
@@ -93,15 +94,13 @@ def __init__(self, prompt: str, config: dict,
9394
else:
9495
set_verbosity_warning()
9596

96-
self.headless = True if config is None else config.get("headless", True)
97-
self.loader_kwargs = config.get("loader_kwargs", {})
98-
9997
common_params = {
10098
"headless": self.headless,
10199
"verbose": self.verbose,
102100
"loader_kwargs": self.loader_kwargs,
103101
"llm_model": self.llm_model,
104-
"embedder_model": self.embedder_model
102+
"embedder_model": self.embedder_model,
103+
"cache_path": self.cache_path,
105104
}
106105

107106
self.set_common_params(common_params, overwrite=False)

scrapegraphai/nodes/rag_node.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
from typing import List, Optional
6+
import os
67

78
from langchain.docstore.document import Document
89
from langchain.retrievers import ContextualCompressionRetriever
@@ -50,6 +51,7 @@ def __init__(
5051
self.verbose = (
5152
False if node_config is None else node_config.get("verbose", False)
5253
)
54+
self.cache_path = node_config.get("cache_path", False)
5355

5456
def execute(self, state: dict) -> dict:
5557
"""
@@ -98,7 +100,24 @@ def execute(self, state: dict) -> dict:
98100
)
99101
embeddings = self.embedder_model
100102

101-
retriever = FAISS.from_documents(chunked_docs, embeddings).as_retriever()
103+
folder_name = self.node_config.get("cache_path", "cache")
104+
105+
if self.node_config.get("cache_path", False) and not os.path.exists(folder_name):
106+
index = FAISS.from_documents(chunked_docs, embeddings)
107+
os.makedirs(folder_name)
108+
index.save_local(folder_name)
109+
self.logger.info("--- (indexes saved to cache) ---")
110+
111+
elif self.node_config.get("cache_path", False) and os.path.exists(folder_name):
112+
index = FAISS.load_local(folder_path=folder_name,
113+
embeddings=embeddings,
114+
allow_dangerous_deserialization=True)
115+
self.logger.info("--- (indexes loaded from cache) ---")
116+
117+
else:
118+
index = FAISS.from_documents(chunked_docs, embeddings)
119+
120+
retriever = index.as_retriever()
102121

103122
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
104123
# similarity_threshold could be set, now k=20

0 commit comments

Comments
 (0)