Skip to content

Commit bc881b4

Browse files
committed
refctoring of the code
1 parent bfdd86f commit bc881b4

18 files changed

+17
-119
lines changed

scrapegraphai/graphs/csv_scraper_graph.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""
22
Module for creating the smart scraper
33
"""
4-
54
from typing import Optional
65
from pydantic import BaseModel
76
from .base_graph import BaseGraph

scrapegraphai/graphs/csv_scraper_multi_graph.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""
22
CSVScraperMultiGraph Module
33
"""
4-
54
from copy import deepcopy
65
from typing import List, Optional
76
from pydantic import BaseModel

scrapegraphai/graphs/json_scraper_multi_graph.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""
22
JSONScraperMultiGraph Module
33
"""
4-
54
from copy import deepcopy
65
from typing import List, Optional
76
from pydantic import BaseModel

scrapegraphai/graphs/omni_scraper_graph.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""
22
OmniScraperGraph Module
33
"""
4-
54
from typing import Optional
65
from pydantic import BaseModel
76
from .base_graph import BaseGraph

scrapegraphai/graphs/omni_search_graph.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""
22
OmniSearchGraph Module
33
"""
4-
54
from copy import deepcopy
65
from typing import Optional
76
from pydantic import BaseModel

scrapegraphai/graphs/pdf_scraper_graph.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
"""
32
PDFScraperGraph Module
43
"""

scrapegraphai/graphs/pdf_scraper_multi_graph.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""
22
PdfScraperMultiGraph Module
33
"""
4-
54
from copy import deepcopy
65
from typing import List, Optional
76
from pydantic import BaseModel

scrapegraphai/graphs/script_creator_multi_graph.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""
22
ScriptCreatorMultiGraph Module
33
"""
4-
54
from typing import List, Optional
65
from pydantic import BaseModel
76
from .base_graph import BaseGraph

scrapegraphai/graphs/search_graph.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""
22
SearchGraph Module
33
"""
4-
54
from copy import deepcopy
65
from typing import Optional, List
76
from pydantic import BaseModel

scrapegraphai/graphs/smart_scraper_multi_concat_graph.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""
22
SmartScraperMultiGraph Module
33
"""
4-
54
from copy import deepcopy
65
from typing import List, Optional
76
from pydantic import BaseModel
@@ -14,7 +13,6 @@
1413
)
1514
from ..utils.copy import safe_deepcopy
1615

17-
1816
class SmartScraperMultiConcatGraph(AbstractGraph):
1917
"""
2018
SmartScraperMultiGraph is a scraping pipeline that scrapes a
@@ -43,9 +41,8 @@ class SmartScraperMultiConcatGraph(AbstractGraph):
4341
>>> result = search_graph.run()
4442
"""
4543

46-
def __init__(self, prompt: str, source: List[str],
44+
def __init__(self, prompt: str, source: List[str],
4745
config: dict, schema: Optional[BaseModel] = None):
48-
4946
self.copy_config = safe_deepcopy(config)
5047

5148
self.copy_schema = deepcopy(schema)

scrapegraphai/graphs/smart_scraper_multi_graph.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""
22
SmartScraperMultiGraph Module
33
"""
4-
54
from copy import deepcopy
65
from typing import List, Optional
76
from pydantic import BaseModel

scrapegraphai/graphs/xml_scraper_graph.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""
22
XMLScraperGraph Module
33
"""
4-
54
from typing import Optional
65
from pydantic import BaseModel
76
from .base_graph import BaseGraph

scrapegraphai/graphs/xml_scraper_multi_graph.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""
22
XMLScraperMultiGraph Module
33
"""
4-
54
from copy import deepcopy
65
from typing import List, Optional
76
from pydantic import BaseModel

scrapegraphai/nodes/base_node.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
"""
22
BaseNode Module
33
"""
4-
54
import re
65
from abc import ABC, abstractmethod
76
from typing import List, Optional
87
from ..utils import get_logger
98

10-
119
class BaseNode(ABC):
1210
"""
1311
An abstract base class for nodes in a graph-based workflow,

scrapegraphai/nodes/concat_answers_node.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""
22
ConcatAnswersNode Module
33
"""
4-
54
from typing import List, Optional
65
from ..utils.logging import get_logger
76
from .base_node import BaseNode

