Skip to content

Commit fe3aa28

Browse files
committed
refactoring of the code
1 parent 7621a7c commit fe3aa28

40 files changed

+84
-193
lines changed

scrapegraphai/graphs/abstract_graph.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""
22
AbstractGraph Module
33
"""
4-
54
from abc import ABC, abstractmethod
65
from typing import Optional
76
import uuid
@@ -122,7 +121,7 @@ def _create_llm(self, llm_config: dict) -> object:
122121
llm_defaults = {"temperature": 0, "streaming": False}
123122
llm_params = {**llm_defaults, **llm_config}
124123
rate_limit_params = llm_params.pop("rate_limit", {})
125-
124+
126125
if rate_limit_params:
127126
requests_per_second = rate_limit_params.get("requests_per_second")
128127
max_retries = rate_limit_params.get("max_retries")
@@ -138,7 +137,7 @@ def _create_llm(self, llm_config: dict) -> object:
138137
self.model_token = llm_params["model_tokens"]
139138
except KeyError as exc:
140139
raise KeyError("model_tokens not specified") from exc
141-
return llm_params["model_instance"]
140+
return llm_params["model_instance"]
142141

143142
known_providers = {"openai", "azure_openai", "google_genai", "google_vertexai",
144143
"ollama", "oneapi", "nvidia", "groq", "anthropic", "bedrock", "mistralai",
@@ -149,16 +148,18 @@ def _create_llm(self, llm_config: dict) -> object:
149148
llm_params["model"] = split_model_provider[1]
150149

151150
if llm_params["model_provider"] not in known_providers:
152-
raise ValueError(f"Provider {llm_params['model_provider']} is not supported. If possible, try to use a model instance instead.")
151+
raise ValueError(f"""Provider {llm_params['model_provider']} is not supported.
152+
If possible, try to use a model instance instead.""")
153153

154154
try:
155155
self.model_token = models_tokens[llm_params["model_provider"]][llm_params["model"]]
156156
except KeyError:
157-
print(f"Model {llm_params['model_provider']}/{llm_params['model']} not found, using default token size (8192)")
157+
print(f"""Model {llm_params['model_provider']}/{llm_params['model']} not found,
158+
using default token size (8192)""")
158159
self.model_token = 8192
159160

160161
try:
161-
if llm_params["model_provider"] not in {"oneapi", "nvidia", "ernie", "deepseek", "togetherai"}:
162+
if llm_params["model_provider"] not in {"oneapi","nvidia","ernie","deepseek","togetherai"}:
162163
if llm_params["model_provider"] == "bedrock":
163164
llm_params["model_kwargs"] = { "temperature" : llm_params.pop("temperature") }
164165
with warnings.catch_warnings():
@@ -181,14 +182,16 @@ def _create_llm(self, llm_config: dict) -> object:
181182
try:
182183
from langchain_together import ChatTogether
183184
except ImportError:
184-
raise ImportError("The langchain_together module is not installed. Please install it using `pip install scrapegraphai[other-language-models]`.")
185+
raise ImportError("""The langchain_together module is not installed.
186+
Please install it using `pip install scrapegraphai[other-language-models]`.""")
185187
return ChatTogether(**llm_params)
186188

187189
elif model_provider == "nvidia":
188190
try:
189191
from langchain_nvidia_ai_endpoints import ChatNVIDIA
190192
except ImportError:
191-
raise ImportError("The langchain_nvidia_ai_endpoints module is not installed. Please install it using `pip install scrapegraphai[other-language-models]`.")
193+
raise ImportError("""The langchain_nvidia_ai_endpoints module is not installed.
194+
Please install it using `pip install scrapegraphai[other-language-models]`.""")
192195
return ChatNVIDIA(**llm_params)
193196

194197
except Exception as e:

scrapegraphai/graphs/base_graph.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -116,36 +116,28 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
116116
curr_time = time.time()
117117
current_node = next(node for node in self.nodes if node.node_name == current_node_name)
118118

119-
# check if there is a "source" key in the node config
120119
if current_node.__class__.__name__ == "FetchNode":
121-
# get the second key name of the state dictionary
122120
source_type = list(state.keys())[1]
123121
if state.get("user_prompt", None):
124-
# Set 'prompt' if 'user_prompt' is a string, otherwise None
125122
prompt = state["user_prompt"] if isinstance(state["user_prompt"], str) else None
126123

127-
# Convert 'local_dir' source type to 'html_dir'
128124
if source_type == "local_dir":
129125
source_type = "html_dir"
130126
elif source_type == "url":
131-
# If the source is a list, add string URLs to 'source'
132127
if isinstance(state[source_type], list):
133128
for url in state[source_type]:
134129
if isinstance(url, str):
135130
source.append(url)
136-
# If the source is a single string, add it to 'source'
137131
elif isinstance(state[source_type], str):
138132
source.append(state[source_type])
139133

140-
# check if there is an "llm_model" variable in the class
141134
if hasattr(current_node, "llm_model") and llm_model is None:
142135
llm_model = current_node.llm_model
143136
if hasattr(llm_model, "model_name"):
144137
llm_model = llm_model.model_name
145138
elif hasattr(llm_model, "model"):
146139
llm_model = llm_model.model
147140

148-
# check if there is an "embedder_model" variable in the class
149141
if hasattr(current_node, "embedder_model") and embedder_model is None:
150142
embedder_model = current_node.embedder_model
151143
if hasattr(embedder_model, "model_name"):
@@ -157,7 +149,6 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
157149
if isinstance(current_node.node_config,dict):
158150
if current_node.node_config.get("schema", None) and schema is None:
159151
if not isinstance(current_node.node_config["schema"], dict):
160-
# convert to dict
161152
try:
162153
schema = current_node.node_config["schema"].schema()
163154
except Exception as e:
@@ -220,7 +211,6 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
220211
"exec_time": total_exec_time,
221212
})
222213

223-
# Log the graph execution telemetry
224214
graph_execution_time = time.time() - start_time
225215
response = state.get("answer", None) if source_type == "url" else None
226216
content = state.get("parsed_doc", None) if response is not None else None
@@ -272,13 +262,10 @@ def append_node(self, node):
272262

273263
# if node name already exists in the graph, raise an exception
274264
if node.node_name in {n.node_name for n in self.nodes}:
275-
raise ValueError(f"Node with name '{node.node_name}' already exists in the graph. You can change it by setting the 'node_name' attribute.")
265+
raise ValueError(f"""Node with name '{node.node_name}' already exists in the graph.
266+
You can change it by setting the 'node_name' attribute.""")
276267

277-
# get the last node in the list
278268
last_node = self.nodes[-1]
279-
# add the edge connecting the last node to the new node
280269
self.raw_edges.append((last_node, node))
281-
# add the node to the list of nodes
282270
self.nodes.append(node)
283-
# update the edges connecting the last node to the new node
284271
self.edges = self._create_edges({e for e in self.raw_edges})

scrapegraphai/graphs/csv_scraper_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ class CSVScraperGraph(AbstractGraph):
4343
the answer to the prompt as a string.
4444
run runs the CSVScraperGraph class to extract information from a CSV file based
4545
on the user's prompt. It requires no additional arguments since all necessary data
46-
is stored within the class instance. The method fetches the relevant chunks of text or speech,
46+
is stored within the class instance.
47+
The method fetches the relevant chunks of text or speech,
4748
generates an answer based on these chunks, and returns this answer as a string.
4849
"""
4950

