Skip to content

Commit f7ba1f3

Browse files
committed
refactoring of the code
1 parent 437e48f commit f7ba1f3

20 files changed

+44
-108
lines changed

scrapegraphai/graphs/abstract_graph.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
import uuid
88
import warnings
99
from pydantic import BaseModel
10-
from langchain_community.chat_models import ErnieBotChat
11-
from langchain_nvidia_ai_endpoints import ChatNVIDIA
1210
from langchain.chat_models import init_chat_model
1311
from ..helpers import models_tokens
1412
from ..models import (
@@ -147,8 +145,7 @@ def handle_model(model_name, provider, token_key, default_token=8192):
147145
warnings.simplefilter("ignore")
148146
return init_chat_model(**llm_params)
149147

150-
known_models = ["chatgpt","gpt","openai", "azure_openai", "google_genai", "ollama", "oneapi", "nvidia", "groq", "google_vertexai", "bedrock", "mistralai", "hugging_face", "deepseek", "ernie", "fireworks"]
151-
148+
known_models = {"chatgpt","gpt","openai", "azure_openai", "google_genai", "ollama", "oneapi", "nvidia", "groq", "google_vertexai", "bedrock", "mistralai", "hugging_face", "deepseek", "ernie", "fireworks"}
152149
if llm_params["model"].split("/")[0] not in known_models and llm_params["model"].split("-")[0] not in known_models:
153150
raise ValueError(f"Model '{llm_params['model']}' is not supported")
154151

@@ -198,6 +195,8 @@ def handle_model(model_name, provider, token_key, default_token=8192):
198195
return DeepSeek(llm_params)
199196

200197
elif "ernie" in llm_params["model"]:
198+
from langchain_community.chat_models import ErnieBotChat
199+
201200
try:
202201
self.model_token = models_tokens["ernie"][llm_params["model"]]
203202
except KeyError:
@@ -215,6 +214,8 @@ def handle_model(model_name, provider, token_key, default_token=8192):
215214
return OneApi(llm_params)
216215

217216
elif "nvidia" in llm_params["model"]:
217+
from langchain_nvidia_ai_endpoints import ChatNVIDIA
218+
218219
try:
219220
self.model_token = models_tokens["nvidia"][llm_params["model"].split("/")[-1]]
220221
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])

scrapegraphai/nodes/generate_answer_csv_node.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from tqdm import tqdm
1010
from ..utils.logging import get_logger
1111
from .base_node import BaseNode
12-
from ..prompts.generate_answer_node_csv_prompts import TEMPLATE_CHUKS_CSV, TEMPLATE_NO_CHUKS_CSV, TEMPLATE_MERGE_CSV
12+
from ..prompts.generate_answer_node_csv_prompts import (TEMPLATE_CHUKS_CSV,
13+
TEMPLATE_NO_CHUKS_CSV, TEMPLATE_MERGE_CSV)
1314

