Skip to content

Commit 1bcc0bf

Browse files
authored
Merge pull request #620 from goasleep/feature/export_search_engine
feat:expose the search engine params to user
2 parents f51b155 + 8422463 commit 1bcc0bf

12 files changed

+297
-48
lines changed

scrapegraphai/graphs/csv_scraper_multi_graph.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,18 @@
22
CSVScraperMultiGraph Module
33
"""
44

5-
from copy import copy, deepcopy
65
from typing import List, Optional
76
from pydantic import BaseModel
7+
8+
89
from .base_graph import BaseGraph
910
from .abstract_graph import AbstractGraph
1011
from .csv_scraper_graph import CSVScraperGraph
1112
from ..nodes import (
1213
GraphIteratorNode,
1314
MergeAnswersNode
1415
)
16+
from ..utils.copy import safe_deepcopy
1517

1618
class CSVScraperMultiGraph(AbstractGraph):
1719
"""
@@ -46,10 +48,7 @@ def __init__(self, prompt: str, source: List[str],
4648

4749
self.max_results = config.get("max_results", 3)
4850

49-
if all(isinstance(value, str) for value in config.values()):
50-
self.copy_config = copy(config)
51-
else:
52-
self.copy_config = deepcopy(config)
51+
self.copy_config = safe_deepcopy(config)
5352

5453
super().__init__(prompt, config, source, schema)
5554

scrapegraphai/graphs/json_scraper_multi_graph.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,18 @@
22
JSONScraperMultiGraph Module
33
"""
44

5-
from copy import copy, deepcopy
5+
from copy import deepcopy
66
from typing import List, Optional
77
from pydantic import BaseModel
8+
89
from .base_graph import BaseGraph
910
from .abstract_graph import AbstractGraph
1011
from .json_scraper_graph import JSONScraperGraph
1112
from ..nodes import (
1213
GraphIteratorNode,
1314
MergeAnswersNode
1415
)
16+
from ..utils.copy import safe_deepcopy
1517

1618
class JSONScraperMultiGraph(AbstractGraph):
1719
"""
@@ -45,10 +47,7 @@ def __init__(self, prompt: str, source: List[str], config: dict, schema: Optiona
4547

4648
self.max_results = config.get("max_results", 3)
4749

48-
if all(isinstance(value, str) for value in config.values()):
49-
self.copy_config = copy(config)
50-
else:
51-
self.copy_config = deepcopy(config)
50+
self.copy_config = safe_deepcopy(config)
5251

5352
self.copy_schema = deepcopy(schema)
5453

scrapegraphai/graphs/markdown_scraper_multi_graph.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
GraphIteratorNode,
1313
MergeAnswersNode
1414
)
15+
from ..utils.copy import safe_deepcopy
1516

1617
class MDScraperMultiGraph(AbstractGraph):
1718
"""
@@ -42,11 +43,7 @@ class MDScraperMultiGraph(AbstractGraph):
4243
"""
4344

4445
def __init__(self, prompt: str, source: List[str], config: dict, schema: Optional[BaseModel] = None):
45-
if all(isinstance(value, str) for value in config.values()):
46-
self.copy_config = copy(config)
47-
else:
48-
self.copy_config = deepcopy(config)
49-
46+
self.copy_config = safe_deepcopy(config)
5047
self.copy_schema = deepcopy(schema)
5148

5249
super().__init__(prompt, config, source, schema)

scrapegraphai/graphs/omni_search_graph.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
OmniSearchGraph Module
33
"""
44

5-
from copy import copy, deepcopy
5+
from copy import deepcopy
66
from typing import Optional
77
from pydantic import BaseModel
88

@@ -15,6 +15,7 @@
1515
GraphIteratorNode,
1616
MergeAnswersNode
1717
)
18+
from ..utils.copy import safe_deepcopy
1819

1920

2021
class OmniSearchGraph(AbstractGraph):
@@ -48,10 +49,7 @@ def __init__(self, prompt: str, config: dict, schema: Optional[BaseModel] = None
4849

4950
self.max_results = config.get("max_results", 3)
5051

51-
if all(isinstance(value, str) for value in config.values()):
52-
self.copy_config = copy(config)
53-
else:
54-
self.copy_config = deepcopy(config)
52+
self.copy_config = safe_deepcopy(config)
5553

5654
self.copy_schema = deepcopy(schema)
5755

@@ -85,7 +83,8 @@ def _create_graph(self) -> BaseGraph:
8583
output=["urls"],
8684
node_config={
8785
"llm_model": self.llm_model,
88-
"max_results": self.max_results
86+
"max_results": self.max_results,
87+
"search_engine": self.copy_config.get("search_engine")
8988
}
9089
)
9190
graph_iterator_node = GraphIteratorNode(

scrapegraphai/graphs/pdf_scraper_multi_graph.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
PdfScraperMultiGraph Module
33
"""
44

5-
from copy import copy, deepcopy
5+
from copy import deepcopy
66
from typing import List, Optional
77
from pydantic import BaseModel
88
from .base_graph import BaseGraph
@@ -12,6 +12,7 @@
1212
GraphIteratorNode,
1313
MergeAnswersNode
1414
)
15+
from ..utils.copy import safe_deepcopy
1516

