-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feature: add BaseSerializer and BaseDeserializer #1668
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3bc8575
2d69f65
0717895
a5fd621
bd909bb
cb249f8
8072555
d6fa55d
148c1e9
539978d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""Implements methods for deserializing data returned from an inference endpoint.""" | ||
from __future__ import absolute_import | ||
|
||
import abc | ||
|
||
|
||
class BaseDeserializer(abc.ABC): | ||
"""Abstract base class for creation of new deserializers. | ||
|
||
Provides a skeleton for customization requiring the overriding of the method | ||
deserialize and the class attribute ACCEPT. | ||
""" | ||
|
||
@abc.abstractmethod | ||
def deserialize(self, data, content_type): | ||
"""Deserialize data received from an inference endpoint. | ||
|
||
Args: | ||
data (object): Data to be deserialized. | ||
content_type (str): The MIME type of the data. | ||
|
||
Returns: | ||
object: The data deserialized into an object. | ||
""" | ||
|
||
@property | ||
@abc.abstractmethod | ||
def ACCEPT(self): | ||
bveeramani marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""The content type that is expected from the inference endpoint.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. have you given any more thought on how to support serdes (serializers/deserializers) that can and should be able to accept and operate on a collection of content types? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, the user will be able to specify a string, list, or dictionary of accept types. Because adding support for multiple accept types necessitates changes to the sagemaker-inference-toolkit, I'm planning on first migrating the existing serializers/deserializers (which have only one accept type) and then adding support for multiple accept types. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,9 @@ | |
import numpy as np | ||
|
||
from sagemaker.content_types import CONTENT_TYPE_JSON, CONTENT_TYPE_CSV, CONTENT_TYPE_NPY | ||
from sagemaker.deserializers import BaseDeserializer | ||
from sagemaker.model_monitor import DataCaptureConfig | ||
from sagemaker.serializers import BaseSerializer | ||
from sagemaker.session import production_variant, Session | ||
from sagemaker.utils import name_from_base | ||
|
||
|
@@ -59,27 +61,28 @@ def __init__( | |
object, used for SageMaker interactions (default: None). If not | ||
specified, one is created using the default AWS configuration | ||
chain. | ||
serializer (callable): Accepts a single argument, the input data, | ||
and returns a sequence of bytes. It may provide a | ||
``content_type`` attribute that defines the endpoint request | ||
content type. If not specified, a sequence of bytes is expected | ||
for the data. | ||
deserializer (callable): Accepts two arguments, the result data and | ||
the response content type, and returns a sequence of bytes. It | ||
may provide a ``content_type`` attribute that defines the | ||
endpoint response's "Accept" content type. If not specified, a | ||
sequence of bytes is expected for the data. | ||
serializer (sagemaker.serializers.BaseSerializer): A serializer | ||
object, used to encode data for an inference endpoint | ||
(default: None). | ||
deserializer (sagemaker.deserializers.BaseDeserializer): A | ||
deserializer object, used to decode data from an inference | ||
endpoint (default: None). | ||
content_type (str): The invocation's "ContentType", overriding any | ||
``content_type`` from the serializer (default: None). | ||
``CONTENT_TYPE`` from the serializer (default: None). | ||
accept (str): The invocation's "Accept", overriding any accept from | ||
the deserializer (default: None). | ||
bveeramani marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
if serializer is not None and not isinstance(serializer, BaseSerializer): | ||
serializer = LegacySerializer(serializer) | ||
if deserializer is not None and not isinstance(deserializer, BaseDeserializer): | ||
deserializer = LegacyDeserializer(deserializer) | ||
|
||
self.endpoint_name = endpoint_name | ||
self.sagemaker_session = sagemaker_session or Session() | ||
self.serializer = serializer | ||
self.deserializer = deserializer | ||
self.content_type = content_type or getattr(serializer, "content_type", None) | ||
self.accept = accept or getattr(deserializer, "accept", None) | ||
self.content_type = content_type or getattr(serializer, "CONTENT_TYPE", None) | ||
self.accept = accept or getattr(deserializer, "ACCEPT", None) | ||
bveeramani marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._endpoint_config_name = self._get_endpoint_config_name() | ||
self._model_names = self._get_model_names() | ||
|
||
|
@@ -120,8 +123,10 @@ def _handle_response(self, response): | |
""" | ||
response_body = response["Body"] | ||
if self.deserializer is not None: | ||
if not isinstance(self.deserializer, BaseDeserializer): | ||
self.deserializer = LegacyDeserializer(self.deserializer) | ||
# It's the deserializer's responsibility to close the stream | ||
return self.deserializer(response_body, response["ContentType"]) | ||
return self.deserializer.deserialize(response_body, response["ContentType"]) | ||
data = response_body.read() | ||
response_body.close() | ||
return data | ||
|
@@ -152,7 +157,9 @@ def _create_request_args(self, data, initial_args=None, target_model=None, targe | |
args["TargetVariant"] = target_variant | ||
|
||
if self.serializer is not None: | ||
data = self.serializer(data) | ||
if not isinstance(self.serializer, BaseSerializer): | ||
self.serializer = LegacySerializer(self.serializer) | ||
data = self.serializer.serialize(data) | ||
|
||
args["Body"] = data | ||
return args | ||
|
@@ -406,6 +413,88 @@ def _get_model_names(self): | |
return [d["ModelName"] for d in production_variants] | ||
|
||
|
||
class LegacySerializer(BaseSerializer): | ||
"""Wrapper that makes legacy serializers forward compatibile.""" | ||
|
||
def __init__(self, serializer): | ||
"""Initialize a ``LegacySerializer``. | ||
|
||
Args: | ||
serializer (callable): A legacy serializer. | ||
""" | ||
self.serializer = serializer | ||
self.content_type = getattr(serializer, "content_type", None) | ||
|
||
def __call__(self, *args, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are the function signatures that different across legacy implementations? I think it would be better to avoid There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some of the classes have different function signatures. For example, |
||
"""Wraps the call method of the legacy serializer. | ||
|
||
Args: | ||
data (object): Data to be serialized. | ||
|
||
Returns: | ||
object: Serialized data used for a request. | ||
""" | ||
return self.serializer(*args, **kwargs) | ||
|
||
def serialize(self, data): | ||
"""Wraps the call method of the legacy serializer. | ||
|
||
Args: | ||
data (object): Data to be serialized. | ||
|
||
Returns: | ||
object: Serialized data used for a request. | ||
""" | ||
return self.serializer(data) | ||
|
||
@property | ||
def CONTENT_TYPE(self): | ||
bveeramani marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""The MIME type of the data sent to the inference endpoint.""" | ||
return self.content_type | ||
|
||
|
||
class LegacyDeserializer(BaseDeserializer): | ||
bveeramani marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Wrapper that makes legacy deserializers forward compatibile.""" | ||
|
||
def __init__(self, deserializer): | ||
"""Initialize a ``LegacyDeserializer``. | ||
|
||
Args: | ||
deserializer (callable): A legacy deserializer. | ||
""" | ||
self.deserializer = deserializer | ||
self.accept = getattr(deserializer, "accept", None) | ||
|
||
def __call__(self, *args, **kwargs): | ||
"""Wraps the call method of the legacy deserializer. | ||
|
||
Args: | ||
data (object): Data to be deserialized. | ||
content_type (str): The MIME type of the data. | ||
|
||
Returns: | ||
object: The data deserialized into an object. | ||
""" | ||
return self.deserializer(*args, **kwargs) | ||
|
||
def deserialize(self, data, content_type): | ||
"""Wraps the call method of the legacy deserializer. | ||
|
||
Args: | ||
data (object): Data to be deserialized. | ||
content_type (str): The MIME type of the data. | ||
|
||
Returns: | ||
object: The data deserialized into an object. | ||
""" | ||
return self.deserializer(data, content_type) | ||
|
||
@property | ||
def ACCEPT(self): | ||
"""The content type that is expected from the inference endpoint.""" | ||
return self.accept | ||
|
||
|
||
class _CsvSerializer(object): | ||
"""Placeholder docstring""" | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""Implements methods for serializing data for an inference endpoint.""" | ||
from __future__ import absolute_import | ||
|
||
import abc | ||
|
||
|
||
class BaseSerializer(abc.ABC): | ||
"""Abstract base class for creation of new serializers. | ||
|
||
Provides a skeleton for customization requiring the overriding of the method | ||
serialize and the class attribute CONTENT_TYPE. | ||
""" | ||
|
||
@abc.abstractmethod | ||
def serialize(self, data): | ||
"""Serialize data into the media type specified by CONTENT_TYPE. | ||
|
||
Args: | ||
data (object): Data to be serialized. | ||
|
||
Returns: | ||
object: Serialized data used for a request. | ||
""" | ||
|
||
@property | ||
@abc.abstractmethod | ||
def CONTENT_TYPE(self): | ||
"""The MIME type of the data sent to the inference endpoint.""" |
Uh oh!
There was an error while loading. Please reload this page.