17
17
import json
18
18
import os
19
19
import sys
20
- import pickle
21
- from enum import Enum
22
20
23
21
import cloudpickle
24
22
25
23
from typing import Any , Callable
26
-
27
- from sagemaker .s3 import s3_path_join
28
-
29
24
from sagemaker .remote_function .errors import ServiceError , SerializationError , DeserializationError
30
25
from sagemaker .s3 import S3Downloader , S3Uploader
31
26
from tblib import pickling_support
32
27
33
- METADATA_FILE = "metadata.json"
34
- PAYLOAD_FILE = "payload.pkl"
35
- HEADER_FILE = "headers.pkl"
36
- FRAME_FILE = "frame-{}.dat"
37
-
38
28
39
29
def _get_python_version ():
40
30
return f"{ sys .version_info .major } .{ sys .version_info .minor } .{ sys .version_info .micro } "
41
31
42
32
43
- class SerializationModule (str , Enum ):
44
- """Represents various serialization modules used."""
45
-
46
- CLOUDPICKLE = "cloudpickle"
47
- DASK = "dask"
48
-
49
-
50
33
@dataclasses .dataclass
51
34
class _MetaData :
52
35
"""Metadata about the serialized data or functions."""
53
36
54
- serialization_module : SerializationModule
55
37
version : str = "2023-04-24"
56
38
python_version : str = _get_python_version ()
39
+ serialization_module : str = "cloudpickle"
57
40
58
41
def to_json (self ):
59
42
return json .dumps (dataclasses .asdict (self )).encode ()
@@ -62,13 +45,16 @@ def to_json(self):
62
45
def from_json (s ):
63
46
try :
64
47
obj = json .loads (s )
65
- metadata = _MetaData (** obj )
66
- except (json .decoder .JSONDecodeError , TypeError ):
48
+ except json .decoder .JSONDecodeError :
67
49
raise DeserializationError ("Corrupt metadata file. It is not a valid json file." )
68
50
69
- if (
70
- metadata .version != "2023-04-24"
71
- or metadata .serialization_module not in SerializationModule .__members__ .values ()
51
+ metadata = _MetaData ()
52
+ metadata .version = obj .get ("version" )
53
+ metadata .python_version = obj .get ("python_version" )
54
+ metadata .serialization_module = obj .get ("serialization_module" )
55
+
56
+ if not (
57
+ metadata .version == "2023-04-24" and metadata .serialization_module == "cloudpickle"
72
58
):
73
59
raise DeserializationError (
74
60
f"Corrupt metadata file. Serialization approach { s } is not supported."
@@ -93,12 +79,6 @@ def serialize(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: str = None):
93
79
Raises:
94
80
SerializationError: when fail to serialize object to bytes.
95
81
"""
96
- _upload_bytes_to_s3 (
97
- _MetaData (SerializationModule .CLOUDPICKLE ).to_json (),
98
- os .path .join (s3_uri , METADATA_FILE ),
99
- s3_kms_key ,
100
- sagemaker_session ,
101
- )
102
82
try :
103
83
bytes_to_upload = cloudpickle .dumps (obj )
104
84
except Exception as e :
@@ -116,76 +96,7 @@ def serialize(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: str = None):
116
96
"Error when serializing object of type [{}]: {}" .format (type (obj ).__name__ , repr (e ))
117
97
) from e
118
98
119
- _upload_bytes_to_s3 (
120
- bytes_to_upload , os .path .join (s3_uri , PAYLOAD_FILE ), s3_kms_key , sagemaker_session
121
- )
122
-
123
- @staticmethod
124
- def deserialize (sagemaker_session , s3_uri ) -> Any :
125
- """Downloads from S3 and then deserializes data objects.
126
-
127
- Args:
128
- sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which
129
- AWS service calls are delegated to.
130
- s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
131
- Returns :
132
- List of deserialized python objects.
133
- Raises:
134
- DeserializationError: when fail to serialize object to bytes.
135
- """
136
- bytes_to_deserialize = _read_bytes_from_s3 (
137
- os .path .join (s3_uri , PAYLOAD_FILE ), sagemaker_session
138
- )
139
-
140
- try :
141
- return cloudpickle .loads (bytes_to_deserialize )
142
- except Exception as e :
143
- raise DeserializationError (
144
- "Error when deserializing bytes downloaded from {}: {}" .format (
145
- os .path .join (s3_uri , PAYLOAD_FILE ), repr (e )
146
- )
147
- ) from e
148
-
149
-
150
- class DaskSerializer :
151
- """Serializer using Dask."""
152
-
153
- @staticmethod
154
- def serialize (obj : Any , sagemaker_session , s3_uri : str , s3_kms_key : str = None ):
155
- """Serializes data object and uploads it to S3.
156
-
157
- Args:
158
- sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS
159
- service calls are delegated to.
160
- s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
161
- s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
162
- obj: object to be serialized and persisted
163
- Raises:
164
- SerializationError: when fail to serialize object to bytes.
165
- """
166
- import distributed .protocol as dask
167
-
168
- _upload_bytes_to_s3 (
169
- _MetaData (SerializationModule .DASK ).to_json (),
170
- os .path .join (s3_uri , METADATA_FILE ),
171
- s3_kms_key ,
172
- sagemaker_session ,
173
- )
174
- try :
175
- header , frames = dask .serialize (obj , on_error = "raise" )
176
- except Exception as e :
177
- raise SerializationError (
178
- "Error when serializing object of type [{}]: {}" .format (type (obj ).__name__ , repr (e ))
179
- ) from e
180
-
181
- _upload_bytes_to_s3 (
182
- pickle .dumps (header ), s3_path_join (s3_uri , HEADER_FILE ), s3_kms_key , sagemaker_session
183
- )
184
- for idx , frame in enumerate (frames ):
185
- frame = bytes (frame ) if isinstance (frame , memoryview ) else frame
186
- _upload_bytes_to_s3 (
187
- frame , s3_path_join (s3_uri , FRAME_FILE .format (idx )), s3_kms_key , sagemaker_session
188
- )
99
+ _upload_bytes_to_s3 (bytes_to_upload , s3_uri , s3_kms_key , sagemaker_session )
189
100
190
101
@staticmethod
191
102
def deserialize (sagemaker_session , s3_uri ) -> Any :
@@ -200,29 +111,19 @@ def deserialize(sagemaker_session, s3_uri) -> Any:
200
111
Raises:
201
112
DeserializationError: when fail to serialize object to bytes.
202
113
"""
203
- import distributed . protocol as dask
114
+ bytes_to_deserialize = _read_bytes_from_s3 ( s3_uri , sagemaker_session )
204
115
205
- header_to_deserialize = _read_bytes_from_s3 (
206
- s3_path_join (s3_uri , HEADER_FILE ), sagemaker_session
207
- )
208
- headers = pickle .loads (header_to_deserialize )
209
- num_frames = len (headers ["frame-lengths" ]) if "frame-lengths" in headers else 1
210
- frames = []
211
- for idx in range (num_frames ):
212
- frame = _read_bytes_from_s3 (
213
- s3_path_join (s3_uri , FRAME_FILE .format (idx )), sagemaker_session
214
- )
215
- frames .append (frame )
216
116
try :
217
- return dask . deserialize ( headers , frames )
117
+ return cloudpickle . loads ( bytes_to_deserialize )
218
118
except Exception as e :
219
119
raise DeserializationError (
220
120
"Error when deserializing bytes downloaded from {}: {}" .format (s3_uri , repr (e ))
221
121
) from e
222
122
223
123
124
+ # TODO: use dask serializer in case dask distributed is installed in users' environment.
224
125
def serialize_func_to_s3 (func : Callable , sagemaker_session , s3_uri , s3_kms_key = None ):
225
- """Serializes function using cloudpickle and uploads it to S3.
126
+ """Serializes function and uploads it to S3.
226
127
227
128
Args:
228
129
sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service
@@ -233,7 +134,13 @@ def serialize_func_to_s3(func: Callable, sagemaker_session, s3_uri, s3_kms_key=N
233
134
Raises:
234
135
SerializationError: when fail to serialize function to bytes.
235
136
"""
236
- CloudpickleSerializer .serialize (func , sagemaker_session , s3_uri , s3_kms_key )
137
+
138
+ _upload_bytes_to_s3 (
139
+ _MetaData ().to_json (), os .path .join (s3_uri , "metadata.json" ), s3_kms_key , sagemaker_session
140
+ )
141
+ CloudpickleSerializer .serialize (
142
+ func , sagemaker_session , os .path .join (s3_uri , "payload.pkl" ), s3_kms_key
143
+ )
237
144
238
145
239
146
def deserialize_func_from_s3 (sagemaker_session , s3_uri ) -> Callable :
@@ -251,16 +158,16 @@ def deserialize_func_from_s3(sagemaker_session, s3_uri) -> Callable:
251
158
Raises:
252
159
DeserializationError: when fail to serialize function to bytes.
253
160
"""
254
- _MetaData .from_json (_read_bytes_from_s3 (os .path .join (s3_uri , METADATA_FILE ), sagemaker_session ))
255
- return CloudpickleSerializer .deserialize (sagemaker_session , s3_uri )
161
+ _MetaData .from_json (
162
+ _read_bytes_from_s3 (os .path .join (s3_uri , "metadata.json" ), sagemaker_session )
163
+ )
164
+
165
+ return CloudpickleSerializer .deserialize (sagemaker_session , os .path .join (s3_uri , "payload.pkl" ))
256
166
257
167
258
168
def serialize_obj_to_s3 (obj : Any , sagemaker_session , s3_uri : str , s3_kms_key : str = None ):
259
169
"""Serializes data object and uploads it to S3.
260
170
261
- This method uses the Dask library to perform serialization if its already installed, otherwise,
262
- it uses cloudpickle.
263
-
264
171
Args:
265
172
sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service
266
173
calls are delegated to.
@@ -271,12 +178,12 @@ def serialize_obj_to_s3(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: st
271
178
SerializationError: when fail to serialize object to bytes.
272
179
"""
273
180
274
- try :
275
- import distributed . protocol as dask # noqa: F401
276
-
277
- DaskSerializer .serialize (obj , sagemaker_session , s3_uri , s3_kms_key )
278
- except ModuleNotFoundError :
279
- CloudpickleSerializer . serialize ( obj , sagemaker_session , s3_uri , s3_kms_key )
181
+ _upload_bytes_to_s3 (
182
+ _MetaData (). to_json (), os . path . join ( s3_uri , "metadata.json" ), s3_kms_key , sagemaker_session
183
+ )
184
+ CloudpickleSerializer .serialize (
185
+ obj , sagemaker_session , os . path . join ( s3_uri , "payload.pkl" ), s3_kms_key
186
+ )
280
187
281
188
282
189
def deserialize_obj_from_s3 (sagemaker_session , s3_uri ) -> Any :
@@ -291,12 +198,12 @@ def deserialize_obj_from_s3(sagemaker_session, s3_uri) -> Any:
291
198
Raises:
292
199
DeserializationError: when fail to serialize object to bytes.
293
200
"""
294
- metadata = _MetaData .from_json (
295
- _read_bytes_from_s3 (os .path .join (s3_uri , METADATA_FILE ), sagemaker_session )
201
+
202
+ _MetaData .from_json (
203
+ _read_bytes_from_s3 (os .path .join (s3_uri , "metadata.json" ), sagemaker_session )
296
204
)
297
- if metadata .serialization_module == SerializationModule .DASK :
298
- return DaskSerializer .deserialize (sagemaker_session , s3_uri )
299
- return CloudpickleSerializer .deserialize (sagemaker_session , s3_uri )
205
+
206
+ return CloudpickleSerializer .deserialize (sagemaker_session , os .path .join (s3_uri , "payload.pkl" ))
300
207
301
208
302
209
def serialize_exception_to_s3 (
@@ -314,7 +221,12 @@ def serialize_exception_to_s3(
314
221
SerializationError: when fail to serialize object to bytes.
315
222
"""
316
223
pickling_support .install ()
317
- CloudpickleSerializer .serialize (exc , sagemaker_session , s3_uri , s3_kms_key )
224
+ _upload_bytes_to_s3 (
225
+ _MetaData ().to_json (), os .path .join (s3_uri , "metadata.json" ), s3_kms_key , sagemaker_session
226
+ )
227
+ CloudpickleSerializer .serialize (
228
+ exc , sagemaker_session , os .path .join (s3_uri , "payload.pkl" ), s3_kms_key
229
+ )
318
230
319
231
320
232
def deserialize_exception_from_s3 (sagemaker_session , s3_uri ) -> Any :
@@ -329,8 +241,12 @@ def deserialize_exception_from_s3(sagemaker_session, s3_uri) -> Any:
329
241
Raises:
330
242
DeserializationError: when fail to serialize object to bytes.
331
243
"""
332
- _MetaData .from_json (_read_bytes_from_s3 (os .path .join (s3_uri , METADATA_FILE ), sagemaker_session ))
333
- return CloudpickleSerializer .deserialize (sagemaker_session , s3_uri )
244
+
245
+ _MetaData .from_json (
246
+ _read_bytes_from_s3 (os .path .join (s3_uri , "metadata.json" ), sagemaker_session )
247
+ )
248
+
249
+ return CloudpickleSerializer .deserialize (sagemaker_session , os .path .join (s3_uri , "payload.pkl" ))
334
250
335
251
336
252
def _upload_bytes_to_s3 (bytes , s3_uri , s3_kms_key , sagemaker_session ):
0 commit comments