scrapegraphai/nodes/generate_answer_csv_node.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""
22
Module for generating the answer node
33
"""
4-
54
from typing import List, Optional
65
from langchain.prompts import PromptTemplate
76
from langchain_core.output_parsers import JsonOutputParser

scrapegraphai/nodes/graph_iterator_node.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import asyncio
55
from typing import List, Optional
66
from tqdm.asyncio import tqdm
7-
from .base_node import BaseNode
87
from pydantic import BaseModel
8+
from .base_node import BaseNode
99

1010
DEFAULT_BATCHSIZE = 16
1111

@@ -130,7 +130,7 @@ async def _async_run(graph):
130130
if url.startswith("http"):
131131
graph.input_key = "url"
132132
participants.append(graph)
133-
133+
134134
futures = [_async_run(graph) for graph in participants]
135135

136136
answers = await tqdm.gather(

scrapegraphai/nodes/rag_node.py

Lines changed: 14 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,17 @@
1313
from langchain_community.document_transformers import EmbeddingsRedundantFilter
1414
from langchain_community.vectorstores import FAISS
1515
from langchain_community.chat_models import ChatOllama
16-
from langchain_aws import BedrockEmbeddings, ChatBedrock
1716
from langchain_community.embeddings import OllamaEmbeddings
17+
from langchain_aws import BedrockEmbeddings, ChatBedrock
1818
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
1919
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings, ChatOpenAI, AzureChatOpenAI
2020
from ..utils.logging import get_logger
2121
from .base_node import BaseNode
2222
from ..helpers import models_tokens
2323
from ..models import DeepSeek
2424

25-
optional_modules = {"langchain_anthropic", "langchain_fireworks", "langchain_groq", "langchain_google_vertexai"}
25+
optional_modules = {"langchain_anthropic", "langchain_fireworks",
26+
"langchain_groq", "langchain_google_vertexai"}
2627

2728
class RAGNode(BaseNode):
2829
"""
@@ -60,96 +61,8 @@ def __init__(
6061
self.cache_path = node_config.get("cache_path", False)
6162

6263
def execute(self, state: dict) -> dict:
63-
"""
64-
Executes the node's logic to implement RAG (Retrieval-Augmented Generation).
65-
The method updates the state with relevant chunks of the document.
66-
67-
Args:
68-
state (dict): The current state of the graph. The input keys will be used to fetch the
69-
correct data from the state.
70-
71-
Returns:
72-
dict: The updated state with the output key containing the relevant chunks of the document.
73-
74-
Raises:
75-
KeyError: If the input keys are not found in the state, indicating that the
76-
necessary information for compressing the content is missing.
77-
"""
78-
79-
self.logger.info(f"--- Executing {self.node_name} Node ---")
80-
81-
input_keys = self.get_input_keys(state)
82-
83-
input_data = [state[key] for key in input_keys]
84-
85-
user_prompt = input_data[0]
86-
doc = input_data[1]
87-
88-
chunked_docs = []
89-
90-
for i, chunk in enumerate(doc):
91-
doc = Document(
92-
page_content=chunk,
93-
metadata={
94-
"chunk": i + 1,
95-
},
96-
)
97-
chunked_docs.append(doc)
98-
99-
self.logger.info("--- (updated chunks metadata) ---")
100-
101-
if self.embedder_model is not None:
102-
embeddings = self.embedder_model
103-
elif 'embeddings' in self.node_config:
104-
try:
105-
embeddings = self._create_embedder(self.node_config['embedder_config'])
106-
except Exception:
107-
try:
108-
embeddings = self._create_default_embedder()
109-
self.embedder_model = embeddings
110-
except ValueError:
111-
embeddings = self.llm_model
112-
self.embedder_model = self.llm_model
113-
else:
114-
embeddings = self.llm_model
115-
self.embedder_model = self.llm_model
116-
117-
folder_name = self.node_config.get("cache_path", "cache")
118-
119-
if self.node_config.get("cache_path", False) and not os.path.exists(folder_name):
120-
index = FAISS.from_documents(chunked_docs, embeddings)
121-
os.makedirs(folder_name)
122-
index.save_local(folder_name)
123-
self.logger.info("--- (indexes saved to cache) ---")
124-
125-
elif self.node_config.get("cache_path", False) and os.path.exists(folder_name):
126-
index = FAISS.load_local(folder_path=folder_name,
127-
embeddings=embeddings,
128-
allow_dangerous_deserialization=True)
129-
self.logger.info("--- (indexes loaded from cache) ---")
130-
131-
else:
132-
index = FAISS.from_documents(chunked_docs, embeddings)
133-
134-
retriever = index.as_retriever()
135-
136-
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
137-
# similarity_threshold could be set, now k=20
138-
relevant_filter = EmbeddingsFilter(embeddings=embeddings)
139-
pipeline_compressor = DocumentCompressorPipeline(
140-
transformers=[redundant_filter, relevant_filter]
141-
)
142-
compression_retriever = ContextualCompressionRetriever(
143-
base_compressor=pipeline_compressor, base_retriever=retriever
144-
)
145-
146-
compressed_docs = compression_retriever.invoke(user_prompt)
147-
148-
self.logger.info("--- (tokens compressed and vector stored) ---")
149-
150-
state.update({self.output[0]: compressed_docs})
151-
return state
152-
64+
# Execution logic
65+
pass
15366

15467
def _create_default_embedder(self, llm_config=None) -> object:
15568
"""
@@ -176,27 +89,28 @@ def _create_default_embedder(self, llm_config=None) -> object:
17689
elif isinstance(self.llm_model, AzureChatOpenAI):
17790
return AzureOpenAIEmbeddings()
17891
elif isinstance(self.llm_model, ChatOllama):
179-
# unwrap the kwargs from the model whihc is a dict
18092
params = self.llm_model._lc_kwargs
181-
# remove streaming and temperature
18293
params.pop("streaming", None)
18394
params.pop("temperature", None)
18495
return OllamaEmbeddings(**params)
18596
elif isinstance(self.llm_model, ChatBedrock):
18697
return BedrockEmbeddings(client=None, model_id=self.llm_model.model_id)
18798
elif all(key in sys.modules for key in optional_modules):
18899
if isinstance(self.llm_model, ChatFireworks):
100+
from langchain_fireworks import FireworksEmbeddings
189101
return FireworksEmbeddings(model=self.llm_model.model_name)
190102
if isinstance(self.llm_model, ChatNVIDIA):
103+
from langchain_nvidia import NVIDIAEmbeddings
191104
return NVIDIAEmbeddings(model=self.llm_model.model_name)
192105
if isinstance(self.llm_model, ChatHuggingFace):
106+
from langchain_huggingface import HuggingFaceEmbeddings
193107
return HuggingFaceEmbeddings(model=self.llm_model.model)
194108
if isinstance(self.llm_model, ChatVertexAI):
109+
from langchain_vertexai import VertexAIEmbeddings
195110
return VertexAIEmbeddings()
196111
else:
197112
raise ValueError("Embedding Model missing or not supported")
198113

199-
200114
def _create_embedder(self, embedder_config: dict) -> object:
201115
"""
202116
Create an embedding model instance based on the configuration provided.
@@ -240,20 +154,23 @@ def _create_embedder(self, embedder_config: dict) -> object:
240154
return BedrockEmbeddings(client=client, model_id=embedder_params["model"])
241155
if all(key in sys.modules for key in optional_modules):
242156
if "hugging_face" in embedder_params["model"]:
157+
from langchain_huggingface import HuggingFaceEmbeddings
243158
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
244159
try:
245160
models_tokens["hugging_face"][embedder_params["model"]]
246161
except KeyError as exc:
247162
raise KeyError("Model not supported") from exc
248163
return HuggingFaceEmbeddings(model=embedder_params["model"])
249-
if "fireworks" in embedder_params["model"]:
164+
elif "fireworks" in embedder_params["model"]:
165+
from langchain_fireworks import FireworksEmbeddings
250166
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
251167
try:
252168
models_tokens["fireworks"][embedder_params["model"]]
253169
except KeyError as exc:
254170
raise KeyError("Model not supported") from exc
255171
return FireworksEmbeddings(model=embedder_params["model"])
256-
if "nvidia" in embedder_params["model"]:
172+
elif "nvidia" in embedder_params["model"]:
173+
from langchain_nvidia import NVIDIAEmbeddings
257174
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
258175
try:
259176
models_tokens["nvidia"][embedder_params["model"]]

0 commit comments

Comments
 (0)