Skip to content

Commit 20d9faa

Browse files
authored
Merge branch 'zwei' into deploy-new-resources
2 parents f04ca11 + cde5500 commit 20d9faa

File tree

20 files changed

+374
-131
lines changed

20 files changed

+374
-131
lines changed

doc/overview.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ Here is an end to end example of how to use a SageMaker Estimator:
116116
# Deletes the SageMaker model
117117
mxnet_predictor.delete_model()
118118
119-
The example above will eventually delete both the SageMaker endpoint and endpoint configuration through `delete_endpoint()`. If you want to keep your SageMaker endpoint configuration, use the value False for the `delete_endpoint_config` parameter, as shown below.
119+
The example above will eventually delete both the SageMaker endpoint and endpoint configuration through ``delete_endpoint()``. If you want to keep your SageMaker endpoint configuration, use the value ``False`` for the ``delete_endpoint_config`` parameter, as shown below.
120120

121121
.. code:: python
122122
@@ -180,10 +180,10 @@ Here is an example:
180180
train_input = algo.sagemaker_session.upload_data(path='/path/to/your/data')
181181
182182
algo.fit({'training': train_input})
183-
algo.deploy(1, 'ml.m4.xlarge')
183+
predictor = algo.deploy(1, 'ml.m4.xlarge')
184184
185185
# When you are done using your endpoint
186-
algo.delete_endpoint()
186+
predictor.delete_endpoint()
187187
188188
Use Scripts Stored in a Git Repository
189189
--------------------------------------
@@ -609,7 +609,7 @@ Here is a basic example of how to use it:
609609
response = my_predictor.predict(my_prediction_data)
610610
611611
# Tear down the SageMaker endpoint
612-
my_tuner.delete_endpoint()
612+
my_predictor.delete_endpoint()
613613
614614
This example shows a hyperparameter tuning job that creates up to 100 training jobs, running up to 10 training jobs at a time.
615615
Each training job's learning rate is a value between 0.05 and 0.06, but this value will differ between training jobs.

src/sagemaker/chainer/estimator.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -273,12 +273,10 @@ class constructor
273273
init_params["image_name"] = image_name
274274
return init_params
275275

276-
training_job_name = init_params["base_job_name"]
277-
278276
if framework != cls.__framework_name__:
279277
raise ValueError(
280278
"Training job: {} didn't use image for requested framework".format(
281-
training_job_name
279+
job_details["TrainingJobName"]
282280
)
283281
)
284282
return init_params

src/sagemaker/cli/compatibility/v2/ast_transformer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from sagemaker.cli.compatibility.v2 import modifiers
1919

