13
13
"""SageMaker remote function data serializer/deserializer."""
14
14
from __future__ import absolute_import
15
15
16
+ import dataclasses
17
+ import json
18
+ import os
19
+ import sys
20
+
16
21
import cloudpickle
17
22
18
23
from typing import Any , Callable
19
24
from sagemaker .remote_function .errors import ServiceError , SerializationError , DeserializationError
20
25
from sagemaker .s3 import S3Downloader , S3Uploader
26
+ from tblib import pickling_support
27
+
28
+
29
+ def _get_python_version ():
30
+ return f"{ sys .version_info .major } .{ sys .version_info .minor } .{ sys .version_info .micro } "
31
+
32
+
33
+ @dataclasses .dataclass
34
+ class _MetaData :
35
+ """Metadata about the serialized data or functions."""
21
36
37
+ version : str = "2023-04-24"
38
+ python_version : str = _get_python_version ()
39
+ serialization_module : str = "cloudpickle"
22
40
23
- # TODO: 1) use dask serializer instead of cloudpickle for data serialization.
24
- # 2) set the pickle protocol properly
25
- # 3) serialization/deserialization scheme needs to be explicitly versioned
26
- # 4) handle exceptions
41
+ def to_json (self ):
42
+ return json .dumps (dataclasses .asdict (self )).encode ()
43
+
44
+ @staticmethod
45
+ def from_json (s ):
46
+ try :
47
+ obj = json .loads (s )
48
+ except json .decoder .JSONDecodeError :
49
+ raise DeserializationError ("Corrupt metadata file. It is not a valid json file." )
50
+
51
+ metadata = _MetaData ()
52
+ metadata .version = obj .get ("version" )
53
+ metadata .python_version = obj .get ("python_version" )
54
+ metadata .serialization_module = obj .get ("serialization_module" )
55
+
56
+ if not (
57
+ metadata .version == "2023-04-24" and metadata .serialization_module == "cloudpickle"
58
+ ):
59
+ raise DeserializationError (
60
+ f"Corrupt metadata file. Serialization approach { s } is not supported."
61
+ )
62
+
63
+ return metadata
64
+
65
+
66
+ class CloudpickleSerializer :
67
+ """Serializer using cloudpickle."""
68
+
69
+ @staticmethod
70
+ def serialize (obj : Any , sagemaker_session , s3_uri : str , s3_kms_key : str = None ):
71
+ """Serializes data object and uploads it to S3.
72
+
73
+ Args:
74
+ sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service
75
+ calls are delegated to.
76
+ s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
77
+ s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
78
+ obj: object to be serialized and persisted
79
+ Raises:
80
+ SerializationError: when fail to serialize object to bytes.
81
+ """
82
+ try :
83
+ bytes_to_upload = cloudpickle .dumps (obj )
84
+ except Exception as e :
85
+ if isinstance (
86
+ e , NotImplementedError
87
+ ) and "Instance of Run type is not allowed to be pickled." in str (e ):
88
+ raise SerializationError (
89
+ """You are trying to reference to a sagemaker.experiments.run.Run instance from within the function
90
+ or passing it as a function argument.
91
+ Instantiate a Run in the function or use load_run instead."""
92
+ ) from e
93
+
94
+ raise SerializationError (
95
+ "Error when serializing object of type [{}]: {}" .format (type (obj ).__name__ , repr (e ))
96
+ ) from e
97
+
98
+ _upload_bytes_to_s3 (bytes_to_upload , s3_uri , s3_kms_key , sagemaker_session )
99
+
100
+ @staticmethod
101
+ def deserialize (sagemaker_session , s3_uri ) -> Any :
102
+ """Downloads from S3 and then deserializes data objects.
103
+
104
+ Args:
105
+ sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which AWS service
106
+ calls are delegated to.
107
+ s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
108
+ Returns :
109
+ List of deserialized python objects.
110
+ Raises:
111
+ DeserializationError: when fail to serialize object to bytes.
112
+ """
113
+ bytes_to_deserialize = _read_bytes_from_s3 (s3_uri , sagemaker_session )
114
+
115
+ try :
116
+ return cloudpickle .loads (bytes_to_deserialize )
117
+ except Exception as e :
118
+ raise DeserializationError (
119
+ "Error when deserializing bytes downloaded from {}: {}" .format (s3_uri , repr (e ))
120
+ ) from e
121
+
122
+
123
+ # TODO: use dask serializer in case dask distributed is installed in users' environment.
27
124
def serialize_func_to_s3 (func : Callable , sagemaker_session , s3_uri , s3_kms_key = None ):
28
125
"""Serializes function and uploads it to S3.
29
126
@@ -36,16 +133,13 @@ def serialize_func_to_s3(func: Callable, sagemaker_session, s3_uri, s3_kms_key=N
36
133
Raises:
37
134
SerializationError: when fail to serialize function to bytes.
38
135
"""
39
- try :
40
- bytes_to_upload = cloudpickle .dumps (func )
41
- except Exception as e :
42
- raise SerializationError (
43
- "Error when serializing function [{}]: {}" .format (
44
- getattr (func , "__name__" , repr (func )), repr (e )
45
- )
46
- ) from e
47
136
48
- _upload_bytes_to_s3 (bytes_to_upload , s3_uri , s3_kms_key , sagemaker_session )
137
+ _upload_bytes_to_s3 (
138
+ _MetaData ().to_json (), os .path .join (s3_uri , "metadata.json" ), s3_kms_key , sagemaker_session
139
+ )
140
+ CloudpickleSerializer .serialize (
141
+ func , sagemaker_session , os .path .join (s3_uri , "payload.pkl" ), s3_kms_key
142
+ )
49
143
50
144
51
145
def deserialize_func_from_s3 (sagemaker_session , s3_uri ) -> Callable :
@@ -63,16 +157,11 @@ def deserialize_func_from_s3(sagemaker_session, s3_uri) -> Callable:
63
157
Raises:
64
158
DeserializationError: when fail to serialize function to bytes.
65
159
"""
66
- bytes_to_deserialize = _read_bytes_from_s3 (s3_uri , sagemaker_session )
160
+ _MetaData .from_json (
161
+ _read_bytes_from_s3 (os .path .join (s3_uri , "metadata.json" ), sagemaker_session )
162
+ )
67
163
68
- try :
69
- return cloudpickle .loads (bytes_to_deserialize )
70
- except Exception as e :
71
- raise DeserializationError (
72
- "Error when deserializing bytes downloaded from {} to function: {}" .format (
73
- s3_uri , repr (e )
74
- )
75
- ) from e
164
+ return CloudpickleSerializer .deserialize (sagemaker_session , os .path .join (s3_uri , "payload.pkl" ))
76
165
77
166
78
167
def serialize_obj_to_s3 (obj : Any , sagemaker_session , s3_uri : str , s3_kms_key : str = None ):
@@ -87,21 +176,13 @@ def serialize_obj_to_s3(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: st
87
176
Raises:
88
177
SerializationError: when fail to serialize object to bytes.
89
178
"""
90
- try :
91
- bytes_to_upload = cloudpickle .dumps (obj )
92
- except Exception as e :
93
- if isinstance (
94
- e , NotImplementedError
95
- ) and "Instance of Run type is not allowed to be pickled." in str (e ):
96
- raise SerializationError (
97
- "Remote function does not allow parameters of Run type."
98
- ) from e
99
-
100
- raise SerializationError (
101
- "Error when serializing object of type [{}]: {}" .format (type (obj ).__name__ , repr (e ))
102
- ) from e
103
179
104
- _upload_bytes_to_s3 (bytes_to_upload , s3_uri , s3_kms_key , sagemaker_session )
180
+ _upload_bytes_to_s3 (
181
+ _MetaData ().to_json (), os .path .join (s3_uri , "metadata.json" ), s3_kms_key , sagemaker_session
182
+ )
183
+ CloudpickleSerializer .serialize (
184
+ obj , sagemaker_session , os .path .join (s3_uri , "payload.pkl" ), s3_kms_key
185
+ )
105
186
106
187
107
188
def deserialize_obj_from_s3 (sagemaker_session , s3_uri ) -> Any :
@@ -112,18 +193,59 @@ def deserialize_obj_from_s3(sagemaker_session, s3_uri) -> Any:
112
193
calls are delegated to.
113
194
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
114
195
Returns :
115
- List of deserialized python objects.
196
+ Deserialized python objects.
116
197
Raises:
117
198
DeserializationError: when fail to serialize object to bytes.
118
199
"""
119
- bytes_to_deserialize = _read_bytes_from_s3 (s3_uri , sagemaker_session )
120
200
121
- try :
122
- return cloudpickle .loads (bytes_to_deserialize )
123
- except Exception as e :
124
- raise DeserializationError (
125
- "Error when deserializing bytes downloaded from {}: {}" .format (s3_uri , repr (e ))
126
- ) from e
201
+ _MetaData .from_json (
202
+ _read_bytes_from_s3 (os .path .join (s3_uri , "metadata.json" ), sagemaker_session )
203
+ )
204
+
205
+ return CloudpickleSerializer .deserialize (sagemaker_session , os .path .join (s3_uri , "payload.pkl" ))
206
+
207
+
208
+ def serialize_exception_to_s3 (
209
+ exc : Exception , sagemaker_session , s3_uri : str , s3_kms_key : str = None
210
+ ):
211
+ """Serializes exception with traceback and uploads it to S3.
212
+
213
+ Args:
214
+ sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service
215
+ calls are delegated to.
216
+ s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
217
+ s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
218
+ exc: Exception to be serialized and persisted
219
+ Raises:
220
+ SerializationError: when fail to serialize object to bytes.
221
+ """
222
+ pickling_support .install ()
223
+ _upload_bytes_to_s3 (
224
+ _MetaData ().to_json (), os .path .join (s3_uri , "metadata.json" ), s3_kms_key , sagemaker_session
225
+ )
226
+ CloudpickleSerializer .serialize (
227
+ exc , sagemaker_session , os .path .join (s3_uri , "payload.pkl" ), s3_kms_key
228
+ )
229
+
230
+
231
+ def deserialize_exception_from_s3 (sagemaker_session , s3_uri ) -> Any :
232
+ """Downloads from S3 and then deserializes exception.
233
+
234
+ Args:
235
+ sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which AWS service
236
+ calls are delegated to.
237
+ s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
238
+ Returns :
239
+ Deserialized exception with traceback.
240
+ Raises:
241
+ DeserializationError: when fail to serialize object to bytes.
242
+ """
243
+
244
+ _MetaData .from_json (
245
+ _read_bytes_from_s3 (os .path .join (s3_uri , "metadata.json" ), sagemaker_session )
246
+ )
247
+
248
+ return CloudpickleSerializer .deserialize (sagemaker_session , os .path .join (s3_uri , "payload.pkl" ))
127
249
128
250
129
251
def _upload_bytes_to_s3 (bytes , s3_uri , s3_kms_key , sagemaker_session ):
0 commit comments