Skip to content

Commit c96e95a

Browse files
authored
Merge branch 'zwei' into ipynb-ignore-bangs
2 parents 34c98f7 + b1a2c23 commit c96e95a

File tree

11 files changed

+497
-62
lines changed

11 files changed

+497
-62
lines changed

src/sagemaker/cli/compatibility/v2/ast_transformer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
modifiers.training_input.TrainingInputConstructorRefactor(),
3838
modifiers.training_input.ShuffleConfigModuleRenamer(),
3939
modifiers.serde.SerdeConstructorRenamer(),
40+
modifiers.serde.SerdeKeywordRemover(),
41+
modifiers.image_uris.ImageURIRetrieveRefactor(),
4042
]
4143

4244
IMPORT_MODIFIERS = [modifiers.tfs.TensorFlowServingImportRenamer()]
@@ -55,6 +57,7 @@
5557
modifiers.training_input.ShuffleConfigImportFromRenamer(),
5658
modifiers.serde.SerdeImportFromAmazonCommonRenamer(),
5759
modifiers.serde.SerdeImportFromPredictorRenamer(),
60+
modifiers.image_uris.ImageURIRetrieveImportFromRenamer(),
5861
]
5962

6063

src/sagemaker/cli/compatibility/v2/modifiers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,5 @@
2424
tfs,
2525
training_params,
2626
training_input,
27+
image_uris,
2728
)
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Classes to modify image uri retrieve methods for Python SDK v2.0 and later."""
14+
from __future__ import absolute_import
15+
16+
import ast
17+
18+
from sagemaker.cli.compatibility.v2.modifiers import matching
19+
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
20+
21+
GET_IMAGE_URI_NAME = "get_image_uri"
22+
GET_IMAGE_URI_NAMESPACES = (
23+
"sagemaker",
24+
"sagemaker.amazon_estimator",
25+
"sagemaker.amazon.amazon_estimator",
26+
"amazon_estimator",
27+
"amazon.amazon_estimator",
28+
)
29+
30+
31+
class ImageURIRetrieveRefactor(Modifier):
32+
"""A class to refactor *get_image_uri() method."""
33+
34+
def node_should_be_modified(self, node):
35+
"""Checks if the ``ast.Call`` node calls a function of interest.
36+
37+
This looks for the following calls:
38+
39+
- ``sagemaker.get_image_uri``
40+
- ``sagemaker.amazon_estimator.get_image_uri``
41+
- ``get_image_uri``
42+
43+
Args:
44+
node (ast.Call): a node that represents a function call. For more,
45+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
46+
47+
Returns:
48+
bool: If the ``ast.Call`` instantiates a class of interest.
49+
"""
50+
return matching.matches_name_or_namespaces(
51+
node, GET_IMAGE_URI_NAME, GET_IMAGE_URI_NAMESPACES
52+
)
53+
54+
def modify_node(self, node):
55+
"""Modifies the ``ast.Call`` node to call ``image_uris.retrieve`` instead.
56+
And switch the first two parameters from (region, repo) to (framework, region)
57+
58+
Args:
59+
node (ast.Call): a node that represents a *image_uris.retrieve call.
60+
"""
61+
original_args = [None] * 3
62+
for kw in node.keywords:
63+
if kw.arg == "repo_name":
64+
original_args[0] = ast.Str(kw.value.s)
65+
elif kw.arg == "repo_region":
66+
original_args[1] = ast.Str(kw.value.s)
67+
elif kw.arg == "repo_version":
68+
original_args[2] = ast.Str(kw.value.s)
69+
70+
if len(node.args) > 0:
71+
original_args[1] = ast.Str(node.args[0].s)
72+
if len(node.args) > 1:
73+
original_args[0] = ast.Str(node.args[1].s)
74+
if len(node.args) > 2:
75+
original_args[2] = ast.Str(node.args[2].s)
76+
77+
args = []
78+
for arg in original_args:
79+
if arg:
80+
args.append(arg)
81+
82+
func = node.func
83+
has_sagemaker = False
84+
while hasattr(func, "value"):
85+
if hasattr(func.value, "id") and func.value.id == "sagemaker":
86+
has_sagemaker = True
87+
break
88+
func = func.value
89+
90+
if has_sagemaker:
91+
node.func = ast.Attribute(
92+
value=ast.Attribute(attr="image_uris", value=ast.Name(id="sagemaker")),
93+
attr="retrieve",
94+
)
95+
else:
96+
node.func = ast.Attribute(value=ast.Name(id="image_uris"), attr="retrieve")
97+
node.args = args
98+
node.keywords = []
99+
return node
100+
101+
102+
class ImageURIRetrieveImportFromRenamer(Modifier):
103+
"""A class to update import statements of ``get_image_uri``."""
104+
105+
def node_should_be_modified(self, node):
106+
"""Checks if the import statement imports ``get_image_uri`` from the correct module.
107+
108+
Args:
109+
node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
110+
For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.
111+
112+
Returns:
113+
bool: If the import statement imports ``get_image_uri`` from the correct module.
114+
"""
115+
return node.module in GET_IMAGE_URI_NAMESPACES and any(
116+
name.name == GET_IMAGE_URI_NAME for name in node.names
117+
)
118+
119+
def modify_node(self, node):
120+
"""Changes the ``ast.ImportFrom`` node's name from ``get_image_uri`` to ``image_uris``.
121+
122+
Args:
123+
node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
124+
For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.
125+
126+
Returns:
127+
ast.AST: the original node, which has been potentially modified.
128+
"""
129+
for name in node.names:
130+
if name.name == GET_IMAGE_URI_NAME:
131+
name.name = "image_uris"
132+
if node.module in GET_IMAGE_URI_NAMESPACES:
133+
node.module = "sagemaker"
134+
return node

src/sagemaker/cli/compatibility/v2/modifiers/serde.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,48 @@ def modify_node(self, node):
157157
)
158158

159159

160+
class SerdeKeywordRemover(Modifier):
161+
"""A class to remove Serde-related keyword arguments from call expressions."""
162+
163+
def node_should_be_modified(self, node):
164+
"""Checks if the ``ast.Call`` node uses deprecated keywords.
165+
166+
In particular, this function checks if:
167+
168+
- The ``ast.Call`` represents the ``create_model`` method.
169+
- Either the serializer or deserializer keywords are used.
170+
171+
Args:
172+
node (ast.Call): a node that represents a function call. For more,
173+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
174+
175+
Returns:
176+
bool: If the ``ast.Call`` contains keywords that should be removed.
177+
"""
178+
if not isinstance(node.func, ast.Attribute) or node.func.attr != "create_model":
179+
return False
180+
return any(keyword.arg in {"serializer", "deserializer"} for keyword in node.keywords)
181+
182+
def modify_node(self, node):
183+
"""Removes the serializer and deserializer keywords, as applicable.
184+
185+
Args:
186+
node (ast.Call): a node that represents a ``create_model`` call.
187+
188+
Returns:
189+
ast.Call: the node that represents a ``create_model`` call without
190+
serializer or deserializers keywords.
191+
"""
192+
i = 0
193+
while i < len(node.keywords):
194+
keyword = node.keywords[i]
195+
if keyword.arg in {"serializer", "deserializer"}:
196+
node.keywords.pop(i)
197+
else:
198+
i += 1
199+
return node
200+
201+
160202
class SerdeObjectRenamer(Modifier):
161203
"""A class to rename SerDe objects imported from ``sagemaker.predictor``."""
162204

0 commit comments

Comments
 (0)