Skip to content

feature: Remove LegacySerializer and LegacyDeserializer #1741

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 0 additions & 93 deletions src/sagemaker/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
"""Placeholder docstring"""
from __future__ import print_function, absolute_import

from sagemaker.deserializers import BaseDeserializer
from sagemaker.model_monitor import DataCaptureConfig
from sagemaker.serializers import BaseSerializer
from sagemaker.session import production_variant, Session
from sagemaker.utils import name_from_base

Expand Down Expand Up @@ -64,11 +62,6 @@ def __init__(
accept (str): The invocation's "Accept", overriding any accept from
the deserializer (default: None).
"""
if serializer is not None and not isinstance(serializer, BaseSerializer):
serializer = LegacySerializer(serializer)
if deserializer is not None and not isinstance(deserializer, BaseDeserializer):
deserializer = LegacyDeserializer(deserializer)

self.endpoint_name = endpoint_name
self.sagemaker_session = sagemaker_session or Session()
self.serializer = serializer
Expand Down Expand Up @@ -115,8 +108,6 @@ def _handle_response(self, response):
"""
response_body = response["Body"]
if self.deserializer is not None:
if not isinstance(self.deserializer, BaseDeserializer):
self.deserializer = LegacyDeserializer(self.deserializer)
# It's the deserializer's responsibility to close the stream
return self.deserializer.deserialize(response_body, response["ContentType"])
data = response_body.read()
Expand Down Expand Up @@ -149,8 +140,6 @@ def _create_request_args(self, data, initial_args=None, target_model=None, targe
args["TargetVariant"] = target_variant

if self.serializer is not None:
if not isinstance(self.serializer, BaseSerializer):
self.serializer = LegacySerializer(self.serializer)
data = self.serializer.serialize(data)

args["Body"] = data
Expand Down Expand Up @@ -403,85 +392,3 @@ def _get_model_names(self):
)
production_variants = endpoint_config["ProductionVariants"]
return [d["ModelName"] for d in production_variants]


class LegacySerializer(BaseSerializer):
"""Wrapper that makes legacy serializers forward compatibile."""

def __init__(self, serializer):
"""Initialize a ``LegacySerializer``.

Args:
serializer (callable): A legacy serializer.
"""
self.serializer = serializer
self.content_type = getattr(serializer, "content_type", None)

def __call__(self, *args, **kwargs):
"""Wraps the call method of the legacy serializer.

Args:
data (object): Data to be serialized.

Returns:
object: Serialized data used for a request.
"""
return self.serializer(*args, **kwargs)

def serialize(self, data):
"""Wraps the call method of the legacy serializer.

Args:
data (object): Data to be serialized.

Returns:
object: Serialized data used for a request.
"""
return self.serializer(data)

@property
def CONTENT_TYPE(self):
"""The MIME type of the data sent to the inference endpoint."""
return self.content_type


class LegacyDeserializer(BaseDeserializer):
"""Wrapper that makes legacy deserializers forward compatibile."""

def __init__(self, deserializer):
"""Initialize a ``LegacyDeserializer``.

Args:
deserializer (callable): A legacy deserializer.
"""
self.deserializer = deserializer
self.accept = getattr(deserializer, "accept", None)

def __call__(self, *args, **kwargs):
"""Wraps the call method of the legacy deserializer.

Args:
data (object): Data to be deserialized.
content_type (str): The MIME type of the data.

Returns:
object: The data deserialized into an object.
"""
return self.deserializer(*args, **kwargs)

def deserialize(self, data, content_type):
"""Wraps the call method of the legacy deserializer.

Args:
data (object): Data to be deserialized.
content_type (str): The MIME type of the data.

Returns:
object: The data deserialized into an object.
"""
return self.deserializer(data, content_type)

@property
def ACCEPT(self):
"""The content type that is expected from the inference endpoint."""
return self.accept
21 changes: 12 additions & 9 deletions tests/integ/test_byo_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import sagemaker
from sagemaker import image_uris
from sagemaker.estimator import Estimator
from sagemaker.serializers import BaseSerializer
from sagemaker.utils import unique_name_from_base
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES, datasets
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
Expand All @@ -35,11 +36,15 @@ def training_set():
return datasets.one_p_mnist()


def fm_serializer(data):
js = {"instances": []}
for row in data:
js["instances"].append({"features": row.tolist()})
return json.dumps(js)
class _FactorizationMachineSerializer(BaseSerializer):

CONTENT_TYPE = "application/json"

def serialize(self, data):
js = {"instances": []}
for row in data:
js["instances"].append({"features": row.tolist()})
return json.dumps(js)


@pytest.mark.canary_quick
Expand Down Expand Up @@ -84,8 +89,7 @@ def test_byo_estimator(sagemaker_session, region, cpu_instance_type, training_se
with timeout_and_delete_endpoint_by_name(job_name, sagemaker_session):
model = estimator.create_model()
predictor = model.deploy(1, cpu_instance_type, endpoint_name=job_name)
predictor.serializer = fm_serializer
predictor.content_type = "application/json"
predictor.serializer = _FactorizationMachineSerializer()
predictor.deserializer = sagemaker.deserializers.JSONDeserializer()

result = predictor.predict(training_set[0][:10])
Expand Down Expand Up @@ -130,8 +134,7 @@ def test_async_byo_estimator(sagemaker_session, region, cpu_instance_type, train
)
model = estimator.create_model()
predictor = model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
predictor.serializer = fm_serializer
predictor.content_type = "application/json"
predictor.serializer = _FactorizationMachineSerializer()
predictor.deserializer = sagemaker.deserializers.JSONDeserializer()

result = predictor.predict(training_set[0][:10])
Expand Down
18 changes: 11 additions & 7 deletions tests/integ/test_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from sagemaker.estimator import Estimator
from sagemaker.mxnet.estimator import MXNet
from sagemaker.pytorch import PyTorch
from sagemaker.serializers import BaseSerializer
from sagemaker.tensorflow import TensorFlow
from sagemaker.tuner import (
IntegerParameter,
Expand Down Expand Up @@ -900,8 +901,7 @@ def test_tuning_byo_estimator(sagemaker_session, cpu_instance_type):
best_training_job = tuner.best_training_job()
with timeout_and_delete_endpoint_by_name(best_training_job, sagemaker_session):
predictor = tuner.deploy(1, cpu_instance_type, endpoint_name=best_training_job)
predictor.serializer = _fm_serializer
predictor.content_type = "application/json"
predictor.serializer = _FactorizationMachineSerializer()
predictor.deserializer = JSONDeserializer()

result = predictor.predict(datasets.one_p_mnist()[0][:10])
Expand All @@ -912,11 +912,15 @@ def test_tuning_byo_estimator(sagemaker_session, cpu_instance_type):


# Serializer for the Factorization Machines predictor (for BYO example)
def _fm_serializer(data):
js = {"instances": []}
for row in data:
js["instances"].append({"features": row.tolist()})
return json.dumps(js)
class _FactorizationMachineSerializer(BaseSerializer):

CONTENT_TYPE = "application/json"

def serialize(self, data):
js = {"instances": []}
for row in data:
js["instances"].append({"features": row.tolist()})
return json.dumps(js)


def _assert_model_name_match(sagemaker_client, endpoint_config_name, model_name):
Expand Down
18 changes: 11 additions & 7 deletions tests/integ/test_tuner_multi_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from sagemaker.analytics import HyperparameterTuningJobAnalytics
from sagemaker.deserializers import JSONDeserializer
from sagemaker.estimator import Estimator
from sagemaker.serializers import BaseSerializer
from sagemaker.tuner import ContinuousParameter, IntegerParameter, HyperparameterTuner
from tests.integ import datasets, DATA_DIR, TUNING_DEFAULT_TIMEOUT_MINUTES
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
Expand Down Expand Up @@ -213,14 +214,17 @@ def _create_training_inputs(sagemaker_session):


def _make_prediction(predictor, data):
predictor.serializer = _prediction_data_serializer
predictor.content_type = "application/json"
predictor.serializer = PredictionDataSerializer()
predictor.deserializer = JSONDeserializer()
return predictor.predict(data)


def _prediction_data_serializer(data):
js = {"instances": []}
for row in data:
js["instances"].append({"features": row.tolist()})
return json.dumps(js)
class PredictionDataSerializer(BaseSerializer):

CONTENT_TYPE = "application/json"

def serialize(self, data):
js = {"instances": []}
for row in data:
js["instances"].append({"features": row.tolist()})
return json.dumps(js)