Skip to content

Commit 2d69f65

Browse files
authored
Merge branch 'zwei' into add-serde-base-classes
2 parents 3bc8575 + 2d0d549 commit 2d69f65

File tree

6 files changed

+245
-79
lines changed

6 files changed

+245
-79
lines changed

src/sagemaker/cli/compatibility/v2/ast_transformer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
modifiers.tfs.TensorFlowServingConstructorRenamer(),
2626
modifiers.predictors.PredictorConstructorRefactor(),
2727
modifiers.airflow.ModelConfigArgModifier(),
28-
modifiers.estimators.DistributionParameterRenamer(),
28+
modifiers.renamed_params.DistributionParameterRenamer(),
29+
modifiers.renamed_params.S3SessionRenamer(),
2930
]
3031

3132
IMPORT_MODIFIERS = [modifiers.tfs.TensorFlowServingImportRenamer()]

src/sagemaker/cli/compatibility/v2/modifiers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
from sagemaker.cli.compatibility.v2.modifiers import ( # noqa: F401 (imported but unused)
1717
airflow,
1818
deprecated_params,
19-
estimators,
2019
framework_version,
2120
predictors,
21+
renamed_params,
2222
tf_legacy_mode,
2323
tfs,
2424
)

src/sagemaker/cli/compatibility/v2/modifiers/estimators.py

