Skip to content

breaking: Remove content_type and accept parameters from Predictor #1751

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Jul 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
11616f2
Remove content_type and accept parameters
Jul 24, 2020
b6643f4
Merge branch 'zwei' into remove-serde-parameters
bveeramani Jul 24, 2020
5b43cd0
Update v2.rst
Jul 24, 2020
ce69c3a
Merge branch 'remove-serde-parameters' of github.com:bveeramani/sagem…
Jul 24, 2020
a805378
Merge branch 'zwei' into remove-serde-parameters
bveeramani Jul 24, 2020
7a79f55
Add IdentitySerializer unit test
Jul 24, 2020
3fb26c6
Merge branch 'zwei' into remove-serde-parameters
Jul 28, 2020
dcfc735
Fix tests
Jul 28, 2020
c46b00c
Clean code
Jul 28, 2020
5fb78cc
Improve code
Jul 28, 2020
11dcdd2
Format code
Jul 28, 2020
9b2aace
Merge branch 'zwei' into remove-serde-parameters
bveeramani Jul 29, 2020
1fc531a
Fix tests
Jul 29, 2020
f4eacaf
Address review comments
Jul 29, 2020
504a49f
Fix docs
Jul 29, 2020
7c0cf42
Merge branch 'zwei' into remove-serde-parameters
bveeramani Jul 29, 2020
6d84df9
Update test_multi_variant_endpoint.py
Jul 29, 2020
d91ce31
Fix integration tests
Jul 29, 2020
1f2de10
Update test_tfs.py
Jul 29, 2020
4eec9a7
Merge branch 'zwei' into remove-serde-parameters
bveeramani Jul 29, 2020
96381cf
Merge branch 'zwei' into remove-serde-parameters
bveeramani Jul 29, 2020
d7b0a5d
Merge branch 'zwei' into remove-serde-parameters
bveeramani Jul 30, 2020
46cbd65
Update test_tfs.py
Jul 30, 2020
dea357b
Merge branch 'zwei' into remove-serde-parameters
bveeramani Jul 30, 2020
7e026cb
Merge branch 'zwei' into remove-serde-parameters
Jul 30, 2020
d16319e
Merge branch 'zwei' into remove-serde-parameters
bveeramani Jul 30, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion doc/v2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,18 @@ Please instantiate the objects instead.
The ``update_endpoint`` argument in ``deploy()`` methods for estimators and models has been deprecated.
Please use :func:`sagemaker.predictor.Predictor.update_endpoint` instead.

``content_type`` and ``accept`` in the Predictor Constructor
------------------------------------------------------------

The ``content_type`` and ``accept`` parameters have been removed from the
following classes and methods:
- ``sagemaker.predictor.Predictor``
- ``sagemaker.estimator.Estimator.create_model``
- ``sagemaker.algorithms.AlgorithmEstimator.create_model``
- ``sagemaker.tensorflow.model.TensorFlowPredictor``

Please specify content types in a serializer or deserializer class instead.

``sagemaker.content_types``
---------------------------

Expand All @@ -115,7 +127,6 @@ write MIME types as a string,
| ``CONTENT_TYPE_NPY`` | ``"application/x-npy"`` |
+-------------------------------+--------------------------------+


Require ``framework_version`` and ``py_version`` for Frameworks
===============================================================

Expand Down
34 changes: 14 additions & 20 deletions src/sagemaker/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
import sagemaker
import sagemaker.parameter
from sagemaker import vpc_utils
from sagemaker.deserializers import BytesDeserializer
from sagemaker.estimator import EstimatorBase
from sagemaker.serializers import IdentitySerializer
from sagemaker.transformer import Transformer
from sagemaker.predictor import Predictor

