Skip to content

Commit 5718b88

Browse files
Ao GuoNamrata Madan
authored andcommitted
Enabled integ test case for incompatible deps, polished some exception messages
1 parent c147d74 commit 5718b88

File tree

6 files changed

+48
-27
lines changed

6 files changed

+48
-27
lines changed

src/sagemaker/remote_function/core/serialization.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def serialize_func_to_s3(func: Callable, sagemaker_session, s3_uri, s3_kms_key=N
4141
except Exception as e:
4242
raise SerializationError(
4343
"Error when serializing function [{}]: {}".format(
44-
getattr(func, "__name__", repr(func)), e
44+
getattr(func, "__name__", repr(func)), repr(e)
4545
)
4646
) from e
4747

@@ -69,7 +69,9 @@ def deserialize_func_from_s3(sagemaker_session, s3_uri) -> Callable:
6969
return cloudpickle.loads(bytes_to_deserialize)
7070
except Exception as e:
7171
raise DeserializationError(
72-
"Error when deserializing bytes downloaded from {}: {}".format(s3_uri, e)
72+
"Error when deserializing bytes downloaded from {} to function: {}".format(
73+
s3_uri, repr(e)
74+
)
7375
) from e
7476

7577

@@ -89,7 +91,7 @@ def serialize_obj_to_s3(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: st
8991
bytes_to_upload = cloudpickle.dumps(obj)
9092
except Exception as e:
9193
raise SerializationError(
92-
"Error when serializing object of type [{}]: {}".format(type(obj).__name__, e)
94+
"Error when serializing object of type [{}]: {}".format(type(obj).__name__, repr(e))
9395
) from e
9496

9597
_upload_bytes_to_s3(bytes_to_upload, s3_uri, s3_kms_key, sagemaker_session)
@@ -113,7 +115,7 @@ def deserialize_obj_from_s3(sagemaker_session, s3_uri) -> Any:
113115
return cloudpickle.loads(bytes_to_deserialize)
114116
except Exception as e:
115117
raise DeserializationError(
116-
"Error when deserializing bytes downloaded from {}: {}".format(s3_uri, e)
118+
"Error when deserializing bytes downloaded from {}: {}".format(s3_uri, repr(e))
117119
) from e
118120

119121

@@ -124,12 +126,16 @@ def _upload_bytes_to_s3(bytes, s3_uri, s3_kms_key, sagemaker_session):
124126
bytes, s3_uri, kms_key=s3_kms_key, sagemaker_session=sagemaker_session
125127
)
126128
except Exception as e:
127-
raise ServiceError("Failed to upload serialized bytes to {}: {}".format(s3_uri, e)) from e
129+
raise ServiceError(
130+
"Failed to upload serialized bytes to {}: {}".format(s3_uri, repr(e))
131+
) from e
128132

129133

130134
def _read_bytes_from_s3(s3_uri, sagemaker_session):
131135
"""Wrapping s3 downloading with exception translation for remote function."""
132136
try:
133137
return S3Downloader.read_bytes(s3_uri, sagemaker_session=sagemaker_session)
134138
except Exception as e:
135-
raise ServiceError("Failed to read serialized bytes from {}: {}".format(s3_uri, e)) from e
139+
raise ServiceError(
140+
"Failed to read serialized bytes from {}: {}".format(s3_uri, repr(e))
141+
) from e
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pandas==1.3.4

tests/integ/sagemaker/remote_function/test_decorator.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515

1616
import pytest
1717
import os
18+
import pandas as pd
1819

1920
from sagemaker.remote_function import remote
20-
from sagemaker.remote_function.errors import RuntimeEnvironmentError
21+
from sagemaker.remote_function.errors import DeserializationError, RuntimeEnvironmentError
2122

2223
from tests.integ.kms_utils import get_or_create_kms_key
2324
from tests.integ import DATA_DIR
@@ -136,20 +137,32 @@ def divide(x, y):
136137
divide(10, 2)
137138

138139

139-
@pytest.mark.skip
140140
def test_with_incompatible_dependencies(
141141
sagemaker_session, dummy_container_without_error, cpu_instance_type
142142
):
143+
144+
dependencies_path = os.path.join(DATA_DIR, "remote_function/old_deps_requirements.txt")
145+
143146
@remote(
144147
role=ROLE,
145148
image_uri=dummy_container_without_error,
146-
dependencies="./requirements.txt",
149+
dependencies=dependencies_path,
147150
instance_type=cpu_instance_type,
148151
sagemaker_session=sagemaker_session,
149152
)
150-
def divide(x, y):
151-
return x / y
153+
def mul_ten(df: pd.DataFrame):
154+
print("hehehe")
155+
print(pd.__version__)
156+
return df.mul(10)
157+
158+
df1 = pd.DataFrame(
159+
{
160+
"A": [14, 4, 5, 4, 1],
161+
"B": [5, 2, 54, 3, 2],
162+
"C": [20, 20, 7, 3, 8],
163+
"D": [14, 3, 6, 2, 6],
164+
}
165+
)
152166

153-
# TODO: this should raise DeserializationError
154-
with pytest.raises(RuntimeError):
155-
divide(10, 2)
167+
with pytest.raises(DeserializationError):
168+
mul_ten(df1)

tests/integ/sagemaker/remote_function/test_executor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15-
1615
from sagemaker.remote_function import RemoteExecutor
1716

1817
ROLE = "SageMakerRole"

tests/unit/sagemaker/remote_function/core/test_serialization.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def square(x):
8484

8585
with pytest.raises(
8686
SerializationError,
87-
match=r"Error when serializing function \[square\]: some failure when dumps",
87+
match=r"Error when serializing function \[square\]: RuntimeError\('some failure when dumps'\)",
8888
):
8989
serialize_func_to_s3(
9090
func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY
@@ -95,7 +95,8 @@ def square(x):
9595

9696
with pytest.raises(
9797
DeserializationError,
98-
match=rf"Error when deserializing bytes downloaded from {s3_uri}: some failure when loads",
98+
match=rf"Error when deserializing bytes downloaded from {s3_uri} to function: "
99+
+ r"RuntimeError\('some failure when loads'\)",
99100
):
100101
deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri)
101102
mock_cloudpickle_loads.assert_called_with({})
@@ -116,7 +117,7 @@ def test_serialize_deserialize_lambda_func_serialization_error(
116117

117118
with pytest.raises(
118119
SerializationError,
119-
match=r"Error when serializing function \[<lambda>\]: some failure when dumps",
120+
match=r"Error when serializing function \[<lambda>\]: RuntimeError\('some failure when dumps'\)",
120121
):
121122
serialize_func_to_s3(
122123
func=my_func, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY
@@ -127,7 +128,8 @@ def test_serialize_deserialize_lambda_func_serialization_error(
127128

128129
with pytest.raises(
129130
DeserializationError,
130-
match=rf"Error when deserializing bytes downloaded from {s3_uri}: some failure when loads",
131+
match=rf"Error when deserializing bytes downloaded from {s3_uri} to function: "
132+
+ r"RuntimeError\('some failure when loads'\)",
131133
):
132134
deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri)
133135
mock_cloudpickle_loads.assert_called_with({})
@@ -199,7 +201,7 @@ def __init__(self, x):
199201

200202
with pytest.raises(
201203
SerializationError,
202-
match=r"Error when serializing object of type \[MyData\]: some failure when dumps",
204+
match=r"Error when serializing object of type \[MyData\]: RuntimeError\('some failure when dumps'\)",
203205
):
204206
serialize_obj_to_s3(
205207
obj=my_data, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY
@@ -211,7 +213,7 @@ def __init__(self, x):
211213

212214
with pytest.raises(
213215
DeserializationError,
214-
match=rf"Error when deserializing bytes downloaded from {s3_uri}: some failure when loads",
216+
match=rf"Error when deserializing bytes downloaded from {s3_uri}: RuntimeError\('some failure when loads'\)",
215217
):
216218
deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri)
217219
mock_cloudpickle_loads.assert_called_with({})
@@ -226,7 +228,7 @@ def test_serialize_deserialize_service_error():
226228
s3_uri = random_s3_uri()
227229
with pytest.raises(
228230
ServiceError,
229-
match=rf"Failed to upload serialized bytes to {s3_uri}: some failure when upload_bytes",
231+
match=rf"Failed to upload serialized bytes to {s3_uri}: RuntimeError\('some failure when upload_bytes'\)",
230232
):
231233
serialize_func_to_s3(
232234
func=my_func, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY
@@ -236,6 +238,6 @@ def test_serialize_deserialize_service_error():
236238

237239
with pytest.raises(
238240
ServiceError,
239-
match=rf"Failed to read serialized bytes from {s3_uri}: some failure when read_bytes",
241+
match=rf"Failed to read serialized bytes from {s3_uri}: RuntimeError\('some failure when read_bytes'\)",
240242
):
241243
deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri)

tests/unit/sagemaker/remote_function/test_client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def square(x):
193193

194194
with pytest.raises(
195195
ServiceError,
196-
match=r"Failed to read serialized bytes from .+: some error when reading from s3",
196+
match=r"Failed to read serialized bytes from .+: RuntimeError\('some error when reading from s3'\)",
197197
):
198198
square(5)
199199

@@ -224,7 +224,7 @@ def square(x):
224224
with pytest.raises(
225225
DeserializationError,
226226
match=r"Error when deserializing bytes downloaded from .+: "
227-
"some value error when deserializing",
227+
r"ValueError\('some value error when deserializing'\)",
228228
):
229229
square(5)
230230
assert MockJobSettings.call_args.kwargs["image_uri"] == IMAGE
@@ -738,7 +738,7 @@ def test_future_get_result_from_failed_job_local_error_service_error(mock_start,
738738

739739
with pytest.raises(
740740
ServiceError,
741-
match=r"Failed to read serialized bytes from .+: some error when reading from s3",
741+
match=r"Failed to read serialized bytes from .+: RuntimeError\('some error when reading from s3'\)",
742742
):
743743
future.result()
744744

@@ -770,7 +770,7 @@ def test_future_get_result_from_failed_job_local_error_remote_function_error(
770770
with pytest.raises(
771771
DeserializationError,
772772
match=r"Error when deserializing bytes downloaded from .+: "
773-
"some value error when deserializing",
773+
r"ValueError\('some value error when deserializing'\)",
774774
):
775775
future.result()
776776

0 commit comments

Comments
 (0)