Lines changed: 0 additions & 72 deletions
This file was deleted.
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Classes to modify Predictor code to be compatible
14+
with version 2.0 and later of the SageMaker Python SDK.
15+
"""
16+
from __future__ import absolute_import
17+
18+
import ast
19+
from abc import abstractmethod
20+
21+
from sagemaker.cli.compatibility.v2.modifiers import matching
22+
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
23+
24+
25+
class ParamRenamer(Modifier):
26+
"""Abstract class to take in an AST node, check if it is a function call with
27+
an argument that needs to be renamed, and rename the argument if needed.
28+
"""
29+
30+
@property
31+
@abstractmethod
32+
def calls_to_modify(self):
33+
"""A dictionary mapping function names to possible namespaces."""
34+
35+
@property
36+
@abstractmethod
37+
def old_param_name(self):
38+
"""The parameter name used in previous versions of the SageMaker Python SDK."""
39+
40+
@property
41+
@abstractmethod
42+
def new_param_name(self):
43+
"""The parameter name used in version 2.0 and later of the SageMaker Python SDK."""
44+
45+
def node_should_be_modified(self, node):
46+
"""Checks if the node matches any of the relevant functions and
47+
contains the parameter to be renamed.
48+
49+
Args:
50+
node (ast.Call): a node that represents a function call. For more,
51+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
52+
53+
Returns:
54+
bool: If the ``ast.Call`` matches the relevant function calls and
55+
contains the parameter to be renamed.
56+
"""
57+
return matching.matches_any(node, self.calls_to_modify) and self._has_param_to_rename(node)
58+
59+
def _has_param_to_rename(self, node):
60+
"""Checks if the call has the argument that needs to be renamed."""
61+
return _keyword_from_keywords(node, self.old_param_name) is not None
62+
63+
def modify_node(self, node):
64+
"""Modifies the ``ast.Call`` node to rename the attribute.
65+
66+
Args:
67+
node (ast.Call): a node that represents the relevant function call.
68+
"""
69+
keyword = _keyword_from_keywords(node, self.old_param_name)
70+
keyword.arg = self.new_param_name
71+
72+
73+
def _keyword_from_keywords(node, param_name):
74+
"""Retrieves a keyword argument from the node's keywords.
75+
76+
Args:
77+
node (ast.Call): a node that represents a function call. For more,
78+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
79+
param_name (str): the name of the argument.
80+
81+
Returns:
82+
ast.keyword: the keyword argument if it is present. Otherwise, this returns ``None``.
83+
"""
84+
for kw in node.keywords:
85+
if kw.arg == param_name:
86+
return kw
87+
88+
return None
89+
90+
91+
class DistributionParameterRenamer(ParamRenamer):
92+
"""A class to rename the ``distributions`` attribute to ``distrbution`` in
93+
MXNet and TensorFlow estimators.
94+
95+
This looks for the following calls:
96+
97+
- ``<Framework>``
98+
- ``sagemaker.<framework>.<Framework>``
99+
- ``sagemaker.<framework>.estimator.<Framework>``
100+
101+
where ``<Framework>`` is either ``TensorFlow`` or ``MXNet``.
102+
"""
103+
104+
@property
105+
def calls_to_modify(self):
106+
"""A dictionary mapping ``MXNet`` and ``TensorFlow`` to their respective namespaces."""
107+
return {
108+
"TensorFlow": ("sagemaker.tensorflow", "sagemaker.tensorflow.estimator"),
109+
"MXNet": ("sagemaker.mxnet", "sagemaker.mxnet.estimator"),
110+
}
111+
112+
@property
113+
def old_param_name(self):
114+
"""The previous name for the distribution argument."""
115+
return "distributions"
116+
117+
@property
118+
def new_param_name(self):
119+
"""The new name for the distribution argument."""
120+
return "distribution"
121+
122+
123+
class S3SessionRenamer(ParamRenamer):
124+
"""A class to rename the ``session`` attribute to ``sagemaker_session`` in
125+
``S3Uploader`` and ``S3Downloader``.
126+
127+
This looks for the following calls:
128+
129+
- ``sagemaker.s3.S3Uploader.<function>``
130+
- ``s3.S3Uploader.<function>``
131+
- ``S3Uploader.<function>``
132+
133+
where ``S3Uploader`` is either ``S3Uploader`` or ``S3Downloader``, and where
134+
``<function>`` is any of the functions belonging to those two classes.
135+
"""
136+
137+
@property
138+
def calls_to_modify(self):
139+
"""A dictionary mapping S3 utility functions to their respective namespaces."""
140+
return {
141+
"download": ("sagemaker.s3.S3Downloader", "s3.S3Downloader", "S3Downloader"),
142+
"list": ("sagemaker.s3.S3Downloader", "s3.S3Downloader", "S3Downloader"),
143+
"read_file": ("sagemaker.s3.S3Downloader", "s3.S3Downloader", "S3Downloader"),
144+
"upload": ("sagemaker.s3.S3Uploader", "s3.S3Uploader", "S3Uploader"),
145+
"upload_string_as_file_body": (
146+
"sagemaker.s3.S3Uploader",
147+
"s3.S3Uploader",
148+
"S3Uploader",
149+
),
150+
}
151+
152+
@property
153+
def old_param_name(self):
154+
"""The previous name for the SageMaker session argument."""
155+
return "session"
156+
157+
@property
158+
def new_param_name(self):
159+
"""The new name for the SageMaker session argument."""
160+
return "sagemaker_session"
161+
162+
def node_should_be_modified(self, node):
163+
"""Checks if the node is one of the S3 utility functions and
164+
contains the ``session`` parameter.
165+
"""
166+
if isinstance(node.func, ast.Name):
167+
return False
168+
169+
return super(S3SessionRenamer, self).node_should_be_modified(node)

tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_distribution.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import pasta
1616

17-
from sagemaker.cli.compatibility.v2.modifiers import estimators
17+
from sagemaker.cli.compatibility.v2.modifiers import renamed_params
1818
from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call
1919

2020

@@ -28,7 +28,7 @@ def test_node_should_be_modified():
2828
"sagemaker.mxnet.estimator.MXNet(distributions={})",
2929
)
3030

31-
modifier = estimators.DistributionParameterRenamer()
31+
modifier = renamed_params.DistributionParameterRenamer()
3232

3333
for call in constructors:
3434
assert modifier.node_should_be_modified(ast_call(call))
@@ -44,20 +44,20 @@ def test_node_should_be_modified_no_distribution():
4444
"sagemaker.mxnet.estimator.MXNet()",
4545
)
4646

47-
modifier = estimators.DistributionParameterRenamer()
47+
modifier = renamed_params.DistributionParameterRenamer()
4848

4949
for call in constructors:
5050
assert not modifier.node_should_be_modified(ast_call(call))
5151

5252

5353
def test_node_should_be_modified_random_function_call():
54-
modifier = estimators.DistributionParameterRenamer()
54+
modifier = renamed_params.DistributionParameterRenamer()
5555
assert not modifier.node_should_be_modified(ast_call("Session()"))
5656

5757

5858
def test_modify_node():
5959
node = ast_call("TensorFlow(distributions={'parameter_server': {'enabled': True}})")
60-
modifier = estimators.DistributionParameterRenamer()
60+
modifier = renamed_params.DistributionParameterRenamer()
6161
modifier.modify_node(node)
6262

6363
expected = "TensorFlow(distribution={'parameter_server': {'enabled': True}})"
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pasta
16+
17+
from sagemaker.cli.compatibility.v2.modifiers import renamed_params
18+
from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call
19+
20+
NAMESPACES = ("", "s3.", "sagemaker.s3.")
21+
FUNCTIONS = (
22+
"S3Downloader.download",
23+
"S3Downloader.list",
24+
"S3Downloader.read_file",
25+
"S3Uploader.upload",
26+
"S3Uploader.upload_string_as_file_body",
27+
)
28+
29+
30+
def test_node_should_be_modified():
31+
modifier = renamed_params.S3SessionRenamer()
32+
33+
for func in FUNCTIONS:
34+
for namespace in NAMESPACES:
35+
call = ast_call("{}{}(session=sess)".format(namespace, func))
36+
assert modifier.node_should_be_modified(call)
37+
38+
39+
def test_node_should_be_modified_no_session():
40+
modifier = renamed_params.S3SessionRenamer()
41+
42+
for func in FUNCTIONS:
43+
for namespace in NAMESPACES:
44+
call = ast_call("{}{}()".format(namespace, func))
45+
assert not modifier.node_should_be_modified(call)
46+
47+
48+
def test_node_should_be_modified_random_function_call():
49+
modifier = renamed_params.S3SessionRenamer()
50+
51+
generic_function_calls = (
52+
"download()",
53+
"list()",
54+
"read_file()",
55+
"upload()",
56+
)
57+
58+
for call in generic_function_calls:
59+
assert not modifier.node_should_be_modified(ast_call(call))
60+
61+
62+
def test_modify_node():
63+
node = ast_call("S3Downloader.download(session=sess)")
64+
modifier = renamed_params.S3SessionRenamer()
65+
modifier.modify_node(node)
66+
67+
expected = "S3Downloader.download(sagemaker_session=sess)"
68+
assert expected == pasta.dump(node)

0 commit comments

Comments
 (0)