1415
class GenerateAnswerCSVNode(BaseNode):
1516
"""
@@ -95,22 +96,22 @@ def execute(self, state):
9596
else:
9697
output_parser = JsonOutputParser()
9798

98-
TEMPLATE_NO_CHUKS_CSV_prompt = TEMPLATE_NO_CHUKS_CSV
99-
TEMPLATE_CHUKS_CSV_prompt = TEMPLATE_CHUKS_CSV
100-
TEMPLATE_MERGE_CSV_prompt = TEMPLATE_MERGE_CSV
99+
TEMPLATE_NO_CHUKS_CSV_PROMPT = TEMPLATE_NO_CHUKS_CSV
100+
TEMPLATE_CHUKS_CSV_PROMPT = TEMPLATE_CHUKS_CSV
101+
TEMPLATE_MERGE_CSV_PROMPT = TEMPLATE_MERGE_CSV
101102

102103
if self.additional_info is not None:
103-
TEMPLATE_NO_CHUKS_CSV_prompt = self.additional_info + TEMPLATE_NO_CHUKS_CSV
104-
TEMPLATE_CHUKS_CSV_prompt = self.additional_info + TEMPLATE_CHUKS_CSV
105-
TEMPLATE_MERGE_CSV_prompt = self.additional_info + TEMPLATE_MERGE_CSV
104+
TEMPLATE_NO_CHUKS_CSV_PROMPT = self.additional_info + TEMPLATE_NO_CHUKS_CSV
105+
TEMPLATE_CHUKS_CSV_PROMPT = self.additional_info + TEMPLATE_CHUKS_CSV
106+
TEMPLATE_MERGE_CSV_PROMPT = self.additional_info + TEMPLATE_MERGE_CSV
106107

107108
format_instructions = output_parser.get_format_instructions()
108109

109110
chains_dict = {}
110111

111112
if len(doc) == 1:
112113
prompt = PromptTemplate(
113-
template=TEMPLATE_NO_CHUKS_CSV_prompt,
114+
template=TEMPLATE_NO_CHUKS_CSV_PROMPT,
114115
input_variables=["question"],
115116
partial_variables={
116117
"context": doc,
@@ -127,7 +128,7 @@ def execute(self, state):
127128
tqdm(doc, desc="Processing chunks", disable=not self.verbose)
128129
):
129130
prompt = PromptTemplate(
130-
template=TEMPLATE_CHUKS_CSV_prompt,
131+
template=TEMPLATE_CHUKS_CSV_PROMPT,
131132
input_variables=["question"],
132133
partial_variables={
133134
"context": chunk,
@@ -144,7 +145,7 @@ def execute(self, state):
144145
batch_results = async_runner.invoke({"question": user_prompt})
145146

146147
merge_prompt = PromptTemplate(
147-
template = TEMPLATE_MERGE_CSV_prompt,
148+
template = TEMPLATE_MERGE_CSV_PROMPT,
148149
input_variables=["context", "question"],
149150
partial_variables={"format_instructions": format_instructions},
150151
)
@@ -153,4 +154,4 @@ def execute(self, state):
153154
answer = merge_chain.invoke({"context": batch_results, "question": user_prompt})
154155

155156
state.update({self.output[0]: answer})
156-
return state
157+
return state

scrapegraphai/nodes/generate_scraper_node.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,8 @@ def execute(self, state: dict) -> dict:
6767

6868
self.logger.info(f"--- Executing {self.node_name} Node ---")
6969

70-
# Interpret input keys based on the provided input expression
7170
input_keys = self.get_input_keys(state)
7271

73-
# Fetching data from the state based on the input keys
7472
input_data = [state[key] for key in input_keys]
7573

7674
user_prompt = input_data[0]

scrapegraphai/nodes/get_probable_tags_node.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,8 @@ def execute(self, state: dict) -> dict:
5858

5959
self.logger.info(f"--- Executing {self.node_name} Node ---")
6060

61-
# Interpret input keys based on the provided input expression
6261
input_keys = self.get_input_keys(state)
6362

64-
# Fetching data from the state based on the input keys
6563
input_data = [state[key] for key in input_keys]
6664

6765
user_prompt = input_data[0]
@@ -88,10 +86,8 @@ def execute(self, state: dict) -> dict:
8886
},
8987
)
9088

91-
# Execute the chain to get probable tags
9289
tag_answer = tag_prompt | self.llm_model | output_parser
9390
probable_tags = tag_answer.invoke({"question": user_prompt})
9491

95-
# Update the dictionary with probable tags
9692
state.update({self.output[0]: probable_tags})
9793
return state

scrapegraphai/nodes/graph_iterator_node.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ async def _async_execute(self, state: dict, batchsize: int) -> dict:
103103
if graph_instance is None:
104104
raise ValueError("graph instance is required for concurrent execution")
105105

106-
# Assign depth level to the graph
107106
if "graph_depth" in graph_instance.config:
108107
graph_instance.config["graph_depth"] += 1
109108
else:
@@ -113,14 +112,12 @@ async def _async_execute(self, state: dict, batchsize: int) -> dict:
113112

114113
participants = []
115114

116-
# semaphore to limit the number of concurrent tasks
117115
semaphore = asyncio.Semaphore(batchsize)
118116

119117
async def _async_run(graph):
120118
async with semaphore:
121119
return await asyncio.to_thread(graph.run)
122120

123-
# creates a deepcopy of the graph instance for each endpoint
124121
for url in urls:
125122
instance = copy.copy(graph_instance)
126123
instance.source = url

scrapegraphai/nodes/merge_answers_node.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,21 +56,17 @@ def execute(self, state: dict) -> dict:
5656

5757
self.logger.info(f"--- Executing {self.node_name} Node ---")
5858

59-
# Interpret input keys based on the provided input expression
6059
input_keys = self.get_input_keys(state)
6160

62-
# Fetching data from the state based on the input keys
6361
input_data = [state[key] for key in input_keys]
6462

6563
user_prompt = input_data[0]
6664
answers = input_data[1]
6765

68-
# merge the answers in one string
6966
answers_str = ""
7067
for i, answer in enumerate(answers):
7168
answers_str += f"CONTENT WEBSITE {i+1}: {answer}\n"
7269

73-
# Initialize the output parser
7470
if self.node_config.get("schema", None) is not None:
7571
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
7672
else:
@@ -90,6 +86,5 @@ def execute(self, state: dict) -> dict:
9086
merge_chain = prompt_template | self.llm_model | output_parser
9187
answer = merge_chain.invoke({"user_prompt": user_prompt})
9288

93-
# Update the state with the generated answer
9489
state.update({self.output[0]: answer})
9590
return state

scrapegraphai/nodes/parse_node.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,11 @@ def execute(self, state: dict) -> dict:
5959

6060
self.logger.info(f"--- Executing {self.node_name} Node ---")
6161

62-
# Interpret input keys based on the provided input expression
6362
input_keys = self.get_input_keys(state)
6463

65-
# Fetching data from the state based on the input keys
6664
input_data = [state[key] for key in input_keys]
67-
# Parse the document
6865
docs_transformed = input_data[0]
66+
6967
if self.parse_html:
7068
docs_transformed = Html2TextTransformer().transform_documents(input_data[0])
7169
docs_transformed = docs_transformed[0]
@@ -77,7 +75,6 @@ def execute(self, state: dict) -> dict:
7775
else:
7876
docs_transformed = docs_transformed[0]
7977

80-
# Adapt the chunk size, leaving room for the reply, the prompt and the schema
8178
chunk_size = self.node_config.get("chunk_size", 4096)
8279
chunk_size = min(chunk_size - 500, int(chunk_size * 0.9))
8380

scrapegraphai/nodes/rag_node.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,8 @@ def execute(self, state: dict) -> dict:
8080

8181
self.logger.info(f"--- Executing {self.node_name} Node ---")
8282

83-
# Interpret input keys based on the provided input expression
8483
input_keys = self.get_input_keys(state)
8584

86-
# Fetching data from the state based on the input keys
8785
input_data = [state[key] for key in input_keys]
8886

8987
user_prompt = input_data[0]
@@ -102,7 +100,6 @@ def execute(self, state: dict) -> dict:
102100

103101
self.logger.info("--- (updated chunks metadata) ---")
104102

105-
# check if embedder_model is provided, if not use llm_model
106103
if self.embedder_model is not None:
107104
embeddings = self.embedder_model
108105
elif 'embeddings' in self.node_config:
@@ -144,23 +141,17 @@ def execute(self, state: dict) -> dict:
144141
pipeline_compressor = DocumentCompressorPipeline(
145142
transformers=[redundant_filter, relevant_filter]
146143
)
147-
# redundant + relevant filter compressor
148144
compression_retriever = ContextualCompressionRetriever(
149145
base_compressor=pipeline_compressor, base_retriever=retriever
150146
)
151147

152-
# relevant filter compressor only
153-
# compression_retriever = ContextualCompressionRetriever(
154-
# base_compressor=relevant_filter, base_retriever=retriever
155-
# )
156-
157148
compressed_docs = compression_retriever.invoke(user_prompt)
158149

159150
self.logger.info("--- (tokens compressed and vector stored) ---")
160151

161152
state.update({self.output[0]: compressed_docs})
162153
return state
163-
154+
164155

165156
def _create_default_embedder(self, llm_config=None) -> object:
166157
"""
@@ -223,7 +214,6 @@ def _create_embedder(self, embedder_config: dict) -> object:
223214
embedder_params = {**embedder_config}
224215
if "model_instance" in embedder_config:
225216
return embedder_params["model_instance"]
226-
# Instantiate the embedding model based on the model name
227217
if "openai" in embedder_params["model"]:
228218
return OpenAIEmbeddings(api_key=embedder_params["api_key"])
229219
if "azure" in embedder_params["model"]:

