Skip to content

Commit 11616f2

Browse files
author
Balaji Veeramani
committed
Remove content_type and accept parameters
1 parent 6b76093 commit 11616f2

File tree

8 files changed

+79
-119
lines changed

8 files changed

+79
-119
lines changed

src/sagemaker/algorithm.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
import sagemaker
1717
import sagemaker.parameter
1818
from sagemaker import vpc_utils
19+
from sagemaker.deserializers import BytesDeserializer
1920
from sagemaker.estimator import EstimatorBase
21+
from sagemaker.serializers import IdentitySerializer
2022
from sagemaker.transformer import Transformer
2123
from sagemaker.predictor import Predictor
2224

@@ -251,37 +253,29 @@ def create_model(
251253
self,
252254
role=None,
253255
predictor_cls=None,
254-
serializer=None,
255-
deserializer=None,
256-
content_type=None,
257-
accept=None,
256+
serializer=IdentitySerializer(),
257+
deserializer=BytesDeserializer(),
258258
vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT,
259259
**kwargs
260260
):
261261
"""Create a model to deploy.
262262
263-
The serializer, deserializer, content_type, and accept arguments are
264-
only used to define a default Predictor. They are ignored if an
265-
explicit predictor class is passed in. Other arguments are passed
266-
through to the Model class.
263+
The serializer and deserializer are only used to define a default
264+
Predictor. They are ignored if an explicit predictor class is passed in.
265+
Other arguments are passed through to the Model class.
267266
268267
Args:
269268
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``,
270269
which is also used during transform jobs. If not specified, the
271270
role from the Estimator will be used.
272271
predictor_cls (Predictor): The predictor class to use when
273272
deploying the model.
274-
serializer (callable): Should accept a single argument, the input
275-
data, and return a sequence of bytes. May provide a content_type
276-
attribute that defines the endpoint request content type
277-
deserializer (callable): Should accept two arguments, the result
278-
data and the response content type, and return a sequence of
279-
bytes. May provide a content_type attribute that defines the
280-
endpoint response Accept content type.
281-
content_type (str): The invocation ContentType, overriding any
282-
content_type from the serializer
283-
accept (str): The invocation Accept, overriding any accept from the
284-
deserializer.
273+
serializer (sagemaker.serializers.BaseSerializer): A serializer
274+
object, used to encode data for an inference endpoint
275+
(default: ``IdentitySerializer()``).
276+
deserializer (sagemaker.deserializers.BaseDeserializer): A
277+
deserializer object, used to decode data from an inference
278+
endpoint (default: ``BytesDeserializer()``).
285279
vpc_config_override (dict[str, list[str]]): Optional override for VpcConfig set on
286280
the model. Default: use subnets and security groups from this Estimator.
287281
* 'Subnets' (list[str]): List of subnet ids.
@@ -300,7 +294,7 @@ def create_model(
300294
if predictor_cls is None:
301295

302296
def predict_wrapper(endpoint, session):
303-
return Predictor(endpoint, session, serializer, deserializer, content_type, accept)
297+
return Predictor(endpoint, session, serializer, deserializer)
304298

305299
predictor_cls = predict_wrapper
306300

src/sagemaker/estimator.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030
from sagemaker.debugger import DebuggerHookConfig
3131
from sagemaker.debugger import TensorBoardOutputConfig # noqa: F401 # pylint: disable=unused-import
3232
from sagemaker.debugger import get_rule_container_image_uri
33+
from sagemaker.deserializers import BytesDeserializer
3334
from sagemaker.s3 import S3Uploader
35+
from sagemaker.serializers import IdentitySerializer
3436

3537
from sagemaker.fw_utils import (
3638
tar_and_upload_dir,
@@ -1343,16 +1345,14 @@ def create_model(
13431345
role=None,
13441346
image_uri=None,
13451347
predictor_cls=None,
1346-
serializer=None,
1347-
deserializer=None,
1348-
content_type=None,
1349-
accept=None,
1348+
serializer=IdentitySerializer(),
1349+
deserializer=BytesDeserializer(),
13501350
vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT,
13511351
**kwargs
13521352
):
13531353
"""Create a model to deploy.
13541354
1355-
The serializer, deserializer, content_type, and accept arguments are only used to define a
1355+
The serializer and deserializer arguments are only used to define a
13561356
default Predictor. They are ignored if an explicit predictor class is passed in.
13571357
Other arguments are passed through to the Model class.
13581358
@@ -1364,17 +1364,12 @@ def create_model(
13641364
Defaults to the image used for training.
13651365
predictor_cls (Predictor): The predictor class to use when
13661366
deploying the model.
1367-
serializer (callable): Should accept a single argument, the input
1368-
data, and return a sequence of bytes. May provide a content_type
1369-
attribute that defines the endpoint request content type
1370-
deserializer (callable): Should accept two arguments, the result
1371-
data and the response content type, and return a sequence of
1372-
bytes. May provide a content_type attribute that defines th
1373-
endpoint response Accept content type.
1374-
content_type (str): The invocation ContentType, overriding any
1375-
content_type from the serializer
1376-
accept (str): The invocation Accept, overriding any accept from the
1377-
deserializer.
1367+
serializer (sagemaker.serializers.BaseSerializer): A serializer
1368+
object, used to encode data for an inference endpoint
1369+
(default: ``IdentitySerializer()``).
1370+
deserializer (sagemaker.deserializers.BaseDeserializer): A
1371+
deserializer object, used to decode data from an inference
1372+
endpoint (default: ``BytesDeserializer()``).
13781373
vpc_config_override (dict[str, list[str]]): Optional override for VpcConfig set on
13791374
the model.
13801375
Default: use subnets and security groups from this Estimator.
@@ -1393,7 +1388,7 @@ def create_model(
13931388
if predictor_cls is None:
13941389

13951390
def predict_wrapper(endpoint, session):
1396-
return Predictor(endpoint, session, serializer, deserializer, content_type, accept)
1391+
return Predictor(endpoint, session, serializer, deserializer)
13971392

13981393
predictor_cls = predict_wrapper
13991394

src/sagemaker/predictor.py

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
"""Placeholder docstring"""
1414
from __future__ import print_function, absolute_import
1515

16-
from sagemaker.deserializers import BaseDeserializer
16+
from sagemaker.deserializers import BaseDeserializer, BytesDeserializer
1717
from sagemaker.model_monitor import DataCaptureConfig
18-
from sagemaker.serializers import BaseSerializer
18+
from sagemaker.serializers import BaseSerializer, IdentitySerializer
1919
from sagemaker.session import production_variant, Session
2020
from sagemaker.utils import name_from_base
2121

@@ -33,10 +33,8 @@ def __init__(
3333
self,
3434
endpoint_name,
3535
sagemaker_session=None,
36-
serializer=None,
37-
deserializer=None,
38-
content_type=None,
39-
accept=None,
36+
serializer=IdentitySerializer(),
37+
deserializer=BytesDeserializer(),
4038
):
4139
"""Initialize a ``Predictor``.
4240
@@ -55,14 +53,10 @@ def __init__(
5553
chain.
5654
serializer (sagemaker.serializers.BaseSerializer): A serializer
5755
object, used to encode data for an inference endpoint
58-
(default: None).
56+
(default: ``IdentitySerializer()``).
5957
deserializer (sagemaker.deserializers.BaseDeserializer): A
6058
deserializer object, used to decode data from an inference
61-
endpoint (default: None).
62-
content_type (str): The invocation's "ContentType", overriding any
63-
``CONTENT_TYPE`` from the serializer (default: None).
64-
accept (str): The invocation's "Accept", overriding any accept from
65-
the deserializer (default: None).
59+
endpoint (default: ``BytesDeserializer()``).
6660
"""
6761
if serializer is not None and not isinstance(serializer, BaseSerializer):
6862
serializer = LegacySerializer(serializer)
@@ -73,8 +67,8 @@ def __init__(
7367
self.sagemaker_session = sagemaker_session or Session()
7468
self.serializer = serializer
7569
self.deserializer = deserializer
76-
self.content_type = content_type or getattr(serializer, "CONTENT_TYPE", None)
77-
self.accept = accept or getattr(deserializer, "ACCEPT", None)
70+
self.content_type = serializer.CONTENT_TYPE
71+
self.accept = deserializer.ACCEPT
7872
self._endpoint_config_name = self._get_endpoint_config_name()
7973
self._model_names = self._get_model_names()
8074

@@ -114,14 +108,10 @@ def _handle_response(self, response):
114108
response:
115109
"""
116110
response_body = response["Body"]
117-
if self.deserializer is not None:
118-
if not isinstance(self.deserializer, BaseDeserializer):
119-
self.deserializer = LegacyDeserializer(self.deserializer)
120-
# It's the deserializer's responsibility to close the stream
121-
return self.deserializer.deserialize(response_body, response["ContentType"])
122-
data = response_body.read()
123-
response_body.close()
124-
return data
111+
content_type = response.get("ContentType", "application/octet-stream")
112+
if not isinstance(self.deserializer, BaseDeserializer):
113+
self.deserializer = LegacyDeserializer(self.deserializer)
114+
return self.deserializer.deserialize(response_body, content_type)
125115

126116
def _create_request_args(self, data, initial_args=None, target_model=None, target_variant=None):
127117
"""
@@ -136,10 +126,10 @@ def _create_request_args(self, data, initial_args=None, target_model=None, targe
136126
if "EndpointName" not in args:
137127
args["EndpointName"] = self.endpoint_name
138128

139-
if self.content_type and "ContentType" not in args:
129+
if "ContentType" not in args:
140130
args["ContentType"] = self.content_type
141131

142-
if self.accept and "Accept" not in args:
132+
if "Accept" not in args:
143133
args["Accept"] = self.accept
144134

145135
if target_model:
@@ -148,10 +138,9 @@ def _create_request_args(self, data, initial_args=None, target_model=None, targe
148138
if target_variant:
149139
args["TargetVariant"] = target_variant
150140

151-
if self.serializer is not None:
152-
if not isinstance(self.serializer, BaseSerializer):
153-
self.serializer = LegacySerializer(self.serializer)
154-
data = self.serializer.serialize(data)
141+
if not isinstance(self.serializer, BaseSerializer):
142+
self.serializer = LegacySerializer(self.serializer)
143+
data = self.serializer.serialize(data)
155144

156145
args["Body"] = data
157146
return args

src/sagemaker/serializers.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,3 +183,20 @@ def serialize(self, data):
183183
return json.dumps(data.tolist())
184184

185185
return json.dumps(data)
186+
187+
188+
class IdentitySerializer(BaseSerializer):
189+
"""Serialize data by returning data without modification."""
190+
191+
CONTENT_TYPE = "application/json"
192+
193+
def serialize(self, data):
194+
"""Return data without modification.
195+
196+
Args:
197+
data (object): Data to be serialized.
198+
199+
Returns:
200+
object: The unmodified data.
201+
"""
202+
return data

src/sagemaker/sparkml/model.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker import Model, Predictor, Session, image_uris
17-
from sagemaker.content_types import CONTENT_TYPE_CSV
1817
from sagemaker.serializers import CSVSerializer
1918

2019
framework_name = "sparkml-serving"
@@ -50,7 +49,6 @@ def __init__(self, endpoint_name, sagemaker_session=None):
5049
endpoint_name=endpoint_name,
5150
sagemaker_session=sagemaker_session,
5251
serializer=CSVSerializer(),
53-
content_type=CONTENT_TYPE_CSV,
5452
)
5553

5654

src/sagemaker/tensorflow/model.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ def __init__(
3434
sagemaker_session=None,
3535
serializer=JSONSerializer(),
3636
deserializer=JSONDeserializer(),
37-
content_type=None,
3837
model_name=None,
3938
model_version=None,
4039
):
@@ -52,9 +51,6 @@ def __init__(
5251
json. Handles dicts, lists, and numpy arrays.
5352
deserializer (callable): Optional. Default parses the response using
5453
``json.load(...)``.
55-
content_type (str): Optional. The "ContentType" for invocation
56-
requests. If specified, overrides the ``content_type`` from the
57-
serializer (default: None).
5854
model_name (str): Optional. The name of the SavedModel model that
5955
should handle the request. If not specified, the endpoint's
6056
default model will handle the request.
@@ -63,7 +59,7 @@ def __init__(
6359
version of the model will be used.
6460
"""
6561
super(TensorFlowPredictor, self).__init__(
66-
endpoint_name, sagemaker_session, serializer, deserializer, content_type
62+
endpoint_name, sagemaker_session, serializer, deserializer,
6763
)
6864

6965
attributes = []

tests/unit/sagemaker/tensorflow/test_tfs.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import pytest
2121
from mock import Mock, patch
2222

23-
from sagemaker.serializers import CSVSerializer
23+
from sagemaker.serializers import CSVSerializer, IdentitySerializer
2424
from sagemaker.tensorflow import TensorFlow, TensorFlowModel, TensorFlowPredictor
2525

2626
JSON_CONTENT_TYPE = "application/json"
@@ -318,9 +318,8 @@ def test_predictor(sagemaker_session):
318318

319319

320320
def test_predictor_jsons(sagemaker_session):
321-
predictor = TensorFlowPredictor(
322-
"endpoint", sagemaker_session, serializer=None, content_type="application/jsons"
323-
)
321+
predictor = TensorFlowPredictor("endpoint", sagemaker_session, serializer=IdentitySerializer())
322+
predictor.content_type = "application/jsons"
324323

325324
mock_response(json.dumps(PREDICT_RESPONSE).encode("utf-8"), sagemaker_session)
326325
result = predictor.predict("[1.0, 2.0, 3.0]\n[4.0, 5.0, 6.0]")

0 commit comments

Comments
 (0)