scrapegraphai/graphs/csv_scraper_multi_graph.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44

55
from typing import List, Optional
66
from pydantic import BaseModel
7-
8-
97
from .base_graph import BaseGraph
108
from .abstract_graph import AbstractGraph
119
from .csv_scraper_graph import CSVScraperGraph
@@ -60,20 +58,12 @@ def _create_graph(self) -> BaseGraph:
6058
BaseGraph: A graph instance representing the web scraping and searching workflow.
6159
"""
6260

63-
# ************************************************
64-
# Create a CSVScraperGraph instance
65-
# ************************************************
66-
6761
smart_scraper_instance = CSVScraperGraph(
6862
prompt="",
6963
source="",
7064
config=self.copy_config,
7165
)
7266

73-
# ************************************************
74-
# Define the graph nodes
75-
# ************************************************
76-
7767
graph_iterator_node = GraphIteratorNode(
7868
input="user_prompt & jsons",
7969
output=["results"],

scrapegraphai/graphs/deep_scraper_graph.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""
22
DeepScraperGraph Module
33
"""
4-
54
from typing import Optional
65
from pydantic import BaseModel
76
from .base_graph import BaseGraph
@@ -54,7 +53,7 @@ class DeepScraperGraph(AbstractGraph):
5453
"""
5554

5655
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
57-
56+
5857
super().__init__(prompt, config, source, schema)
5958

6059
self.input_key = "url" if source.startswith("http") else "local_dir"
@@ -79,7 +78,7 @@ def _create_repeated_graph(self) -> BaseGraph:
7978
"llm_model": self.llm_model
8079
}
8180
)
82-
81+
8382
generate_answer_node = GenerateAnswerNode(
8483
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
8584
output=["answer"],
@@ -89,13 +88,15 @@ def _create_repeated_graph(self) -> BaseGraph:
8988
"schema": self.schema
9089
}
9190
)
91+
9292
search_node = SearchLinkNode(
9393
input="user_prompt & relevant_chunks",
9494
output=["relevant_links"],
9595
node_config={
9696
"llm_model": self.llm_model,
9797
}
9898
)
99+
99100
graph_iterator_node = GraphIteratorNode(
100101
input="user_prompt & relevant_links",
101102
output=["results"],
@@ -104,6 +105,7 @@ def _create_repeated_graph(self) -> BaseGraph:
104105
"batchsize": 1
105106
}
106107
)
108+
107109
merge_answers_node = MergeAnswersNode(
108110
input="user_prompt & results",
109111
output=["answer"],
@@ -143,8 +145,8 @@ def _create_graph(self) -> BaseGraph:
143145
"""
144146

