Skip to content

Commit cb427d1

Browse files
navaj0Zhankuil
authored andcommitted
Version the serialization scheme (aws#891)
Co-authored-by: Zhankui Lu <[email protected]>
1 parent 4316450 commit cb427d1

File tree

9 files changed

+445
-183
lines changed

9 files changed

+445
-183
lines changed

src/sagemaker/remote_function/client.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
import sagemaker.remote_function.core.serialization as serialization
3030
from sagemaker.remote_function.errors import RemoteFunctionError, ServiceError, DeserializationError
31+
from sagemaker.remote_function.core.stored_function import RESULTS_FOLDER, EXCEPTION_FOLDER
3132
from sagemaker.remote_function.runtime_environment.runtime_environment_manager import (
3233
RuntimeEnvironmentError,
3334
)
@@ -162,10 +163,10 @@ def wrapper(*args, **kwargs):
162163
except UnexpectedStatusException as usex:
163164
if usex.actual_status == "Failed":
164165
try:
165-
exception = serialization.deserialize_obj_from_s3(
166+
exception = serialization.deserialize_exception_from_s3(
166167
sagemaker_session=job_settings.sagemaker_session,
167168
s3_uri=s3_path_join(
168-
job_settings.s3_root_uri, job.job_name, "exception.pkl"
169+
job_settings.s3_root_uri, job.job_name, EXCEPTION_FOLDER
169170
),
170171
)
171172
except ServiceError as serr:
@@ -202,7 +203,7 @@ def wrapper(*args, **kwargs):
202203
if job.describe()["TrainingJobStatus"] == "Completed":
203204
return serialization.deserialize_obj_from_s3(
204205
sagemaker_session=job_settings.sagemaker_session,
205-
s3_uri=s3_path_join(job_settings.s3_root_uri, job.job_name, "results.pkl"),
206+
s3_uri=s3_path_join(job_settings.s3_root_uri, job.job_name, RESULTS_FOLDER),
206207
)
207208

208209
if job.describe()["TrainingJobStatus"] == "Stopped":
@@ -596,7 +597,7 @@ def from_describe_response(describe_training_job_response, sagemaker_session):
596597
try:
597598
job_return = serialization.deserialize_obj_from_s3(
598599
sagemaker_session=sagemaker_session,
599-
s3_uri=s3_path_join(job.s3_uri, "results.pkl"),
600+
s3_uri=s3_path_join(job.s3_uri, RESULTS_FOLDER),
600601
)
601602
except DeserializationError as e:
602603
client_exception = e
@@ -605,9 +606,9 @@ def from_describe_response(describe_training_job_response, sagemaker_session):
605606
elif describe_training_job_response["TrainingJobStatus"] == "Failed":
606607
state = _FINISHED
607608
try:
608-
job_exception = serialization.deserialize_obj_from_s3(
609+
job_exception = serialization.deserialize_exception_from_s3(
609610
sagemaker_session=sagemaker_session,
610-
s3_uri=s3_path_join(job.s3_uri, "exception.pkl"),
611+
s3_uri=s3_path_join(job.s3_uri, EXCEPTION_FOLDER),
611612
)
612613
except ServiceError as serr:
613614
chained_e = serr.__cause__
@@ -694,15 +695,15 @@ def result(self, timeout: float = None) -> Any:
694695
if self._job.describe()["TrainingJobStatus"] == "Completed":
695696
self._return = serialization.deserialize_obj_from_s3(
696697
sagemaker_session=self._job.sagemaker_session,
697-
s3_uri=s3_path_join(self._job.s3_uri, "results.pkl"),
698+
s3_uri=s3_path_join(self._job.s3_uri, RESULTS_FOLDER),
698699
)
699700
self._state = _FINISHED
700701
return self._return
701702
if self._job.describe()["TrainingJobStatus"] == "Failed":
702703
try:
703-
self._exception = serialization.deserialize_obj_from_s3(
704+
self._exception = serialization.deserialize_exception_from_s3(
704705
sagemaker_session=self._job.sagemaker_session,
705-
s3_uri=s3_path_join(self._job.s3_uri, "exception.pkl"),
706+
s3_uri=s3_path_join(self._job.s3_uri, EXCEPTION_FOLDER),
706707
)
707708
except ServiceError as serr:
708709
chained_e = serr.__cause__

src/sagemaker/remote_function/core/serialization.py

Lines changed: 166 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,114 @@
1313
"""SageMaker remote function data serializer/deserializer."""
1414
from __future__ import absolute_import
1515

16+
import dataclasses
17+
import json
18+
import os
19+
import sys
20+
1621
import cloudpickle
1722

1823
from typing import Any, Callable
1924
from sagemaker.remote_function.errors import ServiceError, SerializationError, DeserializationError
2025
from sagemaker.s3 import S3Downloader, S3Uploader
26+
from tblib import pickling_support
27+
28+
29+
def _get_python_version():
30+
return f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
31+
32+
33+
@dataclasses.dataclass
34+
class _MetaData:
35+
"""Metadata about the serialized data or functions."""
2136

37+
version: str = "2023-04-24"
38+
python_version: str = _get_python_version()
39+
serialization_module: str = "cloudpickle"
2240

23-
# TODO: 1) use dask serializer instead of cloudpickle for data serialization.
24-
# 2) set the pickle protocol properly
25-
# 3) serialization/deserialization scheme needs to be explicitly versioned
26-
# 4) handle exceptions
41+
def to_json(self):
42+
return json.dumps(dataclasses.asdict(self)).encode()
43+
44+
@staticmethod
45+
def from_json(s):
46+
try:
47+
obj = json.loads(s)
48+
except json.decoder.JSONDecodeError:
49+
raise DeserializationError("Corrupt metadata file. It is not a valid json file.")
50+
51+
metadata = _MetaData()
52+
metadata.version = obj.get("version")
53+
metadata.python_version = obj.get("python_version")
54+
metadata.serialization_module = obj.get("serialization_module")
55+
56+
if not (
57+
metadata.version == "2023-04-24" and metadata.serialization_module == "cloudpickle"
58+
):
59+
raise DeserializationError(
60+
f"Corrupt metadata file. Serialization approach {s} is not supported."
61+
)
62+
63+
return metadata
64+
65+
66+
class CloudpickleSerializer:
67+
"""Serializer using cloudpickle."""
68+
69+
@staticmethod
70+
def serialize(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: str = None):
71+
"""Serializes data object and uploads it to S3.
72+
73+
Args:
74+
sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service
75+
calls are delegated to.
76+
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
77+
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
78+
obj: object to be serialized and persisted
79+
Raises:
80+
SerializationError: when fail to serialize object to bytes.
81+
"""
82+
try:
83+
bytes_to_upload = cloudpickle.dumps(obj)
84+
except Exception as e:
85+
if isinstance(
86+
e, NotImplementedError
87+
) and "Instance of Run type is not allowed to be pickled." in str(e):
88+
raise SerializationError(
89+
"""You are trying to reference to a sagemaker.experiments.run.Run instance from within the function
90+
or passing it as a function argument.
91+
Instantiate a Run in the function or use load_run instead."""
92+
) from e
93+
94+
raise SerializationError(
95+
"Error when serializing object of type [{}]: {}".format(type(obj).__name__, repr(e))
96+
) from e
97+
98+
_upload_bytes_to_s3(bytes_to_upload, s3_uri, s3_kms_key, sagemaker_session)
99+
100+
@staticmethod
101+
def deserialize(sagemaker_session, s3_uri) -> Any:
102+
"""Downloads from S3 and then deserializes data objects.
103+
104+
Args:
105+
sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which AWS service
106+
calls are delegated to.
107+
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
108+
Returns :
109+
List of deserialized python objects.
110+
Raises:
111+
DeserializationError: when fail to serialize object to bytes.
112+
"""
113+
bytes_to_deserialize = _read_bytes_from_s3(s3_uri, sagemaker_session)
114+
115+
try:
116+
return cloudpickle.loads(bytes_to_deserialize)
117+
except Exception as e:
118+
raise DeserializationError(
119+
"Error when deserializing bytes downloaded from {}: {}".format(s3_uri, repr(e))
120+
) from e
121+
122+
123+
# TODO: use dask serializer in case dask distributed is installed in users' environment.
27124
def serialize_func_to_s3(func: Callable, sagemaker_session, s3_uri, s3_kms_key=None):
28125
"""Serializes function and uploads it to S3.
29126
@@ -36,16 +133,13 @@ def serialize_func_to_s3(func: Callable, sagemaker_session, s3_uri, s3_kms_key=N
36133
Raises:
37134
SerializationError: when fail to serialize function to bytes.
38135
"""
39-
try:
40-
bytes_to_upload = cloudpickle.dumps(func)
41-
except Exception as e:
42-
raise SerializationError(
43-
"Error when serializing function [{}]: {}".format(
44-
getattr(func, "__name__", repr(func)), repr(e)
45-
)
46-
) from e
47136

48-
_upload_bytes_to_s3(bytes_to_upload, s3_uri, s3_kms_key, sagemaker_session)
137+
_upload_bytes_to_s3(
138+
_MetaData().to_json(), os.path.join(s3_uri, "metadata.json"), s3_kms_key, sagemaker_session
139+
)
140+
CloudpickleSerializer.serialize(
141+
func, sagemaker_session, os.path.join(s3_uri, "payload.pkl"), s3_kms_key
142+
)
49143

50144

51145
def deserialize_func_from_s3(sagemaker_session, s3_uri) -> Callable:
@@ -63,16 +157,11 @@ def deserialize_func_from_s3(sagemaker_session, s3_uri) -> Callable:
63157
Raises:
64158
DeserializationError: when fail to serialize function to bytes.
65159
"""
66-
bytes_to_deserialize = _read_bytes_from_s3(s3_uri, sagemaker_session)
160+
_MetaData.from_json(
161+
_read_bytes_from_s3(os.path.join(s3_uri, "metadata.json"), sagemaker_session)
162+
)
67163

68-
try:
69-
return cloudpickle.loads(bytes_to_deserialize)
70-
except Exception as e:
71-
raise DeserializationError(
72-
"Error when deserializing bytes downloaded from {} to function: {}".format(
73-
s3_uri, repr(e)
74-
)
75-
) from e
164+
return CloudpickleSerializer.deserialize(sagemaker_session, os.path.join(s3_uri, "payload.pkl"))
76165

77166

78167
def serialize_obj_to_s3(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: str = None):
@@ -87,21 +176,13 @@ def serialize_obj_to_s3(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: st
87176
Raises:
88177
SerializationError: when fail to serialize object to bytes.
89178
"""
90-
try:
91-
bytes_to_upload = cloudpickle.dumps(obj)
92-
except Exception as e:
93-
if isinstance(
94-
e, NotImplementedError
95-
) and "Instance of Run type is not allowed to be pickled." in str(e):
96-
raise SerializationError(
97-
"Remote function does not allow parameters of Run type."
98-
) from e
99-
100-
raise SerializationError(
101-
"Error when serializing object of type [{}]: {}".format(type(obj).__name__, repr(e))
102-
) from e
103179

104-
_upload_bytes_to_s3(bytes_to_upload, s3_uri, s3_kms_key, sagemaker_session)
180+
_upload_bytes_to_s3(
181+
_MetaData().to_json(), os.path.join(s3_uri, "metadata.json"), s3_kms_key, sagemaker_session
182+
)
183+
CloudpickleSerializer.serialize(
184+
obj, sagemaker_session, os.path.join(s3_uri, "payload.pkl"), s3_kms_key
185+
)
105186

106187

107188
def deserialize_obj_from_s3(sagemaker_session, s3_uri) -> Any:
@@ -112,18 +193,59 @@ def deserialize_obj_from_s3(sagemaker_session, s3_uri) -> Any:
112193
calls are delegated to.
113194
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
114195
Returns :
115-
List of deserialized python objects.
196+
Deserialized python objects.
116197
Raises:
117198
DeserializationError: when fail to serialize object to bytes.
118199
"""
119-
bytes_to_deserialize = _read_bytes_from_s3(s3_uri, sagemaker_session)
120200

121-
try:
122-
return cloudpickle.loads(bytes_to_deserialize)
123-
except Exception as e:
124-
raise DeserializationError(
125-
"Error when deserializing bytes downloaded from {}: {}".format(s3_uri, repr(e))
126-
) from e
201+
_MetaData.from_json(
202+
_read_bytes_from_s3(os.path.join(s3_uri, "metadata.json"), sagemaker_session)
203+
)
204+
205+
return CloudpickleSerializer.deserialize(sagemaker_session, os.path.join(s3_uri, "payload.pkl"))
206+
207+
208+
def serialize_exception_to_s3(
209+
exc: Exception, sagemaker_session, s3_uri: str, s3_kms_key: str = None
210+
):
211+
"""Serializes exception with traceback and uploads it to S3.
212+
213+
Args:
214+
sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service
215+
calls are delegated to.
216+
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
217+
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
218+
exc: Exception to be serialized and persisted
219+
Raises:
220+
SerializationError: when fail to serialize object to bytes.
221+
"""
222+
pickling_support.install()
223+
_upload_bytes_to_s3(
224+
_MetaData().to_json(), os.path.join(s3_uri, "metadata.json"), s3_kms_key, sagemaker_session
225+
)
226+
CloudpickleSerializer.serialize(
227+
exc, sagemaker_session, os.path.join(s3_uri, "payload.pkl"), s3_kms_key
228+
)
229+
230+
231+
def deserialize_exception_from_s3(sagemaker_session, s3_uri) -> Any:
232+
"""Downloads from S3 and then deserializes exception.
233+
234+
Args:
235+
sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which AWS service
236+
calls are delegated to.
237+
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
238+
Returns :
239+
Deserialized exception with traceback.
240+
Raises:
241+
DeserializationError: when fail to serialize object to bytes.
242+
"""
243+
244+
_MetaData.from_json(
245+
_read_bytes_from_s3(os.path.join(s3_uri, "metadata.json"), sagemaker_session)
246+
)
247+
248+
return CloudpickleSerializer.deserialize(sagemaker_session, os.path.join(s3_uri, "payload.pkl"))
127249

128250

129251
def _upload_bytes_to_s3(bytes, s3_uri, s3_kms_key, sagemaker_session):

0 commit comments

Comments
 (0)