Skip to content

Commit 060d716

Browse files
author
Chuyang Deng
committed
rename namespace
1 parent 99a9652 commit 060d716

File tree

2 files changed

+35
-4
lines changed

2 files changed

+35
-4
lines changed

src/sagemaker/cli/compatibility/v2/modifiers/training_input.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
"""
1616
from __future__ import absolute_import
1717

18+
import ast
19+
1820
from sagemaker.cli.compatibility.v2.modifiers import matching
1921
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
2022

@@ -53,6 +55,15 @@ def modify_node(self, node):
5355
node.func.id = "TrainingInput"
5456
elif matching.matches_attr(node, S3_INPUT_NAME):
5557
node.func.attr = "TrainingInput"
58+
_rename_namespace(node, "session")
59+
60+
61+
def _rename_namespace(node, name):
62+
"""Rename namespace ``session`` to ``inputs`` """
63+
if isinstance(node.func.value, ast.Attribute) and node.func.value.attr == name:
64+
node.func.value.attr = "inputs"
65+
elif isinstance(node.func.value, ast.Name) and node.func.value.id == name:
66+
node.func.value.id = "inputs"
5667

5768

5869
class TrainingInputImportFromRenamer(Modifier):
@@ -82,5 +93,5 @@ def modify_node(self, node):
8293
for name in node.names:
8394
if name.name == S3_INPUT_NAME:
8495
name.name = "TrainingInput"
85-
elif name.name == "session":
86-
name.name = "inputs"
96+
if node.module == "sagemaker.session":
97+
node.module = "sagemaker.inputs"

tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_training_input.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
def constructors():
2424
return (
2525
"sagemaker.session.s3_input(s3_data='s3://a')",
26+
"sagemaker.inputs.s3_input(s3_data='s3://a')",
2627
"sagemaker.s3_input(s3_data='s3://a')",
28+
"session.s3_input(s3_data='s3://a')",
29+
"inputs.s3_input(s3_data='s3://a')",
2730
"s3_input(s3_data='s3://a')",
2831
)
2932

@@ -32,6 +35,7 @@ def constructors():
3235
def import_statements():
3336
return (
3437
"from sagemaker.session import s3_input",
38+
"from sagemaker.inputs import s3_input",
3539
"from sagemaker import s3_input",
3640
)
3741

@@ -60,6 +64,22 @@ def test_constructor_modify_node():
6064
modifier.modify_node(node)
6165
assert "sagemaker.TrainingInput(s3_data='s3://a')" == pasta.dump(node)
6266

67+
node = ast_call("session.s3_input(s3_data='s3://a')")
68+
modifier.modify_node(node)
69+
assert "inputs.TrainingInput(s3_data='s3://a')" == pasta.dump(node)
70+
71+
node = ast_call("inputs.s3_input(s3_data='s3://a')")
72+
modifier.modify_node(node)
73+
assert "inputs.TrainingInput(s3_data='s3://a')" == pasta.dump(node)
74+
75+
node = ast_call("sagemaker.inputs.s3_input(s3_data='s3://a')")
76+
modifier.modify_node(node)
77+
assert "sagemaker.inputs.TrainingInput(s3_data='s3://a')" == pasta.dump(node)
78+
79+
node = ast_call("sagemaker.session.s3_input(s3_data='s3://a')")
80+
modifier.modify_node(node)
81+
assert "sagemaker.inputs.TrainingInput(s3_data='s3://a')" == pasta.dump(node)
82+
6383

6484
def test_import_from_node_should_be_modified_training_input(import_statements):
6585
modifier = training_input.TrainingInputImportFromRenamer()
@@ -70,7 +90,7 @@ def test_import_from_node_should_be_modified_training_input(import_statements):
7090

7191
def test_import_from_node_should_be_modified_random_import():
7292
modifier = training_input.TrainingInputImportFromRenamer()
73-
node = ast_import("from sagemaker import Session")
93+
node = ast_import("from sagemaker.session import Session")
7494
assert not modifier.node_should_be_modified(node)
7595

7696

@@ -89,5 +109,5 @@ def test_import_from_modify_node():
89109

90110
node = ast_import("from sagemaker.session import s3_input as training_input")
91111
modifier.modify_node(node)
92-
expected_result = "from sagemaker.session import TrainingInput as training_input"
112+
expected_result = "from sagemaker.inputs import TrainingInput as training_input"
93113
assert expected_result == pasta.dump(node)

0 commit comments

Comments
 (0)