Skip to content

Commit 24b2ab9

Browse files
chuyang-dengChuyang Denglaurenyu
authored
change: add modifier for s3_input class (#1699)
* breaking: rename s3_input to TrainingInput * remove TrainingInput import from session * update docstring * change: add modifier for s3_input class * modify namespaces * rename namespace Co-authored-by: Chuyang Deng <[email protected]> Co-authored-by: Lauren Yu <[email protected]>
1 parent 0e4c0fa commit 24b2ab9

File tree

4 files changed

+213
-0
lines changed

4 files changed

+213
-0
lines changed

src/sagemaker/cli/compatibility/v2/ast_transformer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,15 @@
3434
modifiers.renamed_params.SessionCreateModelImageURIRenamer(),
3535
modifiers.renamed_params.SessionCreateEndpointImageURIRenamer(),
3636
modifiers.training_params.TrainPrefixRemover(),
37+
modifiers.training_input.TrainingInputConstructorRefactor(),
3738
]
3839

3940
IMPORT_MODIFIERS = [modifiers.tfs.TensorFlowServingImportRenamer()]
4041

4142
IMPORT_FROM_MODIFIERS = [
4243
modifiers.predictors.PredictorImportFromRenamer(),
4344
modifiers.tfs.TensorFlowServingImportFromRenamer(),
45+
modifiers.training_input.TrainingInputImportFromRenamer(),
4446
]
4547

4648

