Skip to content

Commit 4f4d091

Browse files
committed
feat:add deepcopy tool
1 parent ceb522f commit 4f4d091

File tree

2 files changed

+248
-0
lines changed

2 files changed

+248
-0
lines changed

scrapegraphai/utils/copy.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import copy
2+
from typing import Any, Dict, Optional
3+
4+
5+
def safe_deepcopy(obj: Any, memo: Optional[Dict[int, Any]] = None) -> Any:
6+
"""
7+
Attempts to create a deep copy of the object using `copy.deepcopy`
8+
whenever possible. If that fails, it falls back to custom deep copy
9+
logic or returns the original object.
10+
11+
Args:
12+
obj (Any): The object to be copied, which can be of any type.
13+
memo (Optional[Dict[int, Any]]): A dictionary used to track objects
14+
that have already been copied to handle circular references.
15+
If None, a new dictionary is created.
16+
17+
Returns:
18+
Any: A deep copy of the object if possible; otherwise, a shallow
19+
copy if deep copying fails; if neither is possible, the original
20+
object is returned.
21+
"""
22+
23+
if memo is None:
24+
memo = {}
25+
26+
if id(obj) in memo:
27+
return memo[id(obj)]
28+
29+
try:
30+
# Try to use copy.deepcopy first
31+
return copy.deepcopy(obj, memo)
32+
except (TypeError, AttributeError):
33+
# If deepcopy fails, handle specific types manually
34+
35+
# Handle dictionaries
36+
if isinstance(obj, dict):
37+
new_obj = {}
38+
memo[id(obj)] = new_obj
39+
for k, v in obj.items():
40+
new_obj[k] = safe_deepcopy(v, memo)
41+
return new_obj
42+
43+
# Handle lists
44+
elif isinstance(obj, list):
45+
new_obj = []
46+
memo[id(obj)] = new_obj
47+
for v in obj:
48+
new_obj.append(safe_deepcopy(v, memo))
49+
return new_obj
50+
51+
# Handle tuples (immutable, but might contain mutable objects)
52+
elif isinstance(obj, tuple):
53+
new_obj = tuple(safe_deepcopy(v, memo) for v in obj)
54+
memo[id(obj)] = new_obj
55+
return new_obj
56+
57+
# Handle frozensets (immutable, but might contain mutable objects)
58+
elif isinstance(obj, frozenset):
59+
new_obj = frozenset(safe_deepcopy(v, memo) for v in obj)
60+
memo[id(obj)] = new_obj
61+
return new_obj
62+
63+
# Handle objects with attributes
64+
elif hasattr(obj, "__dict__"):
65+
new_obj = obj.__new__(obj.__class__)
66+
for attr in obj.__dict__:
67+
setattr(new_obj, attr, safe_deepcopy(getattr(obj, attr), memo))
68+
memo[id(obj)] = new_obj
69+
return new_obj
70+
71+
# Attempt shallow copy as a fallback
72+
try:
73+
return copy.copy(obj)
74+
except (TypeError, AttributeError):
75+
pass
76+
77+
# If all else fails, return the original object
78+
return obj

