Skip to content

Commit 5e68e01

Browse files
author
Chuyang Deng
committed
Merge branch 'json-update-script' of github.com:ChuyangDeng/sagemaker-python-sdk into json-update-script
2 parents e0cc4e1 + e8e4f4e commit 5e68e01

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+352
-1652
lines changed

doc/v2.rst

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,18 @@ Please instantiate the objects instead.
9494
The ``update_endpoint`` argument in ``deploy()`` methods for estimators and models has been deprecated.
9595
Please use :func:`sagemaker.predictor.Predictor.update_endpoint` instead.
9696

97+
``content_type`` and ``accept`` in the Predictor Constructor
98+
------------------------------------------------------------
99+
100+
The ``content_type`` and ``accept`` parameters have been removed from the
101+
following classes and methods:
102+
- ``sagemaker.predictor.Predictor``
103+
- ``sagemaker.estimator.Estimator.create_model``
104+
- ``sagemaker.algorithms.AlgorithmEstimator.create_model``
105+
- ``sagemaker.tensorflow.model.TensorFlowPredictor``
106+
107+
Please specify content types in a serializer or deserializer class instead.
108+
97109
``sagemaker.content_types``
98110
---------------------------
99111

@@ -115,7 +127,6 @@ write MIME types as a string,
115127
| ``CONTENT_TYPE_NPY`` | ``"application/x-npy"`` |
116128
+-------------------------------+--------------------------------+
117129

118-
119130
Require ``framework_version`` and ``py_version`` for Frameworks
120131
===============================================================
121132

src/sagemaker/algorithm.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
import sagemaker
1717
import sagemaker.parameter
1818
from sagemaker import vpc_utils
19+
from sagemaker.deserializers import BytesDeserializer
1920
from sagemaker.estimator import EstimatorBase
21+
from sagemaker.serializers import IdentitySerializer
2022
from sagemaker.transformer import Transformer
2123
from sagemaker.predictor import Predictor
2224

@@ -251,37 +253,29 @@ def create_model(
251253
self,
252254
role=None,
253255
predictor_cls=None,
254-
serializer=None,
255-
deserializer=None,
256-
content_type=None,
257-
accept=None,
256+
serializer=IdentitySerializer(),
257+
deserializer=BytesDeserializer(),
258258
vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT,
259259
**kwargs
260260
):
261261
"""Create a model to deploy.
262262
263-
The serializer, deserializer, content_type, and accept arguments are
264-
only used to define a default Predictor. They are ignored if an
265-
explicit predictor class is passed in. Other arguments are passed
266-
through to the Model class.
263+
The serializer and deserializer are only used to define a default
264+
Predictor. They are ignored if an explicit predictor class is passed in.
265+
Other arguments are passed through to the Model class.
267266
268267
Args:
269268
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``,
270269
which is also used during transform jobs. If not specified, the
271270
role from the Estimator will be used.
272271
predictor_cls (Predictor): The predictor class to use when
273272
deploying the model.
274-
serializer (callable): Should accept a single argument, the input
275-
data, and return a sequence of bytes. May provide a content_type
276-
attribute that defines the endpoint request content type
277-
deserializer (callable): Should accept two arguments, the result
278-
data and the response content type, and return a sequence of
279-
bytes. May provide a content_type attribute that defines the
280-
endpoint response Accept content type.
281-
content_type (str): The invocation ContentType, overriding any
282-
content_type from the serializer
283-
accept (str): The invocation Accept, overriding any accept from the
284-
deserializer.
273+
serializer (:class:`~sagemaker.serializers.BaseSerializer`): A
274+
serializer object, used to encode data for an inference endpoint
275+
(default: :class:`~sagemaker.serializers.IdentitySerializer`).
276+
deserializer (:class:`~sagemaker.deserializers.BaseDeserializer`): A
277+
deserializer object, used to decode data from an inference
278+
endpoint (default: :class:`~sagemaker.deserializers.BytesDeserializer`).
285279
vpc_config_override (dict[str, list[str]]): Optional override for VpcConfig set on
286280
the model. Default: use subnets and security groups from this Estimator.
287281
* 'Subnets' (list[str]): List of subnet ids.
@@ -300,7 +294,7 @@ def create_model(
300294
if predictor_cls is None:
301295

302296
def predict_wrapper(endpoint, session):
303-
return Predictor(endpoint, session, serializer, deserializer, content_type, accept)
297+
return Predictor(endpoint, session, serializer, deserializer)
304298

305299
predictor_cls = predict_wrapper
306300

src/sagemaker/chainer/estimator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
class Chainer(Framework):
3333
"""Handle end-to-end training and deployment of custom Chainer code."""
3434

35-
__framework_name__ = "chainer"
35+
_framework_name = "chainer"
3636

3737
# Hyperparameters
3838
_use_mpi = "sagemaker_use_mpi"
@@ -131,7 +131,7 @@ def __init__(
131131
validate_version_or_image_args(framework_version, py_version, image_uri)
132132
if py_version == "py2":
133133
logger.warning(
134-
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
134+
python_deprecation_warning(self._framework_name, defaults.LATEST_PY2_VERSION)
135135
)
136136
self.framework_version = framework_version
137137
self.py_version = py_version
@@ -272,7 +272,7 @@ class constructor
272272
init_params["image_uri"] = image_uri
273273
return init_params
274274

275-
if framework != cls.__framework_name__:
275+
if framework != cls._framework_name:
276276
raise ValueError(
277277
"Training job: {} didn't use image for requested framework".format(
278278
job_details["TrainingJobName"]

src/sagemaker/chainer/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class ChainerModel(FrameworkModel):
5959
``Endpoint``.
6060
"""
6161

