Skip to content

Commit cfa9c97

Browse files
authored
breaking: Add parameters to deploy and remove parameters from create_model (#1783)
1 parent a6f51b7 commit cfa9c97

File tree

18 files changed

+185
-52
lines changed

18 files changed

+185
-52
lines changed

doc/frameworks/tensorflow/upgrade_from_legacy.rst

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,6 @@ For example, if you want to use JSON serialization and deserialization:
248248
from sagemaker.deserializers import JSONDeserializer
249249
from sagemaker.serializers import JSONSerializer
250250
251-
predictor.serializer = JSONSerializer()
252-
predictor.deserializer = JSONDeserializer()
251+
predictor = model.deploy(..., serializer=JSONSerializer(), deserializer=JSONDeserializer())
253252
254253
predictor.predict(data)

doc/frameworks/xgboost/using_xgboost.rst

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,12 +192,14 @@ inference against your model.
192192

193193
.. code::
194194
195+
serializer = StringSerializer()
196+
serializer.CONTENT_TYPE = "text/libsvm"
197+
195198
predictor = estimator.deploy(
196199
initial_instance_count=1,
197-
instance_type="ml.m5.xlarge"
200+
instance_type="ml.m5.xlarge",
201+
serializer=serializer
198202
)
199-
predictor.serializer = str
200-
predictor.content_type = "text/libsvm"
201203
202204
with open("abalone") as f:
203205
payload = f.read()

doc/v2.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,13 @@ Please instantiate the objects instead.
8181
The ``update_endpoint`` argument in ``deploy()`` methods for estimators and models has been deprecated.
8282
Please use :func:`sagemaker.predictor.Predictor.update_endpoint` instead.
8383

84+
``serializer`` and ``deserializer`` in ``create_model()``
85+
---------------------------------------------------------
86+
87+
The ``serializer`` and ``deserializer`` arguments in
88+
:func:`sagemaker.estimator.Estimator.create_model` have been deprecated. Please
89+
specify serializers and deserializers in ``deploy()`` methods instead.
90+
8491
``content_type`` and ``accept`` in the Predictor Constructor
8592
------------------------------------------------------------
8693

src/sagemaker/automl/automl.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,8 @@ def deploy(
337337
self,
338338
initial_instance_count,
339339
instance_type,
340+
serializer=None,
341+
deserializer=None,
340342
candidate=None,
341343
sagemaker_session=None,
342344
name=None,
@@ -356,6 +358,16 @@ def deploy(
356358
in the ``Endpoint`` created from this ``Model``.
357359
instance_type (str): The EC2 instance type to deploy this Model to.
358360
For example, 'ml.p2.xlarge'.
361+
serializer (:class:`~sagemaker.serializers.BaseSerializer`): A
362+
serializer object, used to encode data for an inference endpoint
363+
(default: None). If ``serializer`` is not None, then
364+
``serializer`` will override the default serializer. The
365+
default serializer is set by the ``predictor_cls``.
366+
deserializer (:class:`~sagemaker.deserializers.BaseDeserializer`): A
367+
deserializer object, used to decode data from an inference
368+
endpoint (default: None). If ``deserializer`` is not None, then
369+
``deserializer`` will override the default deserializer. The
370+
default deserializer is set by the ``predictor_cls``.
359371
candidate (CandidateEstimator or dict): a CandidateEstimator used for deploying
360372
to a SageMaker Inference Pipeline. If None, the best candidate will
361373
be used. If the candidate input is a dict, a CandidateEstimator will be
@@ -405,6 +417,8 @@ def deploy(
405417
return model.deploy(
406418
initial_instance_count=initial_instance_count,
407419
instance_type=instance_type,
420+
serializer=serializer,
421+
deserializer=deserializer,
408422
endpoint_name=endpoint_name,
409423
tags=tags,
410424
wait=wait,

src/sagemaker/estimator.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@
2929
from sagemaker.debugger import DebuggerHookConfig
3030
from sagemaker.debugger import TensorBoardOutputConfig # noqa: F401 # pylint: disable=unused-import
3131
from sagemaker.debugger import get_rule_container_image_uri
32-
from sagemaker.deserializers import BytesDeserializer
3332
from sagemaker.s3 import S3Uploader, parse_s3_url
34-
from sagemaker.serializers import IdentitySerializer
3533

3634
from sagemaker.fw_utils import (
3735
tar_and_upload_dir,
@@ -670,6 +668,8 @@ def deploy(
670668
self,
671669
initial_instance_count,
672670
instance_type,
671+
serializer=None,
672+
deserializer=None,
673673
accelerator_type=None,
674674
endpoint_name=None,
675675
use_compiled_model=False,
@@ -691,6 +691,16 @@ def deploy(
691691
deploy to an endpoint for prediction.
692692
instance_type (str): Type of EC2 instance to deploy to an endpoint
693693
for prediction, for example, 'ml.c4.xlarge'.
694+
serializer (:class:`~sagemaker.serializers.BaseSerializer`): A
695+
serializer object, used to encode data for an inference endpoint
696+
(default: None). If ``serializer`` is not None, then
697+
``serializer`` will override the default serializer. The
698+
default serializer is set by the ``predictor_cls``.
699+
deserializer (:class:`~sagemaker.deserializers.BaseDeserializer`): A
700+
deserializer object, used to decode data from an inference
701+
endpoint (default: None). If ``deserializer`` is not None, then
702+
``deserializer`` will override the default deserializer. The
703+
default deserializer is set by the ``predictor_cls``.
694704
accelerator_type (str): Type of Elastic Inference accelerator to
695705
attach to an endpoint for model loading and inference, for
696706
example, 'ml.eia1.medium'. If not specified, no Elastic
@@ -753,6 +763,8 @@ def deploy(
753763
return model.deploy(
754764
instance_type=instance_type,
755765
initial_instance_count=initial_instance_count,
766+
serializer=serializer,
767+
deserializer=deserializer,
756768
accelerator_type=accelerator_type,
757769
endpoint_name=endpoint_name,
758770
tags=tags or self.tags,
@@ -1367,8 +1379,6 @@ def create_model(
13671379
role=None,
13681380
image_uri=None,
13691381
predictor_cls=None,
1370-
serializer=IdentitySerializer(),
1371-
deserializer=BytesDeserializer(),
13721382
vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT,
13731383
**kwargs
13741384
):
@@ -1386,12 +1396,6 @@ def create_model(
13861396
Defaults to the image used for training.
13871397
predictor_cls (Predictor): The predictor class to use when
13881398
deploying the model.
1389-
serializer (:class:`~sagemaker.serializers.BaseSerializer`): A
1390-
serializer object, used to encode data for an inference endpoint
1391-
(default: :class:`~sagemaker.serializers.IdentitySerializer`).
1392-
deserializer (:class:`~sagemaker.deserializers.BaseDeserializer`): A
1393-
deserializer object, used to decode data from an inference
1394-
endpoint (default: :class:`~sagemaker.deserializers.BytesDeserializer`).
13951399
vpc_config_override (dict[str, list[str]]): Optional override for VpcConfig set on
13961400
the model.
13971401
Default: use subnets and security groups from this Estimator.
@@ -1410,7 +1414,7 @@ def create_model(
14101414
if predictor_cls is None:
14111415

14121416
def predict_wrapper(endpoint, session):
1413-
return Predictor(endpoint, session, serializer, deserializer)
1417+
return Predictor(endpoint, session)
14141418

14151419
predictor_cls = predict_wrapper
14161420

src/sagemaker/model.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,8 @@ def deploy(
407407
self,
408408
initial_instance_count,
409409
instance_type,
410+
serializer=None,
411+
deserializer=None,
410412
accelerator_type=None,
411413
endpoint_name=None,
412414
tags=None,
@@ -433,6 +435,16 @@ def deploy(
433435
in the ``Endpoint`` created from this ``Model``.
434436
instance_type (str): The EC2 instance type to deploy this Model to.
435437
For example, 'ml.p2.xlarge', or 'local' for local mode.
438+
serializer (:class:`~sagemaker.serializers.BaseSerializer`): A
439+
serializer object, used to encode data for an inference endpoint
440+
(default: None). If ``serializer`` is not None, then
441+
``serializer`` will override the default serializer. The
442+
default serializer is set by the ``predictor_cls``.
443+
deserializer (:class:`~sagemaker.deserializers.BaseDeserializer`): A
444+
deserializer object, used to decode data from an inference
445+
endpoint (default: None). If ``deserializer`` is not None, then
446+
``deserializer`` will override the default deserializer. The
447+
default deserializer is set by the ``predictor_cls``.
436448
accelerator_type (str): Type of Elastic Inference accelerator to
437449
deploy this model for model loading and inference, for example,
438450
'ml.eia1.medium'. If not specified, no Elastic Inference
@@ -499,7 +511,12 @@ def deploy(
499511
)
500512

501513
if self.predictor_cls:
502-
return self.predictor_cls(self.endpoint_name, self.sagemaker_session)
514+
predictor = self.predictor_cls(self.endpoint_name, self.sagemaker_session)
515+
if serializer:
516+
predictor.serializer = serializer
517+
if deserializer:
518+
predictor.deserializer = deserializer
519+
return predictor
503520
return None
504521

505522
def transformer(

src/sagemaker/multidatamodel.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ def deploy(
143143
self,
144144
initial_instance_count,
145145
instance_type,
146+
serializer=None,
147+
deserializer=None,
146148
accelerator_type=None,
147149
endpoint_name=None,
148150
tags=None,
@@ -171,6 +173,16 @@ def deploy(
171173
in the ``Endpoint`` created from this ``Model``.
172174
instance_type (str): The EC2 instance type to deploy this Model to.
173175
For example, 'ml.p2.xlarge', or 'local' for local mode.
176+
serializer (:class:`~sagemaker.serializers.BaseSerializer`): A
177+
serializer object, used to encode data for an inference endpoint
178+
(default: None). If ``serializer`` is not None, then
179+
``serializer`` will override the default serializer. The
180+
default serializer is set by the ``predictor_cls``.
181+
deserializer (:class:`~sagemaker.deserializers.BaseDeserializer`): A
182+
deserializer object, used to decode data from an inference
183+
endpoint (default: None). If ``deserializer`` is not None, then
184+
``deserializer`` will override the default deserializer. The
185+
default deserializer is set by the ``predictor_cls``.
174186
accelerator_type (str): Type of Elastic Inference accelerator to
175187
deploy this model for model loading and inference, for example,
176188
'ml.eia1.medium'. If not specified, no Elastic Inference
@@ -201,12 +213,12 @@ def deploy(
201213
enable_network_isolation = self.model.enable_network_isolation()
202214
role = self.model.role
203215
vpc_config = self.model.vpc_config
204-
predictor = self.model.predictor_cls
216+
predictor_cls = self.model.predictor_cls
205217
else:
206218
enable_network_isolation = self.enable_network_isolation()
207219
role = self.role
208220
vpc_config = self.vpc_config
209-
predictor = self.predictor_cls
221+
predictor_cls = self.predictor_cls
210222

211223
if role is None:
212224
raise ValueError("Role can not be null for deploying a model")
@@ -245,8 +257,13 @@ def deploy(
245257
data_capture_config_dict=data_capture_config_dict,
246258
)
247259

248-
if predictor:
249-
return predictor(self.endpoint_name, self.sagemaker_session)
260+
if predictor_cls:
261+
predictor = predictor_cls(self.endpoint_name, self.sagemaker_session)
262+
if serializer:
263+
predictor.serializer = serializer
264+
if deserializer:
265+
predictor.deserializer = deserializer
266+
return predictor
250267
return None
251268

252269
def add_model(self, model_data_source, model_data_path=None):

src/sagemaker/pipeline.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ def deploy(
8585
self,
8686
initial_instance_count,
8787
instance_type,
88+
serializer=None,
89+
deserializer=None,
8890
endpoint_name=None,
8991
tags=None,
9092
wait=True,
@@ -110,6 +112,16 @@ def deploy(
110112
in the ``Endpoint`` created from this ``Model``.
111113
instance_type (str): The EC2 instance type to deploy this Model to.
112114
For example, 'ml.p2.xlarge'.
115+
serializer (:class:`~sagemaker.serializers.BaseSerializer`): A
116+
serializer object, used to encode data for an inference endpoint
117+
(default: None). If ``serializer`` is not None, then
118+
``serializer`` will override the default serializer. The
119+
default serializer is set by the ``predictor_cls``.
120+
deserializer (:class:`~sagemaker.deserializers.BaseDeserializer`): A
121+
deserializer object, used to decode data from an inference
122+
endpoint (default: None). If ``deserializer`` is not None, then
123+
``deserializer`` will override the default deserializer. The
124+
default deserializer is set by the ``predictor_cls``.
113125
endpoint_name (str): The name of the endpoint to create (default:
114126
None). If not specified, a unique endpoint name will be created.
115127
tags (List[dict[str, str]]): The list of tags to attach to this
@@ -171,7 +183,12 @@ def deploy(
171183
)
172184

173185
if self.predictor_cls:
174-
return self.predictor_cls(self.endpoint_name, self.sagemaker_session)
186+
predictor = self.predictor_cls(self.endpoint_name, self.sagemaker_session)
187+
if serializer:
188+
predictor.serializer = serializer
189+
if deserializer:
190+
predictor.deserializer = deserializer
191+
return predictor
175192
return None
176193

177194
def _create_sagemaker_pipeline_model(self, instance_type):

src/sagemaker/predictor.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,6 @@ def __init__(
6262
self.sagemaker_session = sagemaker_session or Session()
6363
self.serializer = serializer
6464
self.deserializer = deserializer
65-
self.content_type = serializer.CONTENT_TYPE
66-
self.accept = deserializer.ACCEPT
6765
self._endpoint_config_name = self._get_endpoint_config_name()
6866
self._model_names = self._get_model_names()
6967

@@ -380,3 +378,13 @@ def _get_model_names(self):
380378
)
381379
production_variants = endpoint_config["ProductionVariants"]
382380
return [d["ModelName"] for d in production_variants]
381+
382+
@property
383+
def content_type(self):
384+
"""The MIME type of the data sent to the inference endpoint."""
385+
return self.serializer.CONTENT_TYPE
386+
387+
@property
388+
def accept(self):
389+
"""The content type(s) that are expected from the inference endpoint."""
390+
return self.deserializer.ACCEPT

src/sagemaker/tensorflow/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,8 @@ def deploy(
196196
self,
197197
initial_instance_count,
198198
instance_type,
199+
serializer=None,
200+
deserializer=None,
199201
accelerator_type=None,
200202
endpoint_name=None,
201203
tags=None,
@@ -211,6 +213,8 @@ def deploy(
211213
return super(TensorFlowModel, self).deploy(
212214
initial_instance_count=initial_instance_count,
213215
instance_type=instance_type,
216+
serializer=serializer,
217+
deserializer=deserializer,
214218
accelerator_type=accelerator_type,
215219
endpoint_name=endpoint_name,
216220
tags=tags,

src/sagemaker/tuner.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,8 @@ def deploy(
674674
self,
675675
initial_instance_count,
676676
instance_type,
677+
serializer=None,
678+
deserializer=None,
677679
accelerator_type=None,
678680
endpoint_name=None,
679681
wait=True,
@@ -693,6 +695,16 @@ def deploy(
693695
deploy to an endpoint for prediction.
694696
instance_type (str): Type of EC2 instance to deploy to an endpoint
695697
for prediction, for example, 'ml.c4.xlarge'.
698+
serializer (:class:`~sagemaker.serializers.BaseSerializer`): A
699+
serializer object, used to encode data for an inference endpoint
700+
(default: None). If ``serializer`` is not None, then
701+
``serializer`` will override the default serializer. The
702+
default serializer is set by the ``predictor_cls``.
703+
deserializer (:class:`~sagemaker.deserializers.BaseDeserializer`): A
704+
deserializer object, used to decode data from an inference
705+
endpoint (default: None). If ``deserializer`` is not None, then
706+
``deserializer`` will override the default deserializer. The
707+
default deserializer is set by the ``predictor_cls``.
696708
accelerator_type (str): Type of Elastic Inference accelerator to
697709
attach to an endpoint for model loading and inference, for
698710
example, 'ml.eia1.medium'. If not specified, no Elastic
@@ -727,6 +739,8 @@ def deploy(
727739
return best_estimator.deploy(
728740
initial_instance_count=initial_instance_count,
729741
instance_type=instance_type,
742+
serializer=serializer,
743+
deserializer=deserializer,
730744
accelerator_type=accelerator_type,
731745
endpoint_name=endpoint_name or best_training_job["TrainingJobName"],
732746
wait=wait,

0 commit comments

Comments
 (0)