tests/utils/copy_utils_test.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
import copy
2+
import pytest
3+
4+
# Assuming the custom_deepcopy function is imported or defined above this line
5+
from scrapegraphai.utils.copy import safe_deepcopy
6+
7+
8+
class NormalObject:
9+
def __init__(self, value):
10+
self.value = value
11+
self.nested = [1, 2, 3]
12+
13+
def __deepcopy__(self, memo):
14+
raise TypeError("Forcing fallback")
15+
16+
17+
class NonDeepcopyable:
18+
def __init__(self, value):
19+
self.value = value
20+
21+
def __deepcopy__(self, memo):
22+
raise TypeError("Forcing shallow copy fallback")
23+
24+
25+
class WithoutDict:
26+
__slots__ = ["value"]
27+
28+
def __init__(self, value):
29+
self.value = value
30+
31+
def __deepcopy__(self, memo):
32+
raise TypeError("Forcing shallow copy fallback")
33+
34+
def __copy__(self):
35+
return self
36+
37+
38+
class NonCopyableObject:
39+
__slots__ = ["value"]
40+
41+
def __init__(self, value):
42+
self.value = value
43+
44+
def __deepcopy__(self, memo):
45+
raise TypeError("fail deep copy ")
46+
47+
def __copy__(self):
48+
raise TypeError("fail shallow copy")
49+
50+
51+
def test_deepcopy_simple_dict():
52+
original = {"a": 1, "b": 2, "c": [3, 4, 5]}
53+
copy_obj = safe_deepcopy(original)
54+
assert copy_obj == original
55+
assert copy_obj is not original
56+
assert copy_obj["c"] is not original["c"]
57+
58+
59+
def test_deepcopy_simple_list():
60+
original = [1, 2, 3, [4, 5]]
61+
copy_obj = safe_deepcopy(original)
62+
assert copy_obj == original
63+
assert copy_obj is not original
64+
assert copy_obj[3] is not original[3]
65+
66+
67+
def test_deepcopy_with_tuple():
68+
original = (1, 2, [3, 4])
69+
copy_obj = safe_deepcopy(original)
70+
assert copy_obj == original
71+
assert copy_obj is not original
72+
assert copy_obj[2] is not original[2]
73+
74+
75+
def test_deepcopy_with_frozenset():
76+
original = frozenset([1, 2, 3, (4, 5)])
77+
copy_obj = safe_deepcopy(original)
78+
assert copy_obj == original
79+
assert copy_obj is not original
80+
81+
82+
def test_deepcopy_with_object():
83+
original = NormalObject(10)
84+
copy_obj = safe_deepcopy(original)
85+
assert copy_obj.value == original.value
86+
assert copy_obj is not original
87+
assert copy_obj.nested is not original.nested
88+
89+
90+
def test_deepcopy_with_custom_deepcopy_fallback():
91+
original = {"origin": NormalObject(10)}
92+
copy_obj = safe_deepcopy(original)
93+
assert copy_obj is not original
94+
assert copy_obj["origin"].value == original["origin"].value
95+
96+
97+
def test_shallow_copy_fallback():
98+
original = {"origin": NonDeepcopyable(10)}
99+
copy_obj = safe_deepcopy(original)
100+
assert copy_obj is not original
101+
assert copy_obj["origin"].value == original["origin"].value
102+
103+
104+
def test_circular_reference():
105+
original = []
106+
original.append(original)
107+
copy_obj = safe_deepcopy(original)
108+
assert copy_obj is not original
109+
assert copy_obj[0] is copy_obj
110+
111+
112+
def test_memoization():
113+
original = {"a": 1, "b": 2}
114+
memo = {}
115+
copy_obj = safe_deepcopy(original, memo)
116+
assert copy_obj is memo[id(original)]
117+
118+
119+
def test_deepcopy_object_without_dict():
120+
original = {"origin": WithoutDict(10)}
121+
copy_obj = safe_deepcopy(original)
122+
assert copy_obj["origin"].value == original["origin"].value
123+
assert copy_obj is not original
124+
assert copy_obj["origin"] is original["origin"]
125+
assert (
126+
hasattr(copy_obj["origin"], "__dict__") is False
127+
) # Ensure __dict__ is not present
128+
129+
original = [WithoutDict(10)]
130+
copy_obj = safe_deepcopy(original)
131+
assert copy_obj[0].value == original[0].value
132+
assert copy_obj is not original
133+
assert copy_obj[0] is original[0]
134+
135+
original = (WithoutDict(10),)
136+
copy_obj = safe_deepcopy(original)
137+
assert copy_obj[0].value == original[0].value
138+
assert copy_obj is not original
139+
assert copy_obj[0] is original[0]
140+
141+
original_item = WithoutDict(10)
142+
original = set([original_item])
143+
copy_obj = safe_deepcopy(original)
144+
assert copy_obj is not original
145+
copy_obj_item = copy_obj.pop()
146+
assert copy_obj_item.value == original_item.value
147+
assert copy_obj_item is original_item
148+
149+
original_item = WithoutDict(10)
150+
original = frozenset([original_item])
151+
copy_obj = safe_deepcopy(original)
152+
assert copy_obj is not original
153+
copy_obj_item = list(copy_obj)[0]
154+
assert copy_obj_item.value == original_item.value
155+
assert copy_obj_item is original_item
156+
157+
def test_memo():
158+
obj = NormalObject(10)
159+
original = {"origin": obj}
160+
memo = {id(original):obj}
161+
copy_obj = safe_deepcopy(original, memo)
162+
assert copy_obj is memo[id(original)]
163+
164+
def test_unhandled_type():
165+
original = {"origin": NonCopyableObject(10)}
166+
copy_obj = safe_deepcopy(original)
167+
assert copy_obj["origin"].value == original["origin"].value
168+
assert copy_obj is not original
169+
assert copy_obj["origin"] is original["origin"]
170+
assert hasattr(copy_obj, "__dict__") is False # Ensure __dict__ is not present

0 commit comments

Comments
 (0)