145147
base_graph = self._create_repeated_graph()
146-
graph_iterator_node = list(filter(lambda x: x.node_name == "GraphIterator", base_graph.nodes))[0]
147-
# Graph iterator will repeat the same graph for multiple hyperlinks found within input webpage
148+
graph_iterator_node = list(filter(lambda x: x.node_name == "GraphIterator",
149+
base_graph.nodes))[0]
148150
graph_iterator_node.node_config["graph_instance"] = self
149151
return base_graph
150152

scrapegraphai/graphs/json_scraper_graph.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""
22
JSONScraperGraph Module
33
"""
4-
54
from typing import Optional
65
from pydantic import BaseModel
76
from .base_graph import BaseGraph

scrapegraphai/graphs/json_scraper_multi_graph.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from copy import deepcopy
66
from typing import List, Optional
77
from pydantic import BaseModel
8-
98
from .base_graph import BaseGraph
109
from .abstract_graph import AbstractGraph
1110
from .json_scraper_graph import JSONScraperGraph
@@ -43,7 +42,8 @@ class JSONScraperMultiGraph(AbstractGraph):
4342
>>> result = search_graph.run()
4443
"""
4544

46-
def __init__(self, prompt: str, source: List[str], config: dict, schema: Optional[BaseModel] = None):
45+
def __init__(self, prompt: str, source: List[str],
46+
config: dict, schema: Optional[BaseModel] = None):
4747

4848
self.max_results = config.get("max_results", 3)
4949