scrapegraphai/nodes/robots_node.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,8 @@ def execute(self, state: dict) -> dict:
7575

7676
self.logger.info(f"--- Executing {self.node_name} Node ---")
7777

78-
# Interpret input keys based on the provided input expression
7978
input_keys = self.get_input_keys(state)
8079

81-
# Fetching data from the state based on the input keys
8280
input_data = [state[key] for key in input_keys]
8381

8482
source = input_data[0]

scrapegraphai/nodes/search_internet_node.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def execute(self, state: dict) -> dict:
6767

6868
input_keys = self.get_input_keys(state)
6969

70-
# Fetching data from the state based on the input keys
7170
input_data = [state[key] for key in input_keys]
7271

7372
user_prompt = input_data[0]
@@ -79,10 +78,8 @@ def execute(self, state: dict) -> dict:
7978
input_variables=["user_prompt"],
8079
)
8180

82-
# Execute the chain to get the search query
8381
search_answer = search_prompt | self.llm_model | output_parser
84-
85-
# Ollama: Use no json format when creating the search query
82+
8683
if isinstance(self.llm_model, ChatOllama) and self.llm_model.format == 'json':
8784
self.llm_model.format = None
8885
search_query = search_answer.invoke({"user_prompt": user_prompt})[0]
@@ -96,9 +93,7 @@ def execute(self, state: dict) -> dict:
9693
search_engine=self.search_engine)
9794