Expand Down Expand Up @@ -251,37 +253,29 @@ def create_model(
self,
role=None,
predictor_cls=None,
serializer=None,
deserializer=None,
content_type=None,
accept=None,
serializer=IdentitySerializer(),
deserializer=BytesDeserializer(),
vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT,
**kwargs
):
"""Create a model to deploy.

The serializer, deserializer, content_type, and accept arguments are
only used to define a default Predictor. They are ignored if an
explicit predictor class is passed in. Other arguments are passed
through to the Model class.
The serializer and deserializer are only used to define a default
Predictor. They are ignored if an explicit predictor class is passed in.
Other arguments are passed through to the Model class.

Args:
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``,
which is also used during transform jobs. If not specified, the
role from the Estimator will be used.
predictor_cls (Predictor): The predictor class to use when
deploying the model.
serializer (callable): Should accept a single argument, the input
data, and return a sequence of bytes. May provide a content_type
attribute that defines the endpoint request content type
deserializer (callable): Should accept two arguments, the result
data and the response content type, and return a sequence of
bytes. May provide a content_type attribute that defines the
endpoint response Accept content type.
content_type (str): The invocation ContentType, overriding any
content_type from the serializer
accept (str): The invocation Accept, overriding any accept from the
deserializer.
serializer (:class:`~sagemaker.serializers.BaseSerializer`): A
serializer object, used to encode data for an inference endpoint
(default: :class:`~sagemaker.serializers.IdentitySerializer`).
deserializer (:class:`~sagemaker.deserializers.BaseDeserializer`): A
deserializer object, used to decode data from an inference
endpoint (default: :class:`~sagemaker.deserializers.BytesDeserializer`).
vpc_config_override (dict[str, list[str]]): Optional override for VpcConfig set on
the model. Default: use subnets and security groups from this Estimator.
* 'Subnets' (list[str]): List of subnet ids.
Expand All @@ -300,7 +294,7 @@ def create_model(
if predictor_cls is None:

def predict_wrapper(endpoint, session):
return Predictor(endpoint, session, serializer, deserializer, content_type, accept)
return Predictor(endpoint, session, serializer, deserializer)

predictor_cls = predict_wrapper

Expand Down
29 changes: 12 additions & 17 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
from sagemaker.debugger import DebuggerHookConfig
from sagemaker.debugger import TensorBoardOutputConfig # noqa: F401 # pylint: disable=unused-import
from sagemaker.debugger import get_rule_container_image_uri
from sagemaker.deserializers import BytesDeserializer
from sagemaker.s3 import S3Uploader, parse_s3_url
from sagemaker.serializers import IdentitySerializer

from sagemaker.fw_utils import (
tar_and_upload_dir,
Expand Down Expand Up @@ -1340,16 +1342,14 @@ def create_model(
role=None,
image_uri=None,
predictor_cls=None,
serializer=None,
deserializer=None,
content_type=None,
accept=None,
serializer=IdentitySerializer(),
deserializer=BytesDeserializer(),
vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT,
**kwargs
):
"""Create a model to deploy.

The serializer, deserializer, content_type, and accept arguments are only used to define a
The serializer and deserializer arguments are only used to define a
default Predictor. They are ignored if an explicit predictor class is passed in.
Other arguments are passed through to the Model class.

