Skip to content

Commit 2a8f4cc

Browse files
author
Namrata Madan
committed
Revert "feature: serialize objs using dask (aws#907)"
This reverts commit 27cfeff29f961fd6bb28bd91545de6ec888b284a.
1 parent bd38d6c commit 2a8f4cc

File tree

9 files changed

+278
-724
lines changed

9 files changed

+278
-724
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
cloudpickle==2.2.0
2+
tblib==1.7.0

requirements/extras/test_requirements.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,4 @@ sagemaker-experiments==0.1.35
2121
Jinja2==3.0.3
2222
pandas>=1.3.5,<1.5
2323
scikit-learn==1.0.2
24-
cloudpickle==2.2.1
25-
distributed==2022.2.0
24+
cloudpickle==2.2.0

src/sagemaker/remote_function/core/serialization.py

Lines changed: 49 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -17,43 +17,26 @@
1717
import json
1818
import os
1919
import sys
20-
import pickle
21-
from enum import Enum
2220

2321
import cloudpickle
2422

2523
from typing import Any, Callable
26-
27-
from sagemaker.s3 import s3_path_join
28-
2924
from sagemaker.remote_function.errors import ServiceError, SerializationError, DeserializationError
3025
from sagemaker.s3 import S3Downloader, S3Uploader
3126
from tblib import pickling_support
3227

33-
METADATA_FILE = "metadata.json"
34-
PAYLOAD_FILE = "payload.pkl"
35-
HEADER_FILE = "headers.pkl"
36-
FRAME_FILE = "frame-{}.dat"
37-
3828

3929
def _get_python_version():
4030
return f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
4131

4232

43-
class SerializationModule(str, Enum):
44-
"""Represents various serialization modules used."""
45-
46-
CLOUDPICKLE = "cloudpickle"
47-
DASK = "dask"
48-
49-
5033
@dataclasses.dataclass
5134
class _MetaData:
5235
"""Metadata about the serialized data or functions."""
5336

54-
serialization_module: SerializationModule
5537
version: str = "2023-04-24"
5638
python_version: str = _get_python_version()
39+
serialization_module: str = "cloudpickle"
5740

5841
def to_json(self):
5942
return json.dumps(dataclasses.asdict(self)).encode()
@@ -62,13 +45,16 @@ def to_json(self):
6245
def from_json(s):
6346
try:
6447
obj = json.loads(s)
65-
metadata = _MetaData(**obj)
66-
except (json.decoder.JSONDecodeError, TypeError):
48+
except json.decoder.JSONDecodeError:
6749
raise DeserializationError("Corrupt metadata file. It is not a valid json file.")
6850