62-
__framework_name__ = "chainer"
62+
_framework_name = "chainer"
6363

6464
def __init__(
6565
self,
@@ -116,7 +116,7 @@ def __init__(
116116
validate_version_or_image_args(framework_version, py_version, image_uri)
117117
if py_version == "py2":
118118
logger.warning(
119-
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
119+
python_deprecation_warning(self._framework_name, defaults.LATEST_PY2_VERSION)
120120
)
121121
self.framework_version = framework_version
122122
self.py_version = py_version
@@ -176,7 +176,7 @@ def serving_image_uri(self, region_name, instance_type, accelerator_type=None):
176176
177177
"""
178178
return image_uris.retrieve(
179-
self.__framework_name__,
179+
self._framework_name,
180180
region_name,
181181
version=self.framework_version,
182182
py_version=self.py_version,

src/sagemaker/debugger.py

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -23,32 +23,9 @@
2323

2424
import smdebug_rulesconfig as rule_configs # noqa: F401 # pylint: disable=unused-import
2525

26-
from sagemaker.utils import get_ecr_image_uri_prefix
27-
28-
RULES_ECR_REPO_NAME = "sagemaker-debugger-rules"
29-
30-
SAGEMAKER_RULE_CONTAINERS_ACCOUNTS_MAP = {
31-
"eu-north-1": {RULES_ECR_REPO_NAME: "314864569078"},
32-
"me-south-1": {RULES_ECR_REPO_NAME: "986000313247"},
33-
"ap-south-1": {RULES_ECR_REPO_NAME: "904829902805"},
34-
"eu-west-3": {RULES_ECR_REPO_NAME: "447278800020"},
35-
"us-east-2": {RULES_ECR_REPO_NAME: "915447279597"},
36-
"eu-west-1": {RULES_ECR_REPO_NAME: "929884845733"},
37-
"eu-central-1": {RULES_ECR_REPO_NAME: "482524230118"},
38-
"sa-east-1": {RULES_ECR_REPO_NAME: "818342061345"},
39-
"ap-east-1": {RULES_ECR_REPO_NAME: "199566480951"},
40-
"us-east-1": {RULES_ECR_REPO_NAME: "503895931360"},
41-
"ap-northeast-2": {RULES_ECR_REPO_NAME: "578805364391"},
42-
"eu-west-2": {RULES_ECR_REPO_NAME: "250201462417"},
43-
"ap-northeast-1": {RULES_ECR_REPO_NAME: "430734990657"},
44-
"us-west-2": {RULES_ECR_REPO_NAME: "895741380848"},
45-
"us-west-1": {RULES_ECR_REPO_NAME: "685455198987"},
46-
"ap-southeast-1": {RULES_ECR_REPO_NAME: "972752614525"},
47-
"ap-southeast-2": {RULES_ECR_REPO_NAME: "184798709955"},
48-
"ca-central-1": {RULES_ECR_REPO_NAME: "519511493484"},
49-
"cn-north-1": {RULES_ECR_REPO_NAME: "618459771430"},
50-
"cn-northwest-1": {RULES_ECR_REPO_NAME: "658757709296"},
51-
}
26+
from sagemaker import image_uris
27+
28+
framework_name = "debugger"
5229

5330

5431
def get_rule_container_image_uri(region):
@@ -61,9 +38,7 @@ def get_rule_container_image_uri(region):
6138
Returns:
6239
str: Formatted image uri for the given region and the rule container type
6340
"""
64-
registry_id = SAGEMAKER_RULE_CONTAINERS_ACCOUNTS_MAP.get(region).get(RULES_ECR_REPO_NAME)
65-
image_uri_prefix = get_ecr_image_uri_prefix(registry_id, region)
66-
return "{}/{}:latest".format(image_uri_prefix, RULES_ECR_REPO_NAME)
41+
return image_uris.retrieve(framework_name, region)
6742

6843

6944
class Rule(object):

src/sagemaker/deserializers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,8 @@ def deserialize(self, stream, content_type):
266266
list: A list of JSON serializable objects.
267267
"""
268268
try:
269-
lines = stream.read().rstrip().split("\n")
269+
body = stream.read().decode("utf-8")
270+
lines = body.rstrip().split("\n")
270271
return [json.loads(line) for line in lines]
271272
finally:
272273
stream.close()

src/sagemaker/estimator.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,12 @@
2929
from sagemaker.debugger import DebuggerHookConfig
3030
from sagemaker.debugger import TensorBoardOutputConfig # noqa: F401 # pylint: disable=unused-import
3131
from sagemaker.debugger import get_rule_container_image_uri
32-
from sagemaker.s3 import S3Uploader
32+
from sagemaker.deserializers import BytesDeserializer
33+
from sagemaker.s3 import S3Uploader, parse_s3_url
34+
from sagemaker.serializers import IdentitySerializer
3335

3436
from sagemaker.fw_utils import (
3537
tar_and_upload_dir,
36-
parse_s3_url,
3738
UploadedCode,
3839
validate_source_dir,
3940
_region_supports_debugger,
@@ -1341,16 +1342,14 @@ def create_model(
13411342
role=None,
13421343
image_uri=None,
13431344
predictor_cls=None,
1344-
serializer=None,
1345-
deserializer=None,
1346-
content_type=None,
1347-
accept=None,
1345+
serializer=IdentitySerializer(),
1346+
deserializer=BytesDeserializer(),
13481347
vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT,
13491348
**kwargs
13501349
):
13511350
"""Create a model to deploy.
13521351
1353-
The serializer, deserializer, content_type, and accept arguments are only used to define a
1352+
The serializer and deserializer arguments are only used to define a
13541353
default Predictor. They are ignored if an explicit predictor class is passed in.
13551354
Other arguments are passed through to the Model class.
13561355
@@ -1362,17 +1361,12 @@ def create_model(
13621361
Defaults to the image used for training.
13631362
predictor_cls (Predictor): The predictor class to use when
13641363
deploying the model.
1365-
serializer (callable): Should accept a single argument, the input
1366-
data, and return a sequence of bytes. May provide a content_type
1367-
attribute that defines the endpoint request content type
1368-
deserializer (callable): Should accept two arguments, the result
1369-
data and the response content type, and return a sequence of
1370-
bytes. May provide a content_type attribute that defines th
1371-
endpoint response Accept content type.
1372-
content_type (str): The invocation ContentType, overriding any
1373-
content_type from the serializer
1374-
accept (str): The invocation Accept, overriding any accept from the
1375-
deserializer.
1364+
serializer (:class:`~sagemaker.serializers.BaseSerializer`): A
1365+
serializer object, used to encode data for an inference endpoint
1366+
(default: :class:`~sagemaker.serializers.IdentitySerializer`).
1367+
deserializer (:class:`~sagemaker.deserializers.BaseDeserializer`): A
1368+
deserializer object, used to decode data from an inference
1369+
endpoint (default: :class:`~sagemaker.deserializers.BytesDeserializer`).
13761370
vpc_config_override (dict[str, list[str]]): Optional override for VpcConfig set on
13771371
the model.
13781372
Default: use subnets and security groups from this Estimator.
@@ -1391,7 +1385,7 @@ def create_model(
13911385
if predictor_cls is None:
13921386

13931387
def predict_wrapper(endpoint, session):
1394-
return Predictor(endpoint, session, serializer, deserializer, content_type, accept)
1388+
return Predictor(endpoint, session, serializer, deserializer)
13951389

13961390
predictor_cls = predict_wrapper
13971391

@@ -1418,7 +1412,7 @@ class Framework(EstimatorBase):
14181412
such as training/deployment images and predictor instances.
14191413
"""
14201414

1421-
__framework_name__ = None
1415+
_framework_name = None
14221416

14231417
LAUNCH_PS_ENV_NAME = "sagemaker_parameter_server_enabled"
14241418
LAUNCH_MPI_ENV_NAME = "sagemaker_mpi_enabled"
@@ -1816,7 +1810,7 @@ def train_image(self):
18161810
if self.image_uri:
18171811
return self.image_uri
18181812
return image_uris.retrieve(
1819-
self.__framework_name__,
1813+
self._framework_name,
18201814
self.sagemaker_session.boto_region_name,
18211815
instance_type=self.instance_type,
18221816
version=self.framework_version, # pylint: disable=no-member

0 commit comments

Comments
 (0)