1617
class PdfScraperMultiGraph(AbstractGraph):
1718
"""
@@ -44,10 +45,7 @@ class PdfScraperMultiGraph(AbstractGraph):
4445
def __init__(self, prompt: str, source: List[str],
4546
config: dict, schema: Optional[BaseModel] = None):
4647

47-
if all(isinstance(value, str) for value in config.values()):
48-
self.copy_config = copy(config)
49-
else:
50-
self.copy_config = deepcopy(config)
48+
self.copy_config = safe_deepcopy(config)
5149

5250
self.copy_schema = deepcopy(schema)
5351

scrapegraphai/graphs/script_creator_multi_graph.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
ScriptCreatorMultiGraph Module
33
"""
44

5-
from copy import copy, deepcopy
65
from typing import List, Optional
76

87
from pydantic import BaseModel
@@ -15,6 +14,7 @@
1514
GraphIteratorNode,
1615
MergeGeneratedScriptsNode
1716
)
17+
from ..utils.copy import safe_deepcopy
1818

1919
class ScriptCreatorMultiGraph(AbstractGraph):
2020
"""
@@ -47,10 +47,7 @@ def __init__(self, prompt: str, source: List[str], config: dict, schema: Optiona
4747

4848
self.max_results = config.get("max_results", 3)
4949

50-
if all(isinstance(value, str) for value in config.values()):
51-
self.copy_config = copy(config)
52-
else:
53-
self.copy_config = deepcopy(config)
50+
self.copy_config = safe_deepcopy(config)
5451

5552
super().__init__(prompt, config, source, schema)
5653

scrapegraphai/graphs/search_graph.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
SearchGraph Module
33
"""
44

5-
from copy import copy, deepcopy
5+
from copy import deepcopy
66
from typing import Optional, List
77
from pydantic import BaseModel
88

@@ -15,6 +15,7 @@
1515
GraphIteratorNode,
1616
MergeAnswersNode
1717
)
18+
from ..utils.copy import safe_deepcopy
1819

1920
class SearchGraph(AbstractGraph):
2021
"""
@@ -47,10 +48,7 @@ class SearchGraph(AbstractGraph):
4748
def __init__(self, prompt: str, config: dict, schema: Optional[BaseModel] = None):
4849
self.max_results = config.get("max_results", 3)
4950

50-
if all(isinstance(value, str) for value in config.values()):
51-
self.copy_config = copy(config)
52-
else:
53-
self.copy_config = deepcopy(config)
51+
self.copy_config = safe_deepcopy(config)
5452
self.copy_schema = deepcopy(schema)
5553
self.considered_urls = [] # New attribute to store URLs
5654

@@ -78,7 +76,8 @@ def _create_graph(self) -> BaseGraph:
7876
output=["urls"],
7977
node_config={
8078
"llm_model": self.llm_model,
81-
"max_results": self.max_results
79+
"max_results": self.max_results,
80+
"search_engine": self.copy_config.get("search_engine")
8281
}
8382
)
8483
graph_iterator_node = GraphIteratorNode(

scrapegraphai/graphs/smart_scraper_multi_graph.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
SmartScraperMultiGraph Module
33
"""
44

5-
from copy import copy, deepcopy
5+
from copy import deepcopy
66
from typing import List, Optional
77
from pydantic import BaseModel
88

@@ -14,6 +14,7 @@
1414
GraphIteratorNode,
1515
MergeAnswersNode
1616
)
17+
from ..utils.copy import safe_deepcopy
1718

