-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feature: add BaseSerializer and BaseDeserializer #1668
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
3bc8575
2d69f65
0717895
a5fd621
bd909bb
cb249f8
8072555
d6fa55d
148c1e9
539978d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""Implements methods for deserializing data returned from an inference endpoint.""" | ||
from __future__ import absolute_import | ||
|
||
import abc | ||
|
||
|
||
class BaseDeserializer(abc.ABC): | ||
"""Abstract base class for creation of new deserializers. | ||
|
||
Provides a skeleton for customization requiring the overriding of the method | ||
deserialize and the class attribute ACCEPT. | ||
""" | ||
|
||
@abc.abstractmethod | ||
def deserialize(self, data, content_type): | ||
"""Deserialize data received from an inference endpoint. | ||
|
||
Args: | ||
data (object): Data to be deserialized. | ||
content_type (str): The MIME type of the data. | ||
|
||
Returns: | ||
object: The data deserialized into an object. | ||
""" | ||
raise NotImplementedError | ||
bveeramani marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@property | ||
@abc.abstractmethod | ||
def ACCEPT(self): | ||
bveeramani marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""The content type that is expected from the inference endpoint.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. have you given any more thought on how to support serdes (serializers/deserializers) that can and should be able to accept and operate on a collection of content types? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, the user will be able to specify a string, list, or dictionary of accept types. Because adding support for multiple accept types necessitates changes to the sagemaker-inference-toolkit, I'm planning on first migrating the existing serializers/deserializers (which have only one accept type) and then adding support for multiple accept types. |
||
raise NotImplementedError | ||
bveeramani marked this conversation as resolved.
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,9 @@ | |
import numpy as np | ||
|
||
from sagemaker.content_types import CONTENT_TYPE_JSON, CONTENT_TYPE_CSV, CONTENT_TYPE_NPY | ||
from sagemaker.deserializers import BaseDeserializer | ||
from sagemaker.model_monitor import DataCaptureConfig | ||
from sagemaker.serializers import BaseSerializer | ||
from sagemaker.session import production_variant, Session | ||
from sagemaker.utils import name_from_base | ||
|
||
|
@@ -59,27 +61,23 @@ def __init__( | |
object, used for SageMaker interactions (default: None). If not | ||
specified, one is created using the default AWS configuration | ||
chain. | ||
serializer (callable): Accepts a single argument, the input data, | ||
and returns a sequence of bytes. It may provide a | ||
``content_type`` attribute that defines the endpoint request | ||
content type. If not specified, a sequence of bytes is expected | ||
for the data. | ||
deserializer (callable): Accepts two arguments, the result data and | ||
the response content type, and returns a sequence of bytes. It | ||
may provide a ``content_type`` attribute that defines the | ||
endpoint response's "Accept" content type. If not specified, a | ||
sequence of bytes is expected for the data. | ||
serializer (sagemaker.serializers.BaseSerializer): A serializer | ||
object, used to encode data for an inference endpoint | ||
(default: None). | ||
deserializer (sagemaker.deserializers.BaseDeserializer): A | ||
deserializer object, used to decode data from an inference | ||
endpoint (default: None). | ||
content_type (str): The invocation's "ContentType", overriding any | ||
``content_type`` from the serializer (default: None). | ||
``CONTENT_TYPE`` from the serializer (default: None). | ||
accept (str): The invocation's "Accept", overriding any accept from | ||
the deserializer (default: None). | ||
bveeramani marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
self.endpoint_name = endpoint_name | ||
self.sagemaker_session = sagemaker_session or Session() | ||
self.serializer = serializer | ||
self.deserializer = deserializer | ||
self.content_type = content_type or getattr(serializer, "content_type", None) | ||
self.accept = accept or getattr(deserializer, "accept", None) | ||
self.content_type = content_type or getattr(serializer, "CONTENT_TYPE", None) | ||
self.accept = accept or getattr(deserializer, "ACCEPT", None) | ||
bveeramani marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._endpoint_config_name = self._get_endpoint_config_name() | ||
self._model_names = self._get_model_names() | ||
|
||
|
@@ -121,7 +119,7 @@ def _handle_response(self, response): | |
response_body = response["Body"] | ||
if self.deserializer is not None: | ||
# It's the deserializer's responsibility to close the stream | ||
return self.deserializer(response_body, response["ContentType"]) | ||
return self.deserializer.deserialize(response_body, response["ContentType"]) | ||
data = response_body.read() | ||
response_body.close() | ||
return data | ||
|
@@ -152,7 +150,7 @@ def _create_request_args(self, data, initial_args=None, target_model=None, targe | |
args["TargetVariant"] = target_variant | ||
|
||
if self.serializer is not None: | ||
data = self.serializer(data) | ||
data = self.serializer.serialize(data) | ||
|
||
args["Body"] = data | ||
return args | ||
|
@@ -406,6 +404,88 @@ def _get_model_names(self): | |
return [d["ModelName"] for d in production_variants] | ||
|
||
|
||
class LegacySerializer(BaseSerializer): | ||
"""Wrapper that makes legacy serializers forward compatibile.""" | ||
|
||
def __init__(self, serializer): | ||
"""Placeholder docstring. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't forget to fill this in :) |
||
|
||
Args: | ||
serializer (callable): A legacy serializer. | ||
""" | ||
self.serializer = serializer | ||
self.content_type = self.serializer.content_type | ||
|
||
def __call__(self, *args, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are the function signatures that different across legacy implementations? I think it would be better to avoid There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some of the classes have different function signatures. For example, |
||
"""Wraps the call method of the legacy serializer. | ||
|
||
Args: | ||
data (object): Data to be serialized. | ||
|
||
Returns: | ||
object: Serialized data used for a request. | ||
""" | ||
return self.serializer(*args, **kwargs) | ||
|
||
def serialize(self, data): | ||
"""Wraps the call method of the legacy serializer. | ||
|
||
Args: | ||
data (object): Data to be serialized. | ||
|
||
Returns: | ||
object: Serialized data used for a request. | ||
""" | ||
return self.serializer(data) | ||
|
||
@property | ||
def CONTENT_TYPE(self): | ||
bveeramani marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""The MIME type of the data sent to the inference endpoint.""" | ||
return self.content_type | ||
|
||
|
||
class LegacyDeserializer(BaseDeserializer): | ||
bveeramani marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Wrapper that makes legacy deserializers forward compatibile.""" | ||
|
||
def __init__(self, deserializer): | ||
"""Placeholder docstring. | ||
|
||
Args: | ||
deserializer (callable): A legacy deserializer. | ||
""" | ||
self.deserializer = deserializer | ||
self.accept = deserializer.accept | ||
|
||
def __call__(self, *args, **kwargs): | ||
"""Wraps the call method of the legacy deserializer. | ||
|
||
Args: | ||
data (object): Data to be deserialized. | ||
content_type (str): The MIME type of the data. | ||
|
||
Returns: | ||
object: The data deserialized into an object. | ||
""" | ||
return self.deserializer(*args, **kwargs) | ||
|
||
def deserialize(self, data, content_type): | ||
"""Wraps the call method of the legacy deserializer. | ||
|
||
Args: | ||
data (object): Data to be deserialized. | ||
content_type (str): The MIME type of the data. | ||
|
||
Returns: | ||
object: The data deserialized into an object. | ||
""" | ||
return self.deserializer(data, content_type) | ||
|
||
@property | ||
def ACCEPT(self): | ||
"""The content type that is expected from the inference endpoint.""" | ||
return self.accept | ||
|
||
|
||
class _CsvSerializer(object): | ||
"""Placeholder docstring""" | ||
|
||
|
@@ -479,7 +559,7 @@ def _csv_serialize_object(data): | |
return csv_buffer.getvalue().rstrip("\r\n") | ||
|
||
|
||
csv_serializer = _CsvSerializer() | ||
csv_serializer = LegacySerializer(_CsvSerializer()) | ||
|
||
|
||
def _is_mutable_sequence_like(obj): | ||
|
@@ -531,7 +611,7 @@ def __call__(self, stream, content_type): | |
stream.close() | ||
|
||
|
||
csv_deserializer = _CsvDeserializer() | ||
csv_deserializer = LegacyDeserializer(_CsvDeserializer()) | ||
|
||
|
||
class BytesDeserializer(object): | ||
|
@@ -644,7 +724,7 @@ def __call__(self, data): | |
return json.dumps(_ndarray_to_list(data)) | ||
|
||
|
||
json_serializer = _JsonSerializer() | ||
json_serializer = LegacySerializer(_JsonSerializer()) | ||
|
||
|
||
def _ndarray_to_list(data): | ||
|
@@ -686,7 +766,7 @@ def __call__(self, stream, content_type): | |
stream.close() | ||
|
||
|
||
json_deserializer = _JsonDeserializer() | ||
json_deserializer = LegacyDeserializer(_JsonDeserializer()) | ||
|
||
|
||
class _NumpyDeserializer(object): | ||
|
@@ -730,7 +810,7 @@ def __call__(self, stream, content_type=CONTENT_TYPE_NPY): | |
) | ||
|
||
|
||
numpy_deserializer = _NumpyDeserializer() | ||
numpy_deserializer = LegacyDeserializer(_NumpyDeserializer()) | ||
|
||
|
||
class _NPYSerializer(object): | ||
|
@@ -778,4 +858,4 @@ def _npy_serialize(data): | |
return buffer.getvalue() | ||
|
||
|
||
npy_serializer = _NPYSerializer() | ||
npy_serializer = LegacySerializer(_NPYSerializer()) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""Implements methods for serializing data for an inference endpoint.""" | ||
from __future__ import absolute_import | ||
|
||
import abc | ||
|
||
|
||
class BaseSerializer(abc.ABC): | ||
"""Abstract base class for creation of new serializers. | ||
|
||
Provides a skeleton for customization requiring the overriding of the method | ||
serialize and the class attribute CONTENT_TYPE. | ||
""" | ||
|
||
@abc.abstractmethod | ||
def serialize(self, data): | ||
"""Serialize data into the media type specified by CONTENT_TYPE. | ||
|
||
Args: | ||
data (object): Data to be serialized. | ||
|
||
Returns: | ||
object: Serialized data used for a request. | ||
""" | ||
raise NotImplementedError | ||
|
||
@property | ||
@abc.abstractmethod | ||
def CONTENT_TYPE(self): | ||
"""The MIME type of the data sent to the inference endpoint.""" | ||
raise NotImplementedError |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# Copyright 2018-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
from __future__ import absolute_import | ||
|
||
import pytest | ||
|
||
from sagemaker import deserializers | ||
|
||
|
||
class BaseDeserializerTest: | ||
def test_instantiate_deserializer_without_accept_attribute(self): | ||
bveeramani marked this conversation as resolved.
Show resolved
Hide resolved
|
||
class StubDeserializer(deserializers.BaseDeserializer): | ||
def deserialize(self, data, content_type): | ||
pass | ||
|
||
with pytest.raises(TypeError): | ||
StubDeserializer() | ||
|
||
def test_instantiate_deserializer_without_deserialize_method(self): | ||
bveeramani marked this conversation as resolved.
Show resolved
Hide resolved
|
||
class StubDeserializer(deserializers.BaseDeserializer): | ||
ACCEPT = "application/json" | ||
|
||
with pytest.raises(TypeError): | ||
StubDeserializer() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# Copyright 2018-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
from __future__ import absolute_import | ||
|
||
import pytest | ||
|
||
from sagemaker import serializers | ||
|
||
|
||
class BaseSerializerTest: | ||
def test_instantiate_serializer_without_content_type_attribute(self): | ||
class StubSerializer(serializers.BaseSerializer): | ||
def serialize(self, data): | ||
pass | ||
|
||
with pytest.raises(TypeError): | ||
StubSerializer() | ||
|
||
def test_instantiate_serializer_without_serialize_method(self): | ||
class StubSerializer(serializers.BaseSerializer): | ||
CONTENT_TYPE = "application/json" | ||
|
||
with pytest.raises(TypeError): | ||
StubSerializer() |
Uh oh!
There was an error while loading. Please reload this page.