src/sagemaker/cli/compatibility/v2/modifiers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,5 @@
2222
tf_legacy_mode,
2323
tfs,
2424
training_params,
25+
training_input,
2526
)
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Classes to modify TrainingInput code to be compatible
14+
with version 2.0 and later of the SageMaker Python SDK.
15+
"""
16+
from __future__ import absolute_import
17+
18+
import ast
19+
20+
from sagemaker.cli.compatibility.v2.modifiers import matching
21+
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
22+
23+
S3_INPUT_NAME = "s3_input"
24+
S3_INPUT_NAMESPACES = ("sagemaker", "sagemaker.inputs", "sagemaker.session")
25+
26+
27+
class TrainingInputConstructorRefactor(Modifier):
28+
"""A class to refactor *s3_input class."""
29+
30+
def node_should_be_modified(self, node):
31+
"""Checks if the ``ast.Call`` node instantiates a class of interest.
32+
33+
This looks for the following calls:
34+
35+
- ``sagemaker.s3_input``
36+
- ``sagemaker.session.s3_input``
37+
- ``s3_input``
38+
39+
Args:
40+
node (ast.Call): a node that represents a function call. For more,
41+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
42+
43+
Returns:
44+
bool: If the ``ast.Call`` instantiates a class of interest.
45+
"""
46+
return matching.matches_name_or_namespaces(node, S3_INPUT_NAME, S3_INPUT_NAMESPACES)
47+
48+
def modify_node(self, node):
49+
"""Modifies the ``ast.Call`` node to call ``TrainingInput`` instead.
50+
51+
Args:
52+
node (ast.Call): a node that represents a *TrainingInput constructor.
53+
"""
54+
if matching.matches_name(node, S3_INPUT_NAME):
55+
node.func.id = "TrainingInput"
56+
elif matching.matches_attr(node, S3_INPUT_NAME):
57+
node.func.attr = "TrainingInput"
58+
_rename_namespace(node, "session")
59+
60+
61+
def _rename_namespace(node, name):
62+
"""Rename namespace ``session`` to ``inputs`` """
63+
if isinstance(node.func.value, ast.Attribute) and node.func.value.attr == name:
64+
node.func.value.attr = "inputs"
65+
elif isinstance(node.func.value, ast.Name) and node.func.value.id == name:
66+
node.func.value.id = "inputs"
67+
68+
69+
class TrainingInputImportFromRenamer(Modifier):
70+
"""A class to update import statements of ``s3_input``."""
71+
72+
def node_should_be_modified(self, node):
73+
"""Checks if the import statement imports ``s3_input`` from the correct module.
74+
75+
Args:
76+
node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
77+
For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.
78+
79+
Returns:
80+
bool: If the import statement imports ``s3_input`` from the correct module.
81+
"""
82+
return node.module in S3_INPUT_NAMESPACES and any(
83+
name.name == S3_INPUT_NAME for name in node.names
84+
)
85+
86+
def modify_node(self, node):
87+
"""Changes the ``ast.ImportFrom`` node's name from ``s3_input`` to ``TrainingInput``.
88+
89+
Args:
90+
node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
91+
For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.
92+
"""
93+
for name in node.names:
94+
if name.name == S3_INPUT_NAME:
95+
name.name = "TrainingInput"
96+
if node.module == "sagemaker.session":
97+
node.module = "sagemaker.inputs"
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pasta
16+
import pytest
17+
18+
from sagemaker.cli.compatibility.v2.modifiers import training_input
19+
from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call, ast_import
20+
21+
22+
@pytest.fixture
23+
def constructors():
24+
return (
25+
"sagemaker.session.s3_input(s3_data='s3://a')",
26+
"sagemaker.inputs.s3_input(s3_data='s3://a')",
27+
"sagemaker.s3_input(s3_data='s3://a')",
28+
"session.s3_input(s3_data='s3://a')",
29+
"inputs.s3_input(s3_data='s3://a')",
30+
"s3_input(s3_data='s3://a')",
31+
)
32+
33+
34+
@pytest.fixture
35+
def import_statements():
36+
return (
37+
"from sagemaker.session import s3_input",
38+
"from sagemaker.inputs import s3_input",
39+
"from sagemaker import s3_input",
40+
)
41+
42+
43+
def test_constructor_node_should_be_modified(constructors):
44+
modifier = training_input.TrainingInputConstructorRefactor()
45+
for constructor in constructors:
46+
node = ast_call(constructor)
47+
assert modifier.node_should_be_modified(node)
48+
49+
50+
def test_constructor_node_should_be_modified_random_call():
51+
modifier = training_input.TrainingInputConstructorRefactor()
52+
node = ast_call("FileSystemInput()")
53+
assert not modifier.node_should_be_modified(node)
54+
55+
56+
def test_constructor_modify_node():
57+
modifier = training_input.TrainingInputConstructorRefactor()
58+
59+
node = ast_call("s3_input(s3_data='s3://a')")
60+
modifier.modify_node(node)
61+
assert "TrainingInput(s3_data='s3://a')" == pasta.dump(node)
62+
63+
node = ast_call("sagemaker.s3_input(s3_data='s3://a')")
64+
modifier.modify_node(node)
65+
assert "sagemaker.TrainingInput(s3_data='s3://a')" == pasta.dump(node)
66+
67+
node = ast_call("session.s3_input(s3_data='s3://a')")
68+
modifier.modify_node(node)
69+
assert "inputs.TrainingInput(s3_data='s3://a')" == pasta.dump(node)
70+
71+
node = ast_call("inputs.s3_input(s3_data='s3://a')")
72+
modifier.modify_node(node)
73+
assert "inputs.TrainingInput(s3_data='s3://a')" == pasta.dump(node)
74+
75+
node = ast_call("sagemaker.inputs.s3_input(s3_data='s3://a')")
76+
modifier.modify_node(node)
77+
assert "sagemaker.inputs.TrainingInput(s3_data='s3://a')" == pasta.dump(node)
78+
79+
node = ast_call("sagemaker.session.s3_input(s3_data='s3://a')")
80+
modifier.modify_node(node)
81+
assert "sagemaker.inputs.TrainingInput(s3_data='s3://a')" == pasta.dump(node)
82+
83+
84+
def test_import_from_node_should_be_modified_training_input(import_statements):
85+
modifier = training_input.TrainingInputImportFromRenamer()
86+
for statement in import_statements:
87+
node = ast_import(statement)
88+
assert modifier.node_should_be_modified(node)
89+
90+
91+
def test_import_from_node_should_be_modified_random_import():
92+
modifier = training_input.TrainingInputImportFromRenamer()
93+
node = ast_import("from sagemaker.session import Session")
94+
assert not modifier.node_should_be_modified(node)
95+
96+
97+
def test_import_from_modify_node():
98+
modifier = training_input.TrainingInputImportFromRenamer()
99+
100+
node = ast_import("from sagemaker import s3_input")
101+
modifier.modify_node(node)
102+
expected_result = "from sagemaker import TrainingInput"
103+
assert expected_result == pasta.dump(node)
104+
105+
node = ast_import("from sagemaker.inputs import s3_input as training_input")
106+
modifier.modify_node(node)
107+
expected_result = "from sagemaker.inputs import TrainingInput as training_input"
108+
assert expected_result == pasta.dump(node)
109+
110+
node = ast_import("from sagemaker.session import s3_input as training_input")
111+
modifier.modify_node(node)
112+
expected_result = "from sagemaker.inputs import TrainingInput as training_input"
113+
assert expected_result == pasta.dump(node)

0 commit comments

Comments
 (0)