Skip to content

Commit 553527a

Browse files
committed
fix: fix pydantic object copy
1 parent 71b22d4 commit 553527a

File tree

2 files changed

+21
-22
lines changed

2 files changed

+21
-22
lines changed

scrapegraphai/utils/copy.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ def safe_deepcopy(obj: Any) -> Any:
1010
"""
1111
Attempts to create a deep copy of the object using `copy.deepcopy`
1212
whenever possible. If that fails, it falls back to custom deep copy
13-
logic or returns the original object.
14-
13+
logic. If that also fails, it raises a `DeepCopyError`.
14+
1515
Args:
1616
obj (Any): The object to be copied, which can be of any type.
1717
@@ -26,13 +26,7 @@ def safe_deepcopy(obj: Any) -> Any:
2626
try:
2727

2828
# Try to use copy.deepcopy first
29-
if isinstance(obj,BaseModel):
30-
# handle BaseModel because __fields_set__ need compatibility
31-
copied_obj = obj.copy(deep=True)
32-
else:
33-
copied_obj = copy.deepcopy(obj)
34-
35-
return copied_obj
29+
return copy.deepcopy(obj)
3630
except (TypeError, AttributeError) as e:
3731
# If deepcopy fails, handle specific types manually
3832

@@ -65,14 +59,17 @@ def safe_deepcopy(obj: Any) -> Any:
6559

6660
# Handle objects with attributes
6761
elif hasattr(obj, "__dict__"):
68-
new_obj = obj.__new__(obj.__class__)
69-
for attr in obj.__dict__:
70-
setattr(new_obj, attr, safe_deepcopy(getattr(obj, attr)))
71-
72-
return new_obj
73-
62+
# If an object cannot be deep copied, then the sub-properties of \
63+
# the object will not be analyzed and shallow copy will be used directly.
64+
try:
65+
return copy.copy(obj)
66+
except (TypeError, AttributeError):
67+
raise DeepCopyError(f"Cannot deep copy the object of type {type(obj)}") from e
68+
69+
7470
# Attempt shallow copy as a fallback
7571
try:
7672
return copy.copy(obj)
7773
except (TypeError, AttributeError):
7874
raise DeepCopyError(f"Cannot deep copy the object of type {type(obj)}") from e
75+

tests/utils/copy_utils_test.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,10 @@
44
# Assuming the custom_deepcopy function is imported or defined above this line
55
from scrapegraphai.utils.copy import DeepCopyError, safe_deepcopy
66
from pydantic.v1 import BaseModel
7-
from pydantic import BaseModel as BaseModelV2
87

98
class PydantObject(BaseModel):
109
value: int
1110

12-
class PydantObjectV2(BaseModelV2):
13-
value: int
14-
1511
class NormalObject:
1612
def __init__(self, value):
1713
self.value = value
@@ -162,16 +158,16 @@ def test_client():
162158
llm_instance_config = {
163159
"model": "moonshot-v1-8k",
164160
"base_url": "https://api.moonshot.cn/v1",
165-
"moonshot_api_key": "sk-OWo8hbSubp1QzOPyskOEwXQtZ867Ph0PZWCQdWrc3PH4o0lI",
161+
"moonshot_api_key": "xxx",
166162
}
167163

168164
from langchain_community.chat_models.moonshot import MoonshotChat
169165

170166
llm_model_instance = MoonshotChat(**llm_instance_config)
171-
172167
copy_obj = safe_deepcopy(llm_model_instance)
168+
173169
assert copy_obj
174-
170+
assert hasattr(copy_obj, 'callbacks')
175171

176172
def test_circular_reference_in_dict():
177173
original = {}
@@ -182,3 +178,9 @@ def test_circular_reference_in_dict():
182178
assert copy_obj is not original
183179
# Check that the circular reference is maintained in the copy
184180
assert copy_obj['self'] is copy_obj
181+
182+
def test_with_pydantic():
183+
original = PydantObject(value=1)
184+
copy_obj = safe_deepcopy(original)
185+
assert copy_obj.value == original.value
186+
assert copy_obj is not original

0 commit comments

Comments
 (0)