Skip to content

Commit 4c08cd7

Browse files
author
Chuyang Deng
committed
change: add modifier for s3_input class
1 parent a6e4830 commit 4c08cd7

File tree

2 files changed

+180
-0
lines changed

2 files changed

+180
-0
lines changed
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Classes to modify Predictor code to be compatible
14+
with version 2.0 and later of the SageMaker Python SDK.
15+
"""
16+
from __future__ import absolute_import
17+
18+
from sagemaker.cli.compatibility.v2.modifiers import matching
19+
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
20+
21+
BASE_S3_INPUT = "s3_input"
22+
SESSION = "session"
23+
S3_INPUT = {"s3_input": ("sagemaker", "sagemaker.session")}
24+
25+
26+
class TrainingInputConstructorRefactor(Modifier):
27+
"""A class to refactor *s3_input class."""
28+
29+
def node_should_be_modified(self, node):
30+
"""Checks if the ``ast.Call`` node instantiates a class of interest.
31+
32+
This looks for the following calls:
33+
34+
- ``sagemaker.s3_input``
35+
- ``sagemaker.session.s3_input``
36+
- ``s3_input``
37+
38+
Args:
39+
node (ast.Call): a node that represents a function call. For more,
40+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
41+
42+
Returns:
43+
bool: If the ``ast.Call`` instantiates a class of interest.
44+
"""
45+
return matching.matches_any(node, S3_INPUT)
46+
47+
def modify_node(self, node):
48+
"""Modifies the ``ast.Call`` node to call ``TrainingInput`` instead.
49+
50+
Args:
51+
node (ast.Call): a node that represents a *TrainingInput constructor.
52+
"""
53+
_rename_class(node)
54+
55+
56+
def _rename_class(node):
57+
"""Renames the s3_input class to TrainingInput"""
58+
if matching.matches_name(node, BASE_S3_INPUT):
59+
node.func.id = "TrainingInput"
60+
elif matching.matches_attr(node, BASE_S3_INPUT):
61+
node.func.attr = "TrainingInput"
62+
63+
64+
class TrainingInputImportFromRenamer(Modifier):
65+
"""A class to update import statements of ``s3_input``."""
66+
67+
def node_should_be_modified(self, node):
68+
"""Checks if the import statement imports ``RealTimePredictor`` from the correct module.
69+
70+
Args:
71+
node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
72+
For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.
73+
74+
Returns:
75+
bool: If the import statement imports ``RealTimePredictor`` from the correct module.
76+
"""
77+
return node.module in S3_INPUT[BASE_S3_INPUT] and any(
78+
name.name == BASE_S3_INPUT for name in node.names
79+
)
80+
81+
def modify_node(self, node):
82+
"""Changes the ``ast.ImportFrom`` node's name from ``s3_input`` to ``TrainingInput``.
83+
84+
Args:
85+
node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
86+
For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.
87+
"""
88+
for name in node.names:
89+
if name.name == BASE_S3_INPUT:
90+
name.name = "TrainingInput"
91+
elif name.name == "session":
92+
name.name = "inputs"
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pasta
16+
import pytest
17+
18+
from sagemaker.cli.compatibility.v2.modifiers import training_input
19+
from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call, ast_import
20+
21+
22+
@pytest.fixture
23+
def constructors():
24+
return (
25+
"sagemaker.session.s3_input(s3_data='s3://a')",
26+
"sagemaker.s3_input(s3_data='s3://a')",
27+
"s3_input(s3_data='s3://a')",
28+
)
29+
30+
31+
@pytest.fixture
32+
def import_statements():
33+
return (
34+
"from sagemaker.session import s3_input",
35+
"from sagemaker import s3_input",
36+
)
37+
38+
39+
def test_constructor_node_should_be_modified(constructors):
40+
modifier = training_input.TrainingInputConstructorRefactor()
41+
for constructor in constructors:
42+
node = ast_call(constructor)
43+
assert modifier.node_should_be_modified(node)
44+
45+
46+
def test_constructor_node_should_be_modified_random_call():
47+
modifier = training_input.TrainingInputConstructorRefactor()
48+
node = ast_call("FileSystemInput()")
49+
assert not modifier.node_should_be_modified(node)
50+
51+
52+
def test_constructor_modify_node():
53+
modifier = training_input.TrainingInputConstructorRefactor()
54+
55+
node = ast_call("s3_input(s3_data='s3://a')")
56+
modifier.modify_node(node)
57+
assert "TrainingInput(s3_data='s3://a')" == pasta.dump(node)
58+
59+
node = ast_call("sagemaker.s3_input(s3_data='s3://a')")
60+
modifier.modify_node(node)
61+
assert "sagemaker.TrainingInput(s3_data='s3://a')" == pasta.dump(node)
62+
63+
64+
def test_import_from_node_should_be_modified_training_input(import_statements):
65+
modifier = training_input.TrainingInputImportFromRenamer()
66+
for statement in import_statements:
67+
node = ast_import(statement)
68+
assert modifier.node_should_be_modified(node)
69+
70+
71+
def test_import_from_node_should_be_modified_random_import():
72+
modifier = training_input.TrainingInputImportFromRenamer()
73+
node = ast_import("from sagemaker import Session")
74+
assert not modifier.node_should_be_modified(node)
75+
76+
77+
def test_import_from_modify_node():
78+
modifier = training_input.TrainingInputImportFromRenamer()
79+
80+
node = ast_import("from sagemaker import s3_input")
81+
modifier.modify_node(node)
82+
expected_result = "from sagemaker import TrainingInput"
83+
assert expected_result == pasta.dump(node)
84+
85+
node = ast_import("from sagemaker.inputs import s3_input as training_input")
86+
modifier.modify_node(node)
87+
expected_result = "from sagemaker.inputs import TrainingInput as training_input"
88+
assert expected_result == pasta.dump(node)

0 commit comments

Comments
 (0)