Expand All @@ -1361,17 +1361,12 @@ def create_model(
Defaults to the image used for training.
predictor_cls (Predictor): The predictor class to use when
deploying the model.
serializer (callable): Should accept a single argument, the input
data, and return a sequence of bytes. May provide a content_type
attribute that defines the endpoint request content type
deserializer (callable): Should accept two arguments, the result
data and the response content type, and return a sequence of
bytes. May provide a content_type attribute that defines th
endpoint response Accept content type.
content_type (str): The invocation ContentType, overriding any
content_type from the serializer
accept (str): The invocation Accept, overriding any accept from the
deserializer.
serializer (:class:`~sagemaker.serializers.BaseSerializer`): A
serializer object, used to encode data for an inference endpoint
(default: :class:`~sagemaker.serializers.IdentitySerializer`).
deserializer (:class:`~sagemaker.deserializers.BaseDeserializer`): A
deserializer object, used to decode data from an inference
endpoint (default: :class:`~sagemaker.deserializers.BytesDeserializer`).
vpc_config_override (dict[str, list[str]]): Optional override for VpcConfig set on
the model.
Default: use subnets and security groups from this Estimator.
Expand All @@ -1390,7 +1385,7 @@ def create_model(
if predictor_cls is None:

def predict_wrapper(endpoint, session):
return Predictor(endpoint, session, serializer, deserializer, content_type, accept)
return Predictor(endpoint, session, serializer, deserializer)

predictor_cls = predict_wrapper

Expand Down
41 changes: 16 additions & 25 deletions src/sagemaker/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
"""Placeholder docstring"""
from __future__ import print_function, absolute_import

from sagemaker.deserializers import BytesDeserializer
from sagemaker.model_monitor import DataCaptureConfig
from sagemaker.serializers import IdentitySerializer
from sagemaker.session import production_variant, Session
from sagemaker.utils import name_from_base

Expand All @@ -31,10 +33,8 @@ def __init__(
self,
endpoint_name,
sagemaker_session=None,
serializer=None,
deserializer=None,
content_type=None,
accept=None,
serializer=IdentitySerializer(),
deserializer=BytesDeserializer(),
):
"""Initialize a ``Predictor``.

Expand All @@ -51,23 +51,19 @@ def __init__(
object, used for SageMaker interactions (default: None). If not
specified, one is created using the default AWS configuration
chain.
serializer (sagemaker.serializers.BaseSerializer): A serializer
object, used to encode data for an inference endpoint
(default: None).
deserializer (sagemaker.deserializers.BaseDeserializer): A
serializer (:class:`~sagemaker.serializers.BaseSerializer`): A
serializer object, used to encode data for an inference endpoint
(default: :class:`~sagemaker.serializers.IdentitySerializer`).
deserializer (:class:`~sagemaker.deserializers.BaseDeserializer`): A
deserializer object, used to decode data from an inference
endpoint (default: None).
content_type (str): The invocation's "ContentType", overriding any
``CONTENT_TYPE`` from the serializer (default: None).
accept (str): The invocation's "Accept", overriding any accept from
the deserializer (default: None).
endpoint (default: :class:`~sagemaker.deserializers.BytesDeserializer`).
"""
self.endpoint_name = endpoint_name
self.sagemaker_session = sagemaker_session or Session()
self.serializer = serializer
self.deserializer = deserializer
self.content_type = content_type or getattr(serializer, "CONTENT_TYPE", None)
self.accept = accept or getattr(deserializer, "ACCEPT", None)
self.content_type = serializer.CONTENT_TYPE
self.accept = deserializer.ACCEPT
self._endpoint_config_name = self._get_endpoint_config_name()
self._model_names = self._get_model_names()

Expand Down Expand Up @@ -107,12 +103,8 @@ def _handle_response(self, response):
response:
"""
response_body = response["Body"]
if self.deserializer is not None:
# It's the deserializer's responsibility to close the stream
return self.deserializer.deserialize(response_body, response["ContentType"])
data = response_body.read()
response_body.close()
return data
content_type = response.get("ContentType", "application/octet-stream")
return self.deserializer.deserialize(response_body, content_type)

def _create_request_args(self, data, initial_args=None, target_model=None, target_variant=None):
"""
Expand All @@ -127,10 +119,10 @@ def _create_request_args(self, data, initial_args=None, target_model=None, targe
if "EndpointName" not in args:
args["EndpointName"] = self.endpoint_name

if self.content_type and "ContentType" not in args:
if "ContentType" not in args:
args["ContentType"] = self.content_type

if self.accept and "Accept" not in args:
if "Accept" not in args:
args["Accept"] = self.accept

if target_model:
Expand All @@ -139,8 +131,7 @@ def _create_request_args(self, data, initial_args=None, target_model=None, targe
if target_variant:
args["TargetVariant"] = target_variant

if self.serializer is not None:
data = self.serializer.serialize(data)
data = self.serializer.serialize(data)

args["Body"] = data
return args
Expand Down
29 changes: 29 additions & 0 deletions src/sagemaker/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,35 @@ def serialize(self, data):
return json.dumps(data)


class IdentitySerializer(BaseSerializer):
"""Serialize data by returning data without modification."""

def __init__(self, content_type="application/octet-stream"):
"""Initialize the ``content_type`` attribute.

Args:
content_type (str): The MIME type of the serialized data (default:
"application/octet-stream").
"""
self.content_type = content_type

def serialize(self, data):
"""Return data without modification.

Args:
data (object): Data to be serialized.

Returns:
object: The unmodified data.
"""
return data

@property
def CONTENT_TYPE(self):
"""The MIME type of the data sent to the inference endpoint."""
return self.content_type


class JSONLinesSerializer(BaseSerializer):
"""Serialize data to a JSON Lines formatted string."""

Expand Down
1 change: 0 additions & 1 deletion src/sagemaker/sparkml/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def __init__(self, endpoint_name, sagemaker_session=None):
endpoint_name=endpoint_name,
sagemaker_session=sagemaker_session,
serializer=CSVSerializer(),
content_type="text/csv",
)


Expand Down
6 changes: 1 addition & 5 deletions src/sagemaker/tensorflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def __init__(
sagemaker_session=None,
serializer=JSONSerializer(),
deserializer=JSONDeserializer(),
content_type=None,
model_name=None,
model_version=None,
):
Expand All @@ -51,9 +50,6 @@ def __init__(
json. Handles dicts, lists, and numpy arrays.
deserializer (callable): Optional. Default parses the response using
``json.load(...)``.
content_type (str): Optional. The "ContentType" for invocation
requests. If specified, overrides the ``content_type`` from the
serializer (default: None).
model_name (str): Optional. The name of the SavedModel model that
should handle the request. If not specified, the endpoint's
default model will handle the request.
Expand All @@ -62,7 +58,7 @@ def __init__(
version of the model will be used.
"""
super(TensorFlowPredictor, self).__init__(
endpoint_name, sagemaker_session, serializer, deserializer, content_type
endpoint_name, sagemaker_session, serializer, deserializer,
)

attributes = []
Expand Down
2 changes: 2 additions & 0 deletions tests/integ/test_byo_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def test_byo_estimator(sagemaker_session, region, cpu_instance_type, training_se
predictor = model.deploy(1, cpu_instance_type, endpoint_name=job_name)
predictor.serializer = _FactorizationMachineSerializer()
predictor.deserializer = sagemaker.deserializers.JSONDeserializer()
predictor.content_type = "application/json"
Comment on lines 92 to +94
Copy link
Contributor Author

@bveeramani bveeramani Jul 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my next PR, I'm going to add serializer and deserializer parameters to deploy. In that PR, I'll remove predictor.content_type = ....


result = predictor.predict(training_set[0][:10])

Expand Down Expand Up @@ -136,6 +137,7 @@ def test_async_byo_estimator(sagemaker_session, region, cpu_instance_type, train
predictor = model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
predictor.serializer = _FactorizationMachineSerializer()
predictor.deserializer = sagemaker.deserializers.JSONDeserializer()
predictor.content_type = "application/json"
Comment on lines 138 to +140
Copy link
Contributor Author

@bveeramani bveeramani Jul 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my next PR, I'm going to add serializer and deserializer parameters to deploy. In that PR, I'll remove predictor.content_type = ....


result = predictor.predict(training_set[0][:10])

Expand Down
6 changes: 2 additions & 4 deletions tests/integ/test_multi_variant_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import scipy.stats as st

from sagemaker import image_uris
from sagemaker.deserializers import CSVDeserializer
from sagemaker.s3 import S3Uploader
from sagemaker.session import production_variant
from sagemaker.sparkml import SparkMLModel
Expand Down Expand Up @@ -173,8 +174,6 @@ def test_predict_invocation_with_target_variant(sagemaker_session, multi_variant
endpoint_name=multi_variant_endpoint.endpoint_name,
sagemaker_session=sagemaker_session,
serializer=CSVSerializer(),
content_type="text/csv",
accept="text/csv",
)

# Validate that no exception is raised when the target_variant is specified.
Expand Down Expand Up @@ -301,8 +300,7 @@ def test_predict_invocation_with_target_variant_local_mode(
endpoint_name=multi_variant_endpoint.endpoint_name,
sagemaker_session=sagemaker_session,
serializer=CSVSerializer(),
content_type="text/csv",
accept="text/csv",
deserializer=CSVDeserializer(),
)

# Validate that no exception is raised when the target_variant is specified.
Expand Down
Loading