2020
FUNCTION_CALL_MODIFIERS = [
21+
modifiers.predictors.PredictorConstructorRefactor(),
2122
modifiers.framework_version.FrameworkVersionEnforcer(),
2223
modifiers.tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader(),
2324
modifiers.tf_legacy_mode.TensorBoardParameterRemover(),
@@ -28,7 +29,10 @@
2829

2930
IMPORT_MODIFIERS = [modifiers.tfs.TensorFlowServingImportRenamer()]
3031

31-
IMPORT_FROM_MODIFIERS = [modifiers.tfs.TensorFlowServingImportFromRenamer()]
32+
IMPORT_FROM_MODIFIERS = [
33+
modifiers.predictors.PredictorImportFromRenamer(),
34+
modifiers.tfs.TensorFlowServingImportFromRenamer(),
35+
]
3236

3337

3438
class ASTTransformer(ast.NodeTransformer):

src/sagemaker/cli/compatibility/v2/modifiers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
airflow,
1818
deprecated_params,
1919
framework_version,
20+
predictors,
2021
tf_legacy_mode,
2122
tfs,
2223
)
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Classes to modify Predictor code to be compatible
14+
with version 2.0 and later of the SageMaker Python SDK.
15+
"""
16+
from __future__ import absolute_import
17+
18+
import ast
19+
20+
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
21+
22+
BASE_PREDICTOR = "RealTimePredictor"
23+
PREDICTORS = {
24+
"FactorizationMachinesPredictor": ("sagemaker", "sagemaker.amazon.factorization_machines"),
25+
"IPInsightsPredictor": ("sagemaker", "sagemaker.amazon.ipinsights"),
26+
"KMeansPredictor": ("sagemaker", "sagemaker.amazon.kmeans"),
27+
"KNNPredictor": ("sagemaker", "sagemaker.amazon.knn"),
28+
"LDAPredictor": ("sagemaker", "sagemaker.amazon.lda"),
29+
"LinearLearnerPredictor": ("sagemaker", "sagemaker.amazon.linear_learner"),
30+
"NTMPredictor": ("sagemaker", "sagemaker.amazon.ntm"),
31+
"PCAPredictor": ("sagemaker", "sagemaker.amazon.pca"),
32+
"RandomCutForestPredictor": ("sagemaker", "sagemaker.amazon.randomcutforest"),
33+
"RealTimePredictor": ("sagemaker", "sagemaker.predictor"),
34+
"SparkMLPredictor": ("sagemaker.sparkml", "sagemaker.sparkml.model"),
35+
}
36+
37+
38+
class PredictorConstructorRefactor(Modifier):
39+
"""A class to refactor *Predictor class and refactor endpoint attribute."""
40+
41+
def node_should_be_modified(self, node):
42+
"""Checks if the ``ast.Call`` node instantiates a class of interest.
43+
44+
This looks for the following calls:
45+
46+
- ``sagemaker.<my>.<namespace>.<MyPredictor>``
47+
- ``sagemaker.<namespace>.<MyPredictor>``
48+
- ``<MyPredictor>``
49+
50+
Args:
51+
node (ast.Call): a node that represents a function call. For more,
52+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
53+
54+
Returns:
55+
bool: If the ``ast.Call`` instantiates a class of interest.
56+
"""
57+
return any(_matching(node, name, namespaces) for name, namespaces in PREDICTORS.items())
58+
59+
def modify_node(self, node):
60+
"""Modifies the ``ast.Call`` node to call ``Predictor`` instead.
61+
62+
Also renames ``endpoint`` attribute to ``endpoint_name``.
63+
64+
Args:
65+
node (ast.Call): a node that represents a *Predictor constructor.
66+
"""
67+
_rename_class(node)
68+
_rename_endpoint(node)
69+
70+
71+
def _matching(node, name, namespaces):
72+
"""Determines if the node matches the constructor name in the right namespace"""
73+
if _matching_name(node, name):
74+
return True
75+
76+
if not _matching_attr(node, name):
77+
return False
78+
79+
return any(_matching_namespace(node, namespace) for namespace in namespaces)
80+
81+
82+
def _matching_name(node, name):
83+
"""Determines if the node is an ast.Name node with a matching name"""
84+
return isinstance(node.func, ast.Name) and node.func.id == name
85+
86+
87+
def _matching_attr(node, name):
88+
"""Determines if the node is an ast.Attribute node with a matching name"""
89+
return isinstance(node.func, ast.Attribute) and node.func.attr == name
90+
91+
92+
def _matching_namespace(node, namespace):
93+
"""Determines if the node corresponds to a matching namespace"""
94+
names = namespace.split(".")
95+
name, value = names.pop(), node.func.value
96+
while isinstance(value, ast.Attribute) and len(names) > 0:
97+
if value.attr != name:
98+
return False
99+
name, value = names.pop(), value.value
100+
101+
return isinstance(value, ast.Name) and value.id == name
102+
103+
104+
def _rename_class(node):
105+
"""Renames the RealTimePredictor base class to Predictor"""
106+
if _matching_name(node, BASE_PREDICTOR):
107+
node.func.id = "Predictor"
108+
elif _matching_attr(node, BASE_PREDICTOR):
109+
node.func.attr = "Predictor"
110+
111+
112+
def _rename_endpoint(node):
113+
"""Renames keyword endpoint argument to endpoint_name"""
114+
for keyword in node.keywords:
115+
if keyword.arg == "endpoint":
116+
keyword.arg = "endpoint_name"
117+
break
118+
119+
120+
class PredictorImportFromRenamer(Modifier):
121+
"""A class to update import statements of ``RealTimePredictor``."""
122+
123+
def node_should_be_modified(self, node):
124+
"""Checks if the import statement imports ``RealTimePredictor`` from the correct module.
125+
126+
Args:
127+
node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
128+
For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.
129+
130+
Returns:
131+
bool: If the import statement imports ``RealTimePredictor`` from the correct module.
132+
"""
133+
return node.module in PREDICTORS[BASE_PREDICTOR] and any(
134+
name.name == BASE_PREDICTOR for name in node.names
135+
)
136+
137+
def modify_node(self, node):
138+
"""Changes the ``ast.ImportFrom`` node's name from ``RealTimePredictor`` to ``Predictor``.
139+
140+
Args:
141+
node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
142+
For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.
143+
"""
144+
for name in node.names:
145+
if name.name == BASE_PREDICTOR:
146+
name.name = "Predictor"

src/sagemaker/estimator.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
from sagemaker.session import Session
5757
from sagemaker.session import s3_input
5858
from sagemaker.transformer import Transformer
59-
from sagemaker.utils import base_name_from_image, name_from_base, get_config_value
59+
from sagemaker.utils import base_from_name, base_name_from_image, name_from_base, get_config_value
6060
from sagemaker import vpc_utils
6161

6262

@@ -627,7 +627,7 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m
627627

628628
estimator = cls(sagemaker_session=sagemaker_session, **init_params)
629629
estimator.latest_training_job = _TrainingJob(
630-
sagemaker_session=sagemaker_session, job_name=init_params["base_job_name"]
630+
sagemaker_session=sagemaker_session, job_name=training_job_name
631631
)
632632
estimator._current_job_name = estimator.latest_training_job.name
633633
estimator.latest_training_job.wait()
@@ -791,7 +791,7 @@ class constructor
791791
init_params["train_volume_size"] = job_details["ResourceConfig"]["VolumeSizeInGB"]
792792
init_params["train_max_run"] = job_details["StoppingCondition"]["MaxRuntimeInSeconds"]
793793
init_params["input_mode"] = job_details["AlgorithmSpecification"]["TrainingInputMode"]
794-
init_params["base_job_name"] = job_details["TrainingJobName"]
794+
init_params["base_job_name"] = base_from_name(job_details["TrainingJobName"])
795795
init_params["output_path"] = job_details["OutputDataConfig"]["S3OutputPath"]
796796
init_params["output_kms_key"] = job_details["OutputDataConfig"]["KmsKeyId"]
797797
if "EnableNetworkIsolation" in job_details:
@@ -835,15 +835,6 @@ class constructor
835835

836836
return init_params
837837

838-
def delete_endpoint(self):
839-
"""Delete an Amazon SageMaker ``Endpoint``.
840-
841-
Raises:
842-
botocore.exceptions.ClientError: If the endpoint does not exist.
843-
"""
844-
self._ensure_latest_training_job(error_message="Endpoint was not created yet")
845-
self.sagemaker_session.delete_endpoint(self.latest_training_job.name)
846-
847838
def transformer(
848839
self,
849840
instance_count,

src/sagemaker/mxnet/estimator.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -276,12 +276,10 @@ class constructor
276276
init_params["image_name"] = image_name
277277
return init_params
278278

279-
training_job_name = init_params["base_job_name"]
280-
281279
if framework != cls.__framework_name__:
282280
raise ValueError(
283281
"Training job: {} didn't use image for requested framework".format(
284-
training_job_name
282+
job_details["TrainingJobName"]
285283
)
286284
)
287285

src/sagemaker/pytorch/estimator.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,12 +222,10 @@ class constructor
222222
init_params["image_name"] = image_name
223223
return init_params
224224

225-
training_job_name = init_params["base_job_name"]
226-
227225
if framework != cls.__framework_name__:
228226
raise ValueError(
229227
"Training job: {} didn't use image for requested framework".format(
230-
training_job_name
228+
job_details["TrainingJobName"]
231229
)
232230
)
233231

src/sagemaker/rl/estimator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,10 +315,9 @@ class constructor
315315
toolkit, toolkit_version = cls._toolkit_and_version_from_tag(tag)
316316

317317
if not cls._is_combination_supported(toolkit, toolkit_version, framework):
318-
training_job_name = init_params["base_job_name"]
319318
raise ValueError(
320319
"Training job: {} didn't use image for requested framework".format(
321-
training_job_name
320+
job_details["TrainingJobName"]
322321
)
323322
)
324323

src/sagemaker/sklearn/estimator.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,12 +244,10 @@ class constructor
244244
init_params["image_name"] = image_name
245245
return init_params
246246

247-
training_job_name = init_params["base_job_name"]
248-
249247
if framework and framework != cls.__framework_name__:
250248
raise ValueError(
251249
"Training job: {} didn't use image for requested framework".format(
252-
training_job_name
250+
job_details["TrainingJobName"]
253251
)
254252
)
255253

src/sagemaker/tensorflow/estimator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,11 +221,10 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
221221
if not script_mode:
222222
init_params["image_name"] = image_name
223223

224-
training_job_name = init_params["base_job_name"]
225224
if framework != cls.__framework_name__:
226225
raise ValueError(
227226
"Training job: {} didn't use image for requested framework".format(
228-
training_job_name
227+
job_details["TrainingJobName"]
229228
)
230229
)
231230

0 commit comments

Comments
 (0)