69-
if (
70-
metadata.version != "2023-04-24"
71-
or metadata.serialization_module not in SerializationModule.__members__.values()
51+
metadata = _MetaData()
52+
metadata.version = obj.get("version")
53+
metadata.python_version = obj.get("python_version")
54+
metadata.serialization_module = obj.get("serialization_module")
55+
56+
if not (
57+
metadata.version == "2023-04-24" and metadata.serialization_module == "cloudpickle"
7258
):
7359
raise DeserializationError(
7460
f"Corrupt metadata file. Serialization approach {s} is not supported."
@@ -93,12 +79,6 @@ def serialize(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: str = None):
9379
Raises:
9480
SerializationError: when fail to serialize object to bytes.
9581
"""
96-
_upload_bytes_to_s3(
97-
_MetaData(SerializationModule.CLOUDPICKLE).to_json(),
98-
os.path.join(s3_uri, METADATA_FILE),
99-
s3_kms_key,
100-
sagemaker_session,
101-
)
10282
try:
10383
bytes_to_upload = cloudpickle.dumps(obj)
10484
except Exception as e:
@@ -116,76 +96,7 @@ def serialize(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: str = None):
11696
"Error when serializing object of type [{}]: {}".format(type(obj).__name__, repr(e))
11797
) from e
11898

119-
_upload_bytes_to_s3(
120-
bytes_to_upload, os.path.join(s3_uri, PAYLOAD_FILE), s3_kms_key, sagemaker_session
121-
)
122-
123-
@staticmethod
124-
def deserialize(sagemaker_session, s3_uri) -> Any:
125-
"""Downloads from S3 and then deserializes data objects.
126-
127-
Args:
128-
sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which
129-
AWS service calls are delegated to.
130-
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
131-
Returns :
132-
List of deserialized python objects.
133-
Raises:
134-
DeserializationError: when fail to serialize object to bytes.
135-
"""
136-
bytes_to_deserialize = _read_bytes_from_s3(
137-
os.path.join(s3_uri, PAYLOAD_FILE), sagemaker_session
138-
)
139-
140-
try:
141-
return cloudpickle.loads(bytes_to_deserialize)
142-
except Exception as e:
143-
raise DeserializationError(
144-
"Error when deserializing bytes downloaded from {}: {}".format(
145-
os.path.join(s3_uri, PAYLOAD_FILE), repr(e)
146-
)
147-
) from e
148-
149-
150-
class DaskSerializer:
151-
"""Serializer using Dask."""
152-
153-
@staticmethod
154-
def serialize(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: str = None):
155-
"""Serializes data object and uploads it to S3.
156-
157-
Args:
158-
sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS
159-
service calls are delegated to.
160-
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
161-
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
162-
obj: object to be serialized and persisted
163-
Raises:
164-
SerializationError: when fail to serialize object to bytes.
165-
"""
166-
import distributed.protocol as dask
167-
168-
_upload_bytes_to_s3(
169-
_MetaData(SerializationModule.DASK).to_json(),
170-
os.path.join(s3_uri, METADATA_FILE),
171-
s3_kms_key,
172-
sagemaker_session,
173-
)
174-
try:
175-
header, frames = dask.serialize(obj, on_error="raise")
176-
except Exception as e:
177-
raise SerializationError(
178-
"Error when serializing object of type [{}]: {}".format(type(obj).__name__, repr(e))
179-
) from e
180-
181-
_upload_bytes_to_s3(
182-
pickle.dumps(header), s3_path_join(s3_uri, HEADER_FILE), s3_kms_key, sagemaker_session
183-
)
184-
for idx, frame in enumerate(frames):
185-
frame = bytes(frame) if isinstance(frame, memoryview) else frame
186-
_upload_bytes_to_s3(
187-
frame, s3_path_join(s3_uri, FRAME_FILE.format(idx)), s3_kms_key, sagemaker_session
188-
)
99+
_upload_bytes_to_s3(bytes_to_upload, s3_uri, s3_kms_key, sagemaker_session)
189100

190101
@staticmethod
191102
def deserialize(sagemaker_session, s3_uri) -> Any:
@@ -200,29 +111,19 @@ def deserialize(sagemaker_session, s3_uri) -> Any:
200111
Raises:
201112
DeserializationError: when fail to serialize object to bytes.
202113
"""
203-
import distributed.protocol as dask
114+
bytes_to_deserialize = _read_bytes_from_s3(s3_uri, sagemaker_session)
204115

205-
header_to_deserialize = _read_bytes_from_s3(
206-
s3_path_join(s3_uri, HEADER_FILE), sagemaker_session
207-
)
208-
headers = pickle.loads(header_to_deserialize)
209-
num_frames = len(headers["frame-lengths"]) if "frame-lengths" in headers else 1
210-
frames = []
211-
for idx in range(num_frames):
212-
frame = _read_bytes_from_s3(
213-
s3_path_join(s3_uri, FRAME_FILE.format(idx)), sagemaker_session
214-
)
215-
frames.append(frame)
216116
try:
217-
return dask.deserialize(headers, frames)
117+
return cloudpickle.loads(bytes_to_deserialize)
218118
except Exception as e:
219119
raise DeserializationError(
220120
"Error when deserializing bytes downloaded from {}: {}".format(s3_uri, repr(e))
221121
) from e
222122

223123

124+
# TODO: use dask serializer in case dask distributed is installed in users' environment.
224125
def serialize_func_to_s3(func: Callable, sagemaker_session, s3_uri, s3_kms_key=None):
225-
"""Serializes function using cloudpickle and uploads it to S3.
126+
"""Serializes function and uploads it to S3.
226127
227128
Args:
228129
sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service
@@ -233,7 +134,13 @@ def serialize_func_to_s3(func: Callable, sagemaker_session, s3_uri, s3_kms_key=N
233134
Raises:
234135
SerializationError: when fail to serialize function to bytes.
235136
"""
236-
CloudpickleSerializer.serialize(func, sagemaker_session, s3_uri, s3_kms_key)
137+
138+
_upload_bytes_to_s3(
139+
_MetaData().to_json(), os.path.join(s3_uri, "metadata.json"), s3_kms_key, sagemaker_session
140+
)
141+
CloudpickleSerializer.serialize(
142+
func, sagemaker_session, os.path.join(s3_uri, "payload.pkl"), s3_kms_key
143+
)
237144

238145

239146
def deserialize_func_from_s3(sagemaker_session, s3_uri) -> Callable:
@@ -251,16 +158,16 @@ def deserialize_func_from_s3(sagemaker_session, s3_uri) -> Callable:
251158
Raises:
252159
DeserializationError: when fail to serialize function to bytes.
253160
"""
254-
_MetaData.from_json(_read_bytes_from_s3(os.path.join(s3_uri, METADATA_FILE), sagemaker_session))
255-
return CloudpickleSerializer.deserialize(sagemaker_session, s3_uri)
161+
_MetaData.from_json(
162+
_read_bytes_from_s3(os.path.join(s3_uri, "metadata.json"), sagemaker_session)
163+
)
164+
165+
return CloudpickleSerializer.deserialize(sagemaker_session, os.path.join(s3_uri, "payload.pkl"))
256166

257167

258168
def serialize_obj_to_s3(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: str = None):
259169
"""Serializes data object and uploads it to S3.
260170
261-
This method uses the Dask library to perform serialization if its already installed, otherwise,
262-
it uses cloudpickle.
263-
264171
Args:
265172
sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service
266173
calls are delegated to.
@@ -271,12 +178,12 @@ def serialize_obj_to_s3(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: st
271178
SerializationError: when fail to serialize object to bytes.
272179
"""
273180

274-
try:
275-
import distributed.protocol as dask # noqa: F401
276-
277-
DaskSerializer.serialize(obj, sagemaker_session, s3_uri, s3_kms_key)
278-
except ModuleNotFoundError:
279-
CloudpickleSerializer.serialize(obj, sagemaker_session, s3_uri, s3_kms_key)
181+
_upload_bytes_to_s3(
182+
_MetaData().to_json(), os.path.join(s3_uri, "metadata.json"), s3_kms_key, sagemaker_session
183+
)
184+
CloudpickleSerializer.serialize(
185+
obj, sagemaker_session, os.path.join(s3_uri, "payload.pkl"), s3_kms_key
186+
)
280187

281188

282189
def deserialize_obj_from_s3(sagemaker_session, s3_uri) -> Any:
@@ -291,12 +198,12 @@ def deserialize_obj_from_s3(sagemaker_session, s3_uri) -> Any:
291198
Raises:
292199
DeserializationError: when fail to serialize object to bytes.
293200
"""
294-
metadata = _MetaData.from_json(
295-
_read_bytes_from_s3(os.path.join(s3_uri, METADATA_FILE), sagemaker_session)
201+
202+
_MetaData.from_json(
203+
_read_bytes_from_s3(os.path.join(s3_uri, "metadata.json"), sagemaker_session)
296204
)
297-
if metadata.serialization_module == SerializationModule.DASK:
298-
return DaskSerializer.deserialize(sagemaker_session, s3_uri)
299-
return CloudpickleSerializer.deserialize(sagemaker_session, s3_uri)
205+
206+
return CloudpickleSerializer.deserialize(sagemaker_session, os.path.join(s3_uri, "payload.pkl"))
300207

301208

302209
def serialize_exception_to_s3(
@@ -314,7 +221,12 @@ def serialize_exception_to_s3(
314221
SerializationError: when fail to serialize object to bytes.
315222
"""
316223
pickling_support.install()
317-
CloudpickleSerializer.serialize(exc, sagemaker_session, s3_uri, s3_kms_key)
224+
_upload_bytes_to_s3(
225+
_MetaData().to_json(), os.path.join(s3_uri, "metadata.json"), s3_kms_key, sagemaker_session
226+
)
227+
CloudpickleSerializer.serialize(
228+
exc, sagemaker_session, os.path.join(s3_uri, "payload.pkl"), s3_kms_key
229+
)
318230

319231

320232
def deserialize_exception_from_s3(sagemaker_session, s3_uri) -> Any:
@@ -329,8 +241,12 @@ def deserialize_exception_from_s3(sagemaker_session, s3_uri) -> Any:
329241
Raises:
330242
DeserializationError: when fail to serialize object to bytes.
331243
"""
332-
_MetaData.from_json(_read_bytes_from_s3(os.path.join(s3_uri, METADATA_FILE), sagemaker_session))
333-
return CloudpickleSerializer.deserialize(sagemaker_session, s3_uri)
244+
245+
_MetaData.from_json(
246+
_read_bytes_from_s3(os.path.join(s3_uri, "metadata.json"), sagemaker_session)
247+
)
248+
249+
return CloudpickleSerializer.deserialize(sagemaker_session, os.path.join(s3_uri, "payload.pkl"))
334250

335251

336252
def _upload_bytes_to_s3(bytes, s3_uri, s3_kms_key, sagemaker_session):

src/sagemaker/remote_function/core/stored_function.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import sagemaker.remote_function.core.serialization as serialization
2020

21+
2122
logger = logging_config.get_logger()
2223

2324

tests/data/remote_function/requirements_dask.txt

Lines changed: 0 additions & 1 deletion
This file was deleted.

0 commit comments

Comments
 (0)