Skip to content

Commit 3bc8575

Browse files
author
Balaji Veeramani
committed
Add SerDe base classes
1 parent b786a51 commit 3bc8575

File tree

5 files changed

+254
-21
lines changed

5 files changed

+254
-21
lines changed

src/sagemaker/deserializers.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Implements methods for deserializing data returned from an inference endpoint."""
14+
from __future__ import absolute_import
15+
16+
import abc
17+
18+
19+
class BaseDeserializer(abc.ABC):
20+
"""Abstract base class for creation of new deserializers.
21+
22+
Provides a skeleton for customization requiring the overriding of the method
23+
deserialize and the class attribute ACCEPT.
24+
"""
25+
26+
@abc.abstractmethod
27+
def deserialize(self, data, content_type):
28+
"""Deserialize data received from an inference endpoint.
29+
30+
Args:
31+
data (object): Data to be deserialized.
32+
content_type (str): The MIME type of the data.
33+
34+
Returns:
35+
object: The data deserialized into an object.
36+
"""
37+
raise NotImplementedError
38+
39+
@property
40+
@abc.abstractmethod
41+
def ACCEPT(self):
42+
"""The content type that is expected from the inference endpoint."""
43+
raise NotImplementedError

src/sagemaker/predictor.py

Lines changed: 101 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
import numpy as np
2222

2323
from sagemaker.content_types import CONTENT_TYPE_JSON, CONTENT_TYPE_CSV, CONTENT_TYPE_NPY
24+
from sagemaker.deserializers import BaseDeserializer
2425
from sagemaker.model_monitor import DataCaptureConfig
26+
from sagemaker.serializers import BaseSerializer
2527
from sagemaker.session import production_variant, Session
2628
from sagemaker.utils import name_from_base
2729

@@ -59,27 +61,23 @@ def __init__(
5961
object, used for SageMaker interactions (default: None). If not
6062
specified, one is created using the default AWS configuration
6163
chain.
62-
serializer (callable): Accepts a single argument, the input data,
63-
and returns a sequence of bytes. It may provide a
64-
``content_type`` attribute that defines the endpoint request
65-
content type. If not specified, a sequence of bytes is expected
66-
for the data.
67-
deserializer (callable): Accepts two arguments, the result data and
68-
the response content type, and returns a sequence of bytes. It
69-
may provide a ``content_type`` attribute that defines the
70-
endpoint response's "Accept" content type. If not specified, a
71-
sequence of bytes is expected for the data.
64+
serializer (sagemaker.serializers.BaseSerializer): A serializer
65+
object, used to encode data for an inference endpoint
66+
(default: None).
67+
deserializer (sagemaker.deserializers.BaseDeserializer): A
68+
deserializer object, used to decode data from an inference
69+
endpoint (default: None).
7270
content_type (str): The invocation's "ContentType", overriding any
73-
``content_type`` from the serializer (default: None).
71+
``CONTENT_TYPE`` from the serializer (default: None).
7472
accept (str): The invocation's "Accept", overriding any accept from
7573
the deserializer (default: None).
7674
"""
7775
self.endpoint_name = endpoint_name
7876
self.sagemaker_session = sagemaker_session or Session()
7977
self.serializer = serializer
8078
self.deserializer = deserializer
81-
self.content_type = content_type or getattr(serializer, "content_type", None)
82-
self.accept = accept or getattr(deserializer, "accept", None)
79+
self.content_type = content_type or getattr(serializer, "CONTENT_TYPE", None)
80+
self.accept = accept or getattr(deserializer, "ACCEPT", None)
8381
self._endpoint_config_name = self._get_endpoint_config_name()
8482
self._model_names = self._get_model_names()
8583

