Skip to content

Commit bf1dd1c

Browse files
author
Balaji Veeramani
committed
Merge branch 'zwei' into add-json-deserializer
2 parents d87749f + 6ac82e9 commit bf1dd1c

31 files changed

+1215
-411
lines changed

MANIFEST.in

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
recursive-include src/sagemaker *
1+
recursive-include src/sagemaker *.py
2+
3+
include src/sagemaker/image_uri_config/*.json
24

35
include VERSION
46
include LICENSE.txt

doc/frameworks/tensorflow/upgrade_from_legacy.rst

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,11 +245,10 @@ For example, if you want to use JSON serialization and deserialization:
245245

246246
.. code:: python
247247
248-
from sagemaker.predictor import json_serializer
249248
from sagemaker.deserializers import JSONDeserializer
249+
from sagemaker.serializers import JSONSerializer
250250
251-
predictor.content_type = "application/json"
252-
predictor.serializer = json_serializer
251+
predictor.serializer = JSONSerializer()
253252
predictor.deserializer = JSONDeserializer()
254253
255254
predictor.predict(data)

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def read_version():
8383
packages=find_packages("src"),
8484
package_dir={"": "src"},
8585
py_modules=[os.path.splitext(os.path.basename(path))[0] for path in glob("src/*.py")],
86+
include_package_data=True,
8687
long_description=read("README.rst"),
8788
author="Amazon Web Services",
8889
url="https://github.com/aws/sagemaker-python-sdk/",

src/sagemaker/chainer/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2626
from sagemaker.chainer import defaults
2727
from sagemaker.deserializers import NumpyDeserializer
28-
from sagemaker.predictor import Predictor, npy_serializer
28+
from sagemaker.predictor import Predictor
29+
from sagemaker.serializers import NumpySerializer
2930

3031
logger = logging.getLogger("sagemaker")
3132

@@ -49,7 +50,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
4950
using the default AWS configuration chain.
5051
"""
5152
super(ChainerPredictor, self).__init__(
52-
endpoint_name, sagemaker_session, npy_serializer, NumpyDeserializer()
53+
endpoint_name, sagemaker_session, NumpySerializer(), NumpyDeserializer()
5354
)
5455

5556

src/sagemaker/cli/compatibility/v2/ast_transformer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,15 @@
3434
modifiers.renamed_params.SessionCreateModelImageURIRenamer(),
3535
modifiers.renamed_params.SessionCreateEndpointImageURIRenamer(),
3636
modifiers.training_params.TrainPrefixRemover(),
37+
modifiers.training_input.TrainingInputConstructorRefactor(),
3738
]
3839

3940
IMPORT_MODIFIERS = [modifiers.tfs.TensorFlowServingImportRenamer()]
4041

4142
IMPORT_FROM_MODIFIERS = [
4243
modifiers.predictors.PredictorImportFromRenamer(),
4344
modifiers.tfs.TensorFlowServingImportFromRenamer(),
45+
modifiers.training_input.TrainingInputImportFromRenamer(),
4446
]
4547

4648

src/sagemaker/cli/compatibility/v2/modifiers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,5 @@
2222
tf_legacy_mode,
2323
tfs,
2424
training_params,
25+
training_input,
2526
)

src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
import ast
1717

18+
from packaging.version import InvalidVersion, Version
19+
1820
from sagemaker.cli.compatibility.v2.modifiers import matching, parsing
1921
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
2022

@@ -135,10 +137,15 @@ def _tf_py_version_default(framework_version):
135137
"""Gets the py_version default based on framework_version for TensorFlow."""
136138
if not framework_version:
137139
return "py2"
138-
version = [int(s) for s in framework_version.split(".")]
139-
if version < [1, 12]:
140+
141+
try:
142+
version = Version(framework_version)
143+
except InvalidVersion:
144+
return "py2"
145+
146+
if version < Version("1.12"):
140147
return "py2"
141-
if version < [2, 2]:
148+
if version < Version("2.2"):
142149
return "py3"
143150
return "py37"
144151

@@ -186,7 +193,6 @@ def _version_args_needed(node):
186193
framework, is_model = _framework_from_node(node)
187194
expecting_py_version = _py_version_defaults(framework, framework_version, is_model)
188195
if expecting_py_version:
189-
py_version = parsing.arg_value(node, PY_ARG)
190-
return py_version is None
196+
return not matching.has_arg(node, PY_ARG)
191197

192198
return False

src/sagemaker/cli/compatibility/v2/modifiers/tf_legacy_mode.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,11 @@ def _is_legacy_mode(self, node):
7979

8080
for kw in node.keywords:
8181
if kw.arg == "script_mode":
82-
script_mode = bool(kw.value.value)
82+
script_mode = (
83+
bool(kw.value.value) if isinstance(kw.value, ast.NameConstant) else True
84+
)
8385
if kw.arg == "py_version":
84-
py_version = kw.value.s
86+
py_version = kw.value.s if isinstance(kw.value, ast.Str) else "py3"
8587

8688
return not (py_version.startswith("py3") or script_mode)
8789

@@ -124,7 +126,8 @@ def modify_node(self, node):
124126

125127
if add_image_uri:
126128
image_uri = self._image_uri_from_args(node.keywords)
127-
node.keywords.append(ast.keyword(arg="image_uri", value=ast.Str(s=image_uri)))
129+
if image_uri:
130+
node.keywords.append(ast.keyword(arg="image_uri", value=ast.Str(s=image_uri)))
128131

129132
node.keywords.append(ast.keyword(arg="model_dir", value=ast.NameConstant(value=False)))
130133

@@ -155,19 +158,22 @@ def _to_ast_keyword(self, hps):
155158
return None
156159

157160
def _image_uri_from_args(self, keywords):
158-
"""Returns a legacy TensorFlow image URI based on the estimator arguments."""
161+
"""Returns a legacy TensorFlow image URI based on the estimator arguments if possible."""
159162
tf_version = framework_version.FRAMEWORK_DEFAULTS["TensorFlow"]
160163
instance_type = "ml.m4.xlarge" # CPU default (exact type doesn't matter)
161164

162165
for kw in keywords:
163166
if kw.arg == "framework_version":
164-
tf_version = kw.value.s
167+
tf_version = kw.value.s if isinstance(kw.value, ast.Str) else None
165168
if kw.arg == "train_instance_type":
166-
instance_type = kw.value.s
169+
instance_type = kw.value.s if isinstance(kw.value, ast.Str) else None
167170

168-
return fw_utils.create_image_uri(
169-
self.region, "tensorflow", instance_type, tf_version, "py2"
170-
)
171+
if tf_version and instance_type:
172+
return fw_utils.create_image_uri(
173+
self.region, "tensorflow", instance_type, tf_version, "py2"
174+
)
175+
176+
return None
171177

172178

173179
class TensorBoardParameterRemover(Modifier):
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Classes to modify TrainingInput code to be compatible
14+
with version 2.0 and later of the SageMaker Python SDK.
15+
"""
16+
from __future__ import absolute_import
17+
18+
import ast
19+
20+
from sagemaker.cli.compatibility.v2.modifiers import matching
21+
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
22+
23+
S3_INPUT_NAME = "s3_input"
24+
S3_INPUT_NAMESPACES = ("sagemaker", "sagemaker.inputs", "sagemaker.session")
25+
26+
27+
class TrainingInputConstructorRefactor(Modifier):
28+
"""A class to refactor *s3_input class."""
29+
30+
def node_should_be_modified(self, node):
31+
"""Checks if the ``ast.Call`` node instantiates a class of interest.
32+
33+
This looks for the following calls:
34+
35+
- ``sagemaker.s3_input``
36+
- ``sagemaker.session.s3_input``
37+
- ``s3_input``
38+
39+
Args:
40+
node (ast.Call): a node that represents a function call. For more,
41+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
42+
43+
Returns:
44+
bool: If the ``ast.Call`` instantiates a class of interest.
45+
"""
46+
return matching.matches_name_or_namespaces(node, S3_INPUT_NAME, S3_INPUT_NAMESPACES)
47+
48+
def modify_node(self, node):
49+
"""Modifies the ``ast.Call`` node to call ``TrainingInput`` instead.
50+
51+
Args:
52+
node (ast.Call): a node that represents a *TrainingInput constructor.
53+
"""
54+
if matching.matches_name(node, S3_INPUT_NAME):
55+
node.func.id = "TrainingInput"
56+
elif matching.matches_attr(node, S3_INPUT_NAME):
57+
node.func.attr = "TrainingInput"
58+
_rename_namespace(node, "session")
59+
60+
61+
def _rename_namespace(node, name):
62+
"""Rename namespace ``session`` to ``inputs`` """
63+
if isinstance(node.func.value, ast.Attribute) and node.func.value.attr == name:
64+
node.func.value.attr = "inputs"
65+
elif isinstance(node.func.value, ast.Name) and node.func.value.id == name:
66+
node.func.value.id = "inputs"
67+
68+
69+
class TrainingInputImportFromRenamer(Modifier):
70+
"""A class to update import statements of ``s3_input``."""
71+
72+
def node_should_be_modified(self, node):
73+
"""Checks if the import statement imports ``s3_input`` from the correct module.
74+
75+
Args:
76+
node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
77+
For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.
78+
79+
Returns:
80+
bool: If the import statement imports ``s3_input`` from the correct module.
81+
"""
82+
return node.module in S3_INPUT_NAMESPACES and any(
83+
name.name == S3_INPUT_NAME for name in node.names
84+
)
85+
86+
def modify_node(self, node):
87+
"""Changes the ``ast.ImportFrom`` node's name from ``s3_input`` to ``TrainingInput``.
88+
89+
Args:
90+
node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
91+
For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.
92+
"""
93+
for name in node.names:
94+
if name.name == S3_INPUT_NAME:
95+
name.name = "TrainingInput"
96+
if node.module == "sagemaker.session":
97+
node.module = "sagemaker.inputs"

src/sagemaker/deserializers.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
"""Implements methods for deserializing data returned from an inference endpoint."""
1414
from __future__ import absolute_import
1515

16+
import csv
17+
1618
import abc
1719
import codecs
1820
import io
@@ -96,6 +98,37 @@ def deserialize(self, data, content_type):
9698
data.close()
9799

98100

101+
class CSVDeserializer(BaseDeserializer):
102+
"""Deserialize a stream of bytes into a list of lists."""
103+
104+
ACCEPT = "text/csv"
105+
106+
def __init__(self, encoding="utf-8"):
107+
"""Initialize the string encoding.
108+
109+
Args:
110+
encoding (str): The string encoding to use (default: "utf-8").
111+
"""
112+
self.encoding = encoding
113+
114+
def deserialize(self, data, content_type):
115+
"""Deserialize data from an inference endpoint into a list of lists.
116+
117+
Args:
118+
data (botocore.response.StreamingBody): Data to be deserialized.
119+
content_type (str): The MIME type of the data.
120+
121+
Returns:
122+
list: The data deserialized into a list of lists representing the
123+
contents of a CSV file.
124+
"""
125+
try:
126+
decoded_string = data.read().decode(self.encoding)
127+
return list(csv.reader(decoded_string.splitlines()))
128+
finally:
129+
data.close()
130+
131+
99132
class StreamDeserializer(BaseDeserializer):
100133
"""Returns the data and content-type received from an inference endpoint.
101134

0 commit comments

Comments
 (0)