9895
if len(answer) == 0:
99-
# raise an exception if no answer is found
10096
raise ValueError("Zero results found for the search query.")
10197

102-
# Update the state with the generated answer
10398
state.update({self.output[0]: answer})
10499
return state

scrapegraphai/nodes/search_link_node.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def __init__(
4949
self.filter_config = {**default_filters.filter_dict, **provided_filter_config}
5050
self.filter_links = True
5151
else:
52-
# Skip filtering if not enabled
5352
self.filter_config = None
5453
self.filter_links = False
5554

@@ -58,42 +57,38 @@ def __init__(
5857

5958
def _is_same_domain(self, url, domain):
6059
if not self.filter_links or not self.filter_config.get("diff_domain_filter", True):
61-
return True # Skip the domain filter if not enabled
60+
return True
6261
parsed_url = urlparse(url)
6362
parsed_domain = urlparse(domain)
6463
return parsed_url.netloc == parsed_domain.netloc
6564

6665
def _is_image_url(self, url):
6766
if not self.filter_links:
68-
return False # Skip image filtering if filtering is not enabled
69-
67+
return False
7068
image_extensions = self.filter_config.get("img_exts", [])
7169
return any(url.lower().endswith(ext) for ext in image_extensions)
7270

7371
def _is_language_url(self, url):
7472
if not self.filter_links:
75-
return False # Skip language filtering if filtering is not enabled
73+
return False
7674

7775
lang_indicators = self.filter_config.get("lang_indicators", [])
7876
parsed_url = urlparse(url)
7977
query_params = parse_qs(parsed_url.query)
8078

81-
# Check if the URL path or query string indicates a language-specific version
8279
return any(indicator in parsed_url.path.lower() or indicator in query_params for indicator in lang_indicators)
83-
8480
def _is_potentially_irrelevant(self, url):
8581
if not self.filter_links:
8682
return False # Skip irrelevant URL filtering if filtering is not enabled
8783

8884
irrelevant_keywords = self.filter_config.get("irrelevant_keywords", [])
8985
return any(keyword in url.lower() for keyword in irrelevant_keywords)
9086

91-
87+
9288
def execute(self, state: dict) -> dict:
9389
"""
94-
Filter out relevant links from the webpage that are relavant to prompt. Out of the filtered links, also
95-
ensure that all links are navigable.
96-
90+
Filter out relevant links from the webpage that are relavant to prompt.
91+
Out of the filtered links, also ensure that all links are navigable.
9792
Args:
9893
state (dict): The current state of the graph. The input keys will be used to fetch the
9994
correct data types from the state.
@@ -108,7 +103,6 @@ def execute(self, state: dict) -> dict:
108103

109104
self.logger.info(f"--- Executing {self.node_name} Node ---")
110105

111-
112106
parsed_content_chunks = state.get("doc")
113107
source_url = state.get("url") or state.get("local_dir")
114108
output_parser = JsonOutputParser()
@@ -148,7 +142,7 @@ def execute(self, state: dict) -> dict:
148142
except Exception as e:
149143
# Fallback approach: Using the LLM to extract links
150144
self.logger.error(f"Error extracting links: {e}. Falling back to LLM.")
151-
145+
152146
merge_prompt = PromptTemplate(
153147
template=TEMPLATE_RELEVANT_LINKS,
154148
input_variables=["content", "user_prompt"],

scrapegraphai/nodes/search_node_with_context.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,8 @@ def execute(self, state: dict) -> dict:
5858

5959
self.logger.info(f"--- Executing {self.node_name} Node ---")
6060

61-
# Interpret input keys based on the provided input expression
6261
input_keys = self.get_input_keys(state)
6362

64-
# Fetching data from the state based on the input keys
6563
input_data = [state[key] for key in input_keys]
6664

6765
doc = input_data[1]
@@ -71,7 +69,6 @@ def execute(self, state: dict) -> dict:
7169

7270
result = []
7371

74-
# Use tqdm to add progress bar
7572
for i, chunk in enumerate(
7673
tqdm(doc, desc="Processing chunks", disable=not self.verbose)
7774
):

0 commit comments

Comments
 (0)