@@ -121,7 +119,7 @@ def _handle_response(self, response):
121119
response_body = response["Body"]
122120
if self.deserializer is not None:
123121
# It's the deserializer's responsibility to close the stream
124-
return self.deserializer(response_body, response["ContentType"])
122+
return self.deserializer.deserialize(response_body, response["ContentType"])
125123
data = response_body.read()
126124
response_body.close()
127125
return data
@@ -152,7 +150,7 @@ def _create_request_args(self, data, initial_args=None, target_model=None, targe
152150
args["TargetVariant"] = target_variant
153151

154152
if self.serializer is not None:
155-
data = self.serializer(data)
153+
data = self.serializer.serialize(data)
156154

157155
args["Body"] = data
158156
return args
@@ -406,6 +404,88 @@ def _get_model_names(self):
406404
return [d["ModelName"] for d in production_variants]
407405

408406

407+
class LegacySerializer(BaseSerializer):
408+
"""Wrapper that makes legacy serializers forward compatibile."""
409+
410+
def __init__(self, serializer):
411+
"""Placeholder docstring.
412+
413+
Args:
414+
serializer (callable): A legacy serializer.
415+
"""
416+
self.serializer = serializer
417+
self.content_type = self.serializer.content_type
418+
419+
def __call__(self, *args, **kwargs):
420+
"""Wraps the call method of the legacy serializer.
421+
422+
Args:
423+
data (object): Data to be serialized.
424+
425+
Returns:
426+
object: Serialized data used for a request.
427+
"""
428+
return self.serializer(*args, **kwargs)
429+
430+
def serialize(self, data):
431+
"""Wraps the call method of the legacy serializer.
432+
433+
Args:
434+
data (object): Data to be serialized.
435+
436+
Returns:
437+
object: Serialized data used for a request.
438+
"""
439+
return self.serializer(data)
440+
441+
@property
442+
def CONTENT_TYPE(self):
443+
"""The MIME type of the data sent to the inference endpoint."""
444+
return self.content_type
445+
446+
447+
class LegacyDeserializer(BaseDeserializer):
448+
"""Wrapper that makes legacy deserializers forward compatibile."""
449+
450+
def __init__(self, deserializer):
451+
"""Placeholder docstring.
452+
453+
Args:
454+
deserializer (callable): A legacy deserializer.
455+
"""
456+
self.deserializer = deserializer
457+
self.accept = deserializer.accept
458+
459+
def __call__(self, *args, **kwargs):
460+
"""Wraps the call method of the legacy deserializer.
461+
462+
Args:
463+
data (object): Data to be deserialized.
464+
content_type (str): The MIME type of the data.
465+
466+
Returns:
467+
object: The data deserialized into an object.
468+
"""
469+
return self.deserializer(*args, **kwargs)
470+
471+
def deserialize(self, data, content_type):
472+
"""Wraps the call method of the legacy deserializer.
473+
474+
Args:
475+
data (object): Data to be deserialized.
476+
content_type (str): The MIME type of the data.
477+
478+
Returns:
479+
object: The data deserialized into an object.
480+
"""
481+
return self.deserializer(data, content_type)
482+
483+
@property
484+
def ACCEPT(self):
485+
"""The content type that is expected from the inference endpoint."""
486+
return self.accept
487+
488+
409489
class _CsvSerializer(object):
410490
"""Placeholder docstring"""
411491

@@ -479,7 +559,7 @@ def _csv_serialize_object(data):
479559
return csv_buffer.getvalue().rstrip("\r\n")
480560

481561

482-
csv_serializer = _CsvSerializer()
562+
csv_serializer = LegacySerializer(_CsvSerializer())
483563

484564

485565
def _is_mutable_sequence_like(obj):
@@ -531,7 +611,7 @@ def __call__(self, stream, content_type):
531611
stream.close()
532612

533613

534-
csv_deserializer = _CsvDeserializer()
614+
csv_deserializer = LegacyDeserializer(_CsvDeserializer())
535615

536616

537617
class BytesDeserializer(object):
@@ -644,7 +724,7 @@ def __call__(self, data):
644724
return json.dumps(_ndarray_to_list(data))
645725

646726

647-
json_serializer = _JsonSerializer()
727+
json_serializer = LegacySerializer(_JsonSerializer())
648728

649729

650730
def _ndarray_to_list(data):
@@ -686,7 +766,7 @@ def __call__(self, stream, content_type):
686766
stream.close()
687767

688768

689-
json_deserializer = _JsonDeserializer()
769+
json_deserializer = LegacyDeserializer(_JsonDeserializer())
690770

691771

692772
class _NumpyDeserializer(object):
@@ -730,7 +810,7 @@ def __call__(self, stream, content_type=CONTENT_TYPE_NPY):
730810
)
731811

732812

733-
numpy_deserializer = _NumpyDeserializer()
813+
numpy_deserializer = LegacyDeserializer(_NumpyDeserializer())
734814

735815

736816
class _NPYSerializer(object):
@@ -778,4 +858,4 @@ def _npy_serialize(data):
778858
return buffer.getvalue()
779859

780860

781-
npy_serializer = _NPYSerializer()
861+
npy_serializer = LegacySerializer(_NPYSerializer())

src/sagemaker/serializers.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Implements methods for serializing data for an inference endpoint."""
14+
from __future__ import absolute_import
15+
16+
import abc
17+
18+
19+
class BaseSerializer(abc.ABC):
20+
"""Abstract base class for creation of new serializers.
21+
22+
Provides a skeleton for customization requiring the overriding of the method
23+
serialize and the class attribute CONTENT_TYPE.
24+
"""
25+
26+
@abc.abstractmethod
27+
def serialize(self, data):
28+
"""Serialize data into the media type specified by CONTENT_TYPE.
29+
30+
Args:
31+
data (object): Data to be serialized.
32+
33+
Returns:
34+
object: Serialized data used for a request.
35+
"""
36+
raise NotImplementedError
37+
38+
@property
39+
@abc.abstractmethod
40+
def CONTENT_TYPE(self):
41+
"""The MIME type of the data sent to the inference endpoint."""
42+
raise NotImplementedError
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright 2018-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pytest
16+
17+
from sagemaker import deserializers
18+
19+
20+
class BaseDeserializerTest:
21+
def test_instantiate_deserializer_without_accept_attribute(self):
22+
class StubDeserializer(deserializers.BaseDeserializer):
23+
def deserialize(self, data, content_type):
24+
pass
25+
26+
with pytest.raises(TypeError):
27+
StubDeserializer()
28+
29+
def test_instantiate_deserializer_without_deserialize_method(self):
30+
class StubDeserializer(deserializers.BaseDeserializer):
31+
ACCEPT = "application/json"
32+
33+
with pytest.raises(TypeError):
34+
StubDeserializer()
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright 2018-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pytest
16+
17+
from sagemaker import serializers
18+
19+
20+
class BaseSerializerTest:
21+
def test_instantiate_serializer_without_content_type_attribute(self):
22+
class StubSerializer(serializers.BaseSerializer):
23+
def serialize(self, data):
24+
pass
25+
26+
with pytest.raises(TypeError):
27+
StubSerializer()
28+
29+
def test_instantiate_serializer_without_serialize_method(self):
30+
class StubSerializer(serializers.BaseSerializer):
31+
CONTENT_TYPE = "application/json"
32+
33+
with pytest.raises(TypeError):
34+
StubSerializer()

0 commit comments

Comments
 (0)