Skip to content

Commit cfb7c2e

Browse files
author
Chuyang Deng
committed
modify namespaces
1 parent e1a0eed commit cfb7c2e

File tree

4 files changed

+21
-19
lines changed

4 files changed

+21
-19
lines changed

src/sagemaker/cli/compatibility/v2/ast_transformer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,15 @@
3434
modifiers.renamed_params.SessionCreateModelImageURIRenamer(),
3535
modifiers.renamed_params.SessionCreateEndpointImageURIRenamer(),
3636
modifiers.training_params.TrainPrefixRemover(),
37+
modifiers.training_input.TrainingInputConstructorRefactor(),
3738
]
3839

3940
IMPORT_MODIFIERS = [modifiers.tfs.TensorFlowServingImportRenamer()]
4041

4142
IMPORT_FROM_MODIFIERS = [
4243
modifiers.predictors.PredictorImportFromRenamer(),
4344
modifiers.tfs.TensorFlowServingImportFromRenamer(),
45+
modifiers.training_input.TrainingInputImportFromRenamer(),
4446
]
4547

4648

src/sagemaker/cli/compatibility/v2/modifiers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,5 @@
2222
tf_legacy_mode,
2323
tfs,
2424
training_params,
25+
training_input,
2526
)

src/sagemaker/cli/compatibility/v2/modifiers/training_input.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,16 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""Classes to modify Predictor code to be compatible
13+
"""Classes to modify TrainingInput code to be compatible
1414
with version 2.0 and later of the SageMaker Python SDK.
1515
"""
1616
from __future__ import absolute_import
1717

1818
from sagemaker.cli.compatibility.v2.modifiers import matching
1919
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
2020

21-
BASE_S3_INPUT = "s3_input"
22-
SESSION = "session"
23-
S3_INPUT = {"s3_input": ("sagemaker", "sagemaker.session")}
21+
S3_INPUT_NAME = "s3_input"
22+
S3_INPUT_NAMESPACES = ("sagemaker", "sagemaker.inputs", "sagemaker.session")
2423

2524

2625
class TrainingInputConstructorRefactor(Modifier):
@@ -42,40 +41,35 @@ def node_should_be_modified(self, node):
4241
Returns:
4342
bool: If the ``ast.Call`` instantiates a class of interest.
4443
"""
45-
return matching.matches_any(node, S3_INPUT)
44+
return matching.matches_name_or_namespaces(node, S3_INPUT_NAME, S3_INPUT_NAMESPACES)
4645

4746
def modify_node(self, node):
4847
"""Modifies the ``ast.Call`` node to call ``TrainingInput`` instead.
4948
5049
Args:
5150
node (ast.Call): a node that represents a *TrainingInput constructor.
5251
"""
53-
_rename_class(node)
54-
55-
56-
def _rename_class(node):
57-
"""Renames the s3_input class to TrainingInput"""
58-
if matching.matches_name(node, BASE_S3_INPUT):
59-
node.func.id = "TrainingInput"
60-
elif matching.matches_attr(node, BASE_S3_INPUT):
61-
node.func.attr = "TrainingInput"
52+
if matching.matches_name(node, S3_INPUT_NAME):
53+
node.func.id = "TrainingInput"
54+
elif matching.matches_attr(node, S3_INPUT_NAME):
55+
node.func.attr = "TrainingInput"
6256

6357

6458
class TrainingInputImportFromRenamer(Modifier):
6559
"""A class to update import statements of ``s3_input``."""
6660

6761
def node_should_be_modified(self, node):
68-
"""Checks if the import statement imports ``RealTimePredictor`` from the correct module.
62+
"""Checks if the import statement imports ``s3_input`` from the correct module.
6963
7064
Args:
7165
node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
7266
For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.
7367
7468
Returns:
75-
bool: If the import statement imports ``RealTimePredictor`` from the correct module.
69+
bool: If the import statement imports ``s3_input`` from the correct module.
7670
"""
77-
return node.module in S3_INPUT[BASE_S3_INPUT] and any(
78-
name.name == BASE_S3_INPUT for name in node.names
71+
return node.module in S3_INPUT_NAMESPACES and any(
72+
name.name == S3_INPUT_NAME for name in node.names
7973
)
8074

8175
def modify_node(self, node):
@@ -86,7 +80,7 @@ def modify_node(self, node):
8680
For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.
8781
"""
8882
for name in node.names:
89-
if name.name == BASE_S3_INPUT:
83+
if name.name == S3_INPUT_NAME:
9084
name.name = "TrainingInput"
9185
elif name.name == "session":
9286
name.name = "inputs"

tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_training_input.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,8 @@ def test_import_from_modify_node():
8686
modifier.modify_node(node)
8787
expected_result = "from sagemaker.inputs import TrainingInput as training_input"
8888
assert expected_result == pasta.dump(node)
89+
90+
node = ast_import("from sagemaker.session import s3_input as training_input")
91+
modifier.modify_node(node)
92+
expected_result = "from sagemaker.session import TrainingInput as training_input"
93+
assert expected_result == pasta.dump(node)

0 commit comments

Comments
 (0)