Skip to content

Commit 60f3721

Browse files
author
Balaji Veeramani
committed
Resolve merge conflicts
2 parents 50b3d39 + 0e4c0fa commit 60f3721

File tree

24 files changed

+825
-266
lines changed

24 files changed

+825
-266
lines changed

MANIFEST.in

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
recursive-include src/sagemaker *
1+
recursive-include src/sagemaker *.py
2+
3+
include src/sagemaker/image_uri_config/*.json
24

35
include VERSION
46
include LICENSE.txt

doc/frameworks/tensorflow/upgrade_from_legacy.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,10 +245,10 @@ For example, if you want to use JSON serialization and deserialization:
245245

246246
.. code:: python
247247
248-
from sagemaker.predictor import json_deserializer, json_serializer
248+
from sagemaker.predictor import json_deserializer
249+
from sagemaker.serializers import JSONSerializer
249250
250-
predictor.content_type = "application/json"
251-
predictor.serializer = json_serializer
251+
predictor.serializer = JSONSerializer()
252252
predictor.accept = "application/json"
253253
predictor.deserializer = json_deserializer
254254

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def read_version():
8383
packages=find_packages("src"),
8484
package_dir={"": "src"},
8585
py_modules=[os.path.splitext(os.path.basename(path))[0] for path in glob("src/*.py")],
86+
include_package_data=True,
8687
long_description=read("README.rst"),
8788
author="Amazon Web Services",
8889
url="https://github.com/aws/sagemaker-python-sdk/",

src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
import ast
1717

18+
from packaging.version import InvalidVersion, Version
19+
1820
from sagemaker.cli.compatibility.v2.modifiers import matching, parsing
1921
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
2022

@@ -135,10 +137,15 @@ def _tf_py_version_default(framework_version):
135137
"""Gets the py_version default based on framework_version for TensorFlow."""
136138
if not framework_version:
137139
return "py2"
138-
version = [int(s) for s in framework_version.split(".")]
139-
if version < [1, 12]:
140+
141+
try:
142+
version = Version(framework_version)
143+
except InvalidVersion:
144+
return "py2"
145+
146+
if version < Version("1.12"):
140147
return "py2"
141-
if version < [2, 2]:
148+
if version < Version("2.2"):
142149
return "py3"
143150
return "py37"
144151

@@ -186,7 +193,6 @@ def _version_args_needed(node):
186193
framework, is_model = _framework_from_node(node)
187194
expecting_py_version = _py_version_defaults(framework, framework_version, is_model)
188195
if expecting_py_version:
189-
py_version = parsing.arg_value(node, PY_ARG)
190-
return py_version is None
196+
return not matching.has_arg(node, PY_ARG)
191197

192198
return False

src/sagemaker/cli/compatibility/v2/modifiers/tf_legacy_mode.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,11 @@ def _is_legacy_mode(self, node):
7979

8080
for kw in node.keywords:
8181
if kw.arg == "script_mode":
82-
script_mode = bool(kw.value.value)
82+
script_mode = (
83+
bool(kw.value.value) if isinstance(kw.value, ast.NameConstant) else True
84+
)
8385
if kw.arg == "py_version":
84-
py_version = kw.value.s
86+
py_version = kw.value.s if isinstance(kw.value, ast.Str) else "py3"
8587

8688
return not (py_version.startswith("py3") or script_mode)
8789

@@ -124,7 +126,8 @@ def modify_node(self, node):
124126

125127
if add_image_uri:
126128
image_uri = self._image_uri_from_args(node.keywords)
127-
node.keywords.append(ast.keyword(arg="image_uri", value=ast.Str(s=image_uri)))
129+
if image_uri:
130+
node.keywords.append(ast.keyword(arg="image_uri", value=ast.Str(s=image_uri)))
128131

129132
node.keywords.append(ast.keyword(arg="model_dir", value=ast.NameConstant(value=False)))
130133

@@ -155,19 +158,22 @@ def _to_ast_keyword(self, hps):
155158
return None
156159

157160
def _image_uri_from_args(self, keywords):
158-
"""Returns a legacy TensorFlow image URI based on the estimator arguments."""
161+
"""Returns a legacy TensorFlow image URI based on the estimator arguments if possible."""
159162
tf_version = framework_version.FRAMEWORK_DEFAULTS["TensorFlow"]
160163
instance_type = "ml.m4.xlarge" # CPU default (exact type doesn't matter)
161164

162165
for kw in keywords:
163166
if kw.arg == "framework_version":
164-
tf_version = kw.value.s
167+
tf_version = kw.value.s if isinstance(kw.value, ast.Str) else None
165168
if kw.arg == "train_instance_type":
166-
instance_type = kw.value.s
169+
instance_type = kw.value.s if isinstance(kw.value, ast.Str) else None
167170

168-
return fw_utils.create_image_uri(
169-
self.region, "tensorflow", instance_type, tf_version, "py2"
170-
)
171+
if tf_version and instance_type:
172+
return fw_utils.create_image_uri(
173+
self.region, "tensorflow", instance_type, tf_version, "py2"
174+
)
175+
176+
return None
171177

172178

173179
class TensorBoardParameterRemover(Modifier):

src/sagemaker/deserializers.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
"""Implements methods for deserializing data returned from an inference endpoint."""
1414
from __future__ import absolute_import
1515

16+
import csv
17+
1618
import abc
1719
import codecs
1820
import io
@@ -96,6 +98,37 @@ def deserialize(self, data, content_type):
9698
data.close()
9799

98100

101+
class CSVDeserializer(BaseDeserializer):
102+
"""Deserialize a stream of bytes into a list of lists."""
103+
104+
ACCEPT = "text/csv"
105+
106+
def __init__(self, encoding="utf-8"):
107+
"""Initialize the string encoding.
108+
109+
Args:
110+
encoding (str): The string encoding to use (default: "utf-8").
111+
"""
112+
self.encoding = encoding
113+
114+
def deserialize(self, data, content_type):
115+
"""Deserialize data from an inference endpoint into a list of lists.
116+
117+
Args:
118+
data (botocore.response.StreamingBody): Data to be deserialized.
119+
content_type (str): The MIME type of the data.
120+
121+
Returns:
122+
list: The data deserialized into a list of lists representing the
123+
contents of a CSV file.
124+
"""
125+
try:
126+
decoded_string = data.read().decode(self.encoding)
127+
return list(csv.reader(decoded_string.splitlines()))
128+
finally:
129+
data.close()
130+
131+
99132
class StreamDeserializer(BaseDeserializer):
100133
"""Returns the data and content-type received from an inference endpoint.
101134
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
{
2+
"processors": ["cpu", "gpu"],
3+
"version_aliases": {
4+
"4.0": "4.0.0",
5+
"4.1": "4.1.0",
6+
"5.0": "5.0.0"
7+
},
8+
"versions": {
9+
"4.0.0": {
10+
"registries": {
11+
"ap-east-1": "057415533634",
12+
"ap-northeast-1": "520713654638",
13+
"ap-northeast-2": "520713654638",
14+
"ap-south-1": "520713654638",
15+
"ap-southeast-1": "520713654638",
16+
"ap-southeast-2": "520713654638",
17+
"ca-central-1": "520713654638",
18+
"cn-north-1": "422961961927",
19+
"cn-northwest-1": "423003514399",
20+
"eu-central-1": "520713654638",
21+
"eu-north-1": "520713654638",
22+
"eu-west-1": "520713654638",
23+
"eu-west-2": "520713654638",
24+
"eu-west-3": "520713654638",
25+
"me-south-1": "724002660598",
26+
"sa-east-1": "520713654638",
27+
"us-east-1": "520713654638",
28+
"us-east-2": "520713654638",
29+
"us-gov-west-1": "246785580436",
30+
"us-iso-east-1": "744548109606",
31+
"us-west-1": "520713654638",
32+
"us-west-2": "520713654638"
33+
},
34+
"repository": "sagemaker-chainer",
35+
"py_versions": ["py2", "py3"]
36+
},
37+
"4.1.0": {
38+
"registries": {
39+
"ap-east-1": "057415533634",
40+
"ap-northeast-1": "520713654638",
41+
"ap-northeast-2": "520713654638",
42+
"ap-south-1": "520713654638",
43+
"ap-southeast-1": "520713654638",
44+
"ap-southeast-2": "520713654638",
45+
"ca-central-1": "520713654638",
46+
"cn-north-1": "422961961927",
47+
"cn-northwest-1": "423003514399",
48+
"eu-central-1": "520713654638",
49+
"eu-north-1": "520713654638",
50+
"eu-west-1": "520713654638",
51+
"eu-west-2": "520713654638",
52+
"eu-west-3": "520713654638",
53+
"me-south-1": "724002660598",
54+
"sa-east-1": "520713654638",
55+
"us-east-1": "520713654638",
56+
"us-east-2": "520713654638",
57+
"us-gov-west-1": "246785580436",
58+
"us-iso-east-1": "744548109606",
59+
"us-west-1": "520713654638",
60+
"us-west-2": "520713654638"
61+
},
62+
"repository": "sagemaker-chainer",
63+
"py_versions": ["py2", "py3"]
64+
},
65+
"5.0.0": {
66+
"registries": {
67+
"ap-east-1": "057415533634",
68+
"ap-northeast-1": "520713654638",
69+
"ap-northeast-2": "520713654638",
70+
"ap-south-1": "520713654638",
71+
"ap-southeast-1": "520713654638",
72+
"ap-southeast-2": "520713654638",
73+
"ca-central-1": "520713654638",
74+
"cn-north-1": "422961961927",
75+
"cn-northwest-1": "423003514399",
76+
"eu-central-1": "520713654638",
77+
"eu-north-1": "520713654638",
78+
"eu-west-1": "520713654638",
79+
"eu-west-2": "520713654638",
80+
"eu-west-3": "520713654638",
81+
"me-south-1": "724002660598",
82+
"sa-east-1": "520713654638",
83+
"us-east-1": "520713654638",
84+
"us-east-2": "520713654638",
85+
"us-gov-west-1": "246785580436",
86+
"us-iso-east-1": "744548109606",
87+
"us-west-1": "520713654638",
88+
"us-west-2": "520713654638"
89+
},
90+
"repository": "sagemaker-chainer",
91+
"py_versions": ["py2", "py3"]
92+
}
93+
}
94+
}

src/sagemaker/image_uris.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Functions for generating ECR image URIs for pre-built SageMaker Docker images."""
14+
from __future__ import absolute_import
15+
16+
import json
17+
import os
18+
19+
from sagemaker import utils
20+
21+
ECR_URI_TEMPLATE = "{registry}.dkr.{hostname}/{repository}:{tag}"
22+
23+
24+
def retrieve(framework, region, version=None, py_version=None, instance_type=None):
25+
"""Retrieves the ECR URI for the Docker image matching the given arguments.
26+
27+
Args:
28+
framework (str): The name of the framework.
29+
region (str): The AWS region.
30+
version (str): The framework version. This is required if there is
31+
more than one supported version for the given framework.
32+
py_version (str): The Python version. This is required if there is
33+
more than one supported Python version for the given framework version.
34+
instance_type (str): The SageMaker instance type. For supported types, see
35+
https://aws.amazon.com/sagemaker/pricing/instance-types. This is required if
36+
there are different images for different processor types.
37+
38+
Returns:
39+
str: the ECR URI for the corresponding SageMaker Docker image.
40+
41+
Raises:
42+
ValueError: If the framework version, Python version, processor type, or region is
43+
not supported given the other arguments.
44+
"""
45+
config = config_for_framework(framework)
46+
version_config = config["versions"][_version_for_config(version, config, framework)]
47+
48+
registry = _registry_from_region(region, version_config["registries"])
49+
hostname = utils._botocore_resolver().construct_endpoint("ecr", region)["hostname"]
50+
51+
repo = version_config["repository"]
52+
53+
_validate_py_version(py_version, version_config["py_versions"], framework, version)
54+
tag = "{}-{}-{}".format(version, _processor(instance_type, config["processors"]), py_version)
55+
56+
return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo, tag=tag)
57+
58+
59+
def config_for_framework(framework):
60+
"""Loads the JSON config for the given framework."""
61+
fname = os.path.join(os.path.dirname(__file__), "image_uri_config", "{}.json".format(framework))
62+
with open(fname) as f:
63+
return json.load(f)
64+
65+
66+
def _version_for_config(version, config, framework):
67+
"""Returns the version string for retrieving a framework version's specific config."""
68+
if "version_aliases" in config:
69+
if version in config["version_aliases"].keys():
70+
return config["version_aliases"][version]
71+
72+
available_versions = config["versions"].keys()
73+
if version in available_versions:
74+
return version
75+
76+
raise ValueError(
77+
"Unsupported {} version: {}. "
78+
"You may need to upgrade your SDK version (pip install -U sagemaker) for newer versions. "
79+
"Supported version(s): {}.".format(framework, version, ", ".join(available_versions))
80+
)
81+
82+
83+
def _registry_from_region(region, registry_dict):
84+
"""Returns the ECR registry (AWS account number) for the given region."""
85+
available_regions = registry_dict.keys()
86+
if region not in available_regions:
87+
raise ValueError(
88+
"Unsupported region: {}. You may need to upgrade "
89+
"your SDK version (pip install -U sagemaker) for newer regions. "
90+
"Supported region(s): {}.".format(region, ", ".join(available_regions))
91+
)
92+
93+
return registry_dict[region]
94+
95+
96+
def _processor(instance_type, available_processors):
97+
"""Returns the processor type for the given instance type."""
98+
if instance_type.startswith("local"):
99+
processor = "cpu" if instance_type == "local" else "gpu"
100+
elif not instance_type.startswith("ml."):
101+
raise ValueError(
102+
"Invalid SageMaker instance type: {}. See: "
103+
"https://aws.amazon.com/sagemaker/pricing/instance-types".format(instance_type)
104+
)
105+
else:
106+
family = instance_type.split(".")[1]
107+
processor = "gpu" if family[0] in ("g", "p") else "cpu"
108+
109+
if processor in available_processors:
110+
return processor
111+
112+
raise ValueError(
113+
"Unsupported processor type: {} (for {}). "
114+
"Supported type(s): {}.".format(processor, instance_type, ", ".join(available_processors))
115+
)
116+
117+
118+
def _validate_py_version(py_version, available_versions, framework, fw_version):
119+
"""Checks if the Python version is one of the supported versions."""
120+
if py_version not in available_versions:
121+
raise ValueError(
122+
"Unsupported Python version for {} {}: {}. You may need to upgrade "
123+
"your SDK version (pip install -U sagemaker) for newer versions. "
124+
"Supported Python version(s): {}.".format(
125+
framework, fw_version, py_version, ", ".join(available_versions)
126+
)
127+
)

0 commit comments

Comments
 (0)