@@ -61,21 +61,13 @@ def _create_graph(self) -> BaseGraph:
6161
BaseGraph: A graph instance representing the web scraping and searching workflow.
6262
"""
6363

64-
# ************************************************
65-
# Create a JSONScraperGraph instance
66-
# ************************************************
67-
6864
smart_scraper_instance = JSONScraperGraph(
6965
prompt="",
7066
source="",
7167
config=self.copy_config,
7268
schema=self.copy_schema
7369
)
7470

75-
# ************************************************
76-
# Define the graph nodes
77-
# ************************************************
78-
7971
graph_iterator_node = GraphIteratorNode(
8072
input="user_prompt & jsons",
8173
output=["results"],

scrapegraphai/graphs/markdown_scraper_graph.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
"""
2+
md_scraper module
3+
"""
14
from typing import Optional
25
import logging
36
from pydantic import BaseModel
@@ -17,7 +20,8 @@ class MDScraperGraph(AbstractGraph):
1720
config (dict): Configuration parameters for the graph.
1821
schema (BaseModel): The schema for the graph output.
1922
llm_model: An instance of a language model client, configured for generating answers.
20-
embedder_model: An instance of an embedding model client, configured for generating embeddings.
23+
embedder_model: An instance of an embedding model client,
24+
configured for generating embeddings.
2125
verbose (bool): A flag indicating whether to show print statements during execution.
2226
headless (bool): A flag indicating whether to run the graph in headless mode.
2327

scrapegraphai/graphs/markdown_scraper_multi_graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""
22
MDScraperMultiGraph Module
33
"""
4-
54
from copy import copy, deepcopy
65
from typing import List, Optional
76
from pydantic import BaseModel
@@ -42,7 +41,8 @@ class MDScraperMultiGraph(AbstractGraph):
4241
>>> result = search_graph.run()
4342
"""
4443

45-
def __init__(self, prompt: str, source: List[str], config: dict, schema: Optional[BaseModel] = None):
44+
def __init__(self, prompt: str, source: List[str],
45+
config: dict, schema: Optional[BaseModel] = None):
4646
self.copy_config = safe_deepcopy(config)
4747
self.copy_schema = deepcopy(schema)
4848

scrapegraphai/graphs/omni_search_graph.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,9 @@
55
from copy import deepcopy
66
from typing import Optional
77
from pydantic import BaseModel
8-
98
from .base_graph import BaseGraph
109
from .abstract_graph import AbstractGraph
1110
from .omni_scraper_graph import OmniScraperGraph
12-
1311
from ..nodes import (
1412
SearchInternetNode,
1513
GraphIteratorNode,
@@ -63,21 +61,13 @@ def _create_graph(self) -> BaseGraph:
6361
BaseGraph: A graph instance representing the web scraping and searching workflow.
6462
"""
6563

66-
# ************************************************
67-
# Create a OmniScraperGraph instance
68-
# ************************************************
69-
7064
omni_scraper_instance = OmniScraperGraph(
7165
prompt="",
7266
source="",
7367
config=self.copy_config,
7468
schema=self.copy_schema
7569
)
7670

77-
# ************************************************
78-
# Define the graph nodes
79-
# ************************************************
80-
8171
search_internet_node = SearchInternetNode(
8272
input="user_prompt",
8373
output=["urls"],

scrapegraphai/graphs/pdf_scraper_graph.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
"""
33
PDFScraperGraph Module
44
"""
5-
65
from typing import Optional
76
from pydantic import BaseModel
87
from .base_graph import BaseGraph

scrapegraphai/graphs/pdf_scraper_multi_graph.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,21 +59,13 @@ def _create_graph(self) -> BaseGraph:
5959
BaseGraph: A graph instance representing the web scraping and searching workflow.
6060
"""
6161

62-
# ************************************************
63-
# Create a PDFScraperGraph instance
64-
# ************************************************
65-
6662
pdf_scraper_instance = PDFScraperGraph(
6763
prompt="",
6864
source="",
6965
config=self.copy_config,
7066
schema=self.copy_schema
7167
)
7268

73-
# ************************************************
74-
# Define the graph nodes
75-
# ************************************************
76-
7769
graph_iterator_node = GraphIteratorNode(
7870
input="user_prompt & pdfs",
7971
output=["results"],

scrapegraphai/graphs/script_creator_graph.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""
22
ScriptCreatorGraph Module
33
"""
4-
54
from typing import Optional
65
from pydantic import BaseModel
76
from .base_graph import BaseGraph

0 commit comments

Comments
 (0)