1819
class SmartScraperMultiGraph(AbstractGraph):
1920
"""
@@ -48,10 +49,7 @@ def __init__(self, prompt: str, source: List[str],
4849

4950
self.max_results = config.get("max_results", 3)
5051

51-
if all(isinstance(value, str) for value in config.values()):
52-
self.copy_config = copy(config)
53-
else:
54-
self.copy_config = deepcopy(config)
52+
self.copy_config = safe_deepcopy(config)
5553

5654
self.copy_schema = deepcopy(schema)
5755

scrapegraphai/graphs/xml_scraper_multi_graph.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
XMLScraperMultiGraph Module
33
"""
44

5-
from copy import copy, deepcopy
5+
from copy import deepcopy
66
from typing import List, Optional
77
from pydantic import BaseModel
88

@@ -14,6 +14,7 @@
1414
GraphIteratorNode,
1515
MergeAnswersNode
1616
)
17+
from ..utils.copy import safe_deepcopy
1718

1819
class XMLScraperMultiGraph(AbstractGraph):
1920
"""
@@ -46,10 +47,7 @@ class XMLScraperMultiGraph(AbstractGraph):
4647
def __init__(self, prompt: str, source: List[str],
4748
config: dict, schema: Optional[BaseModel] = None):
4849

49-
if all(isinstance(value, str) for value in config.values()):
50-
self.copy_config = copy(config)
51-
else:
52-
self.copy_config = deepcopy(config)
50+
self.copy_config = safe_deepcopy(config)
5351

5452
self.copy_schema = deepcopy(schema)
5553

scrapegraphai/nodes/search_internet_node.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,11 @@ def __init__(
4141
self.verbose = (
4242
False if node_config is None else node_config.get("verbose", False)
4343
)
44-
self.search_engine = node_config.get("search_engine", "google")
44+
self.search_engine = (
45+
node_config["search_engine"]
46+
if node_config.get("search_engine")
47+
else "google"
48+
)
4549
self.max_results = node_config.get("max_results", 3)
4650

4751
def execute(self, state: dict) -> dict:

scrapegraphai/utils/copy.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import copy
2+
from typing import Any, Dict, Optional
3+
from pydantic.v1 import BaseModel
4+
5+
class DeepCopyError(Exception):
6+
"""Custom exception raised when an object cannot be deep-copied."""
7+
pass
8+
9+
def safe_deepcopy(obj: Any) -> Any:
10+
"""
11+
Attempts to create a deep copy of the object using `copy.deepcopy`
12+
whenever possible. If that fails, it falls back to custom deep copy
13+
logic. If that also fails, it raises a `DeepCopyError`.
14+
15+
Args:
16+
obj (Any): The object to be copied, which can be of any type.
17+
18+
Returns:
19+
Any: A deep copy of the object if possible; otherwise, a shallow
20+
copy if deep copying fails; if neither is possible, the original
21+
object is returned.
22+
Raises:
23+
DeepCopyError: If the object cannot be deep-copied or shallow-copied.
24+
"""
25+
26+
try:
27+
28+
# Try to use copy.deepcopy first
29+
return copy.deepcopy(obj)
30+
except (TypeError, AttributeError) as e:
31+
# If deepcopy fails, handle specific types manually
32+
33+
# Handle dictionaries
34+
if isinstance(obj, dict):
35+
new_obj = {}
36+
37+
for k, v in obj.items():
38+
new_obj[k] = safe_deepcopy(v)
39+
return new_obj
40+
41+
# Handle lists
42+
elif isinstance(obj, list):
43+
new_obj = []
44+
45+
for v in obj:
46+
new_obj.append(safe_deepcopy(v))
47+
return new_obj
48+
49+
# Handle tuples (immutable, but might contain mutable objects)
50+
elif isinstance(obj, tuple):
51+
new_obj = tuple(safe_deepcopy(v) for v in obj)
52+
53+
return new_obj
54+
55+
# Handle frozensets (immutable, but might contain mutable objects)
56+
elif isinstance(obj, frozenset):
57+
new_obj = frozenset(safe_deepcopy(v) for v in obj)
58+
return new_obj
59+
60+
# Handle objects with attributes
61+
elif hasattr(obj, "__dict__"):
62+
# If an object cannot be deep copied, then the sub-properties of \
63+
# the object will not be analyzed and shallow copy will be used directly.
64+
try:
65+
return copy.copy(obj)
66+
except (TypeError, AttributeError):
67+
raise DeepCopyError(f"Cannot deep copy the object of type {type(obj)}") from e
68+
69+
70+
# Attempt shallow copy as a fallback
71+
try:
72+
return copy.copy(obj)
73+
except (TypeError, AttributeError):
74+
raise DeepCopyError(f"Cannot deep copy the object of type {type(obj)}") from e
75+

0 commit comments

Comments
 (0)