23
23
def constructors ():
24
24
return (
25
25
"sagemaker.session.s3_input(s3_data='s3://a')" ,
26
+ "sagemaker.inputs.s3_input(s3_data='s3://a')" ,
26
27
"sagemaker.s3_input(s3_data='s3://a')" ,
28
+ "session.s3_input(s3_data='s3://a')" ,
29
+ "inputs.s3_input(s3_data='s3://a')" ,
27
30
"s3_input(s3_data='s3://a')" ,
28
31
)
29
32
@@ -32,6 +35,7 @@ def constructors():
32
35
def import_statements ():
33
36
return (
34
37
"from sagemaker.session import s3_input" ,
38
+ "from sagemaker.inputs import s3_input" ,
35
39
"from sagemaker import s3_input" ,
36
40
)
37
41
@@ -60,6 +64,22 @@ def test_constructor_modify_node():
60
64
modifier .modify_node (node )
61
65
assert "sagemaker.TrainingInput(s3_data='s3://a')" == pasta .dump (node )
62
66
67
+ node = ast_call ("session.s3_input(s3_data='s3://a')" )
68
+ modifier .modify_node (node )
69
+ assert "inputs.TrainingInput(s3_data='s3://a')" == pasta .dump (node )
70
+
71
+ node = ast_call ("inputs.s3_input(s3_data='s3://a')" )
72
+ modifier .modify_node (node )
73
+ assert "inputs.TrainingInput(s3_data='s3://a')" == pasta .dump (node )
74
+
75
+ node = ast_call ("sagemaker.inputs.s3_input(s3_data='s3://a')" )
76
+ modifier .modify_node (node )
77
+ assert "sagemaker.inputs.TrainingInput(s3_data='s3://a')" == pasta .dump (node )
78
+
79
+ node = ast_call ("sagemaker.session.s3_input(s3_data='s3://a')" )
80
+ modifier .modify_node (node )
81
+ assert "sagemaker.inputs.TrainingInput(s3_data='s3://a')" == pasta .dump (node )
82
+
63
83
64
84
def test_import_from_node_should_be_modified_training_input (import_statements ):
65
85
modifier = training_input .TrainingInputImportFromRenamer ()
@@ -70,7 +90,7 @@ def test_import_from_node_should_be_modified_training_input(import_statements):
70
90
71
91
def test_import_from_node_should_be_modified_random_import ():
72
92
modifier = training_input .TrainingInputImportFromRenamer ()
73
- node = ast_import ("from sagemaker import Session" )
93
+ node = ast_import ("from sagemaker.session import Session" )
74
94
assert not modifier .node_should_be_modified (node )
75
95
76
96
@@ -89,5 +109,5 @@ def test_import_from_modify_node():
89
109
90
110
node = ast_import ("from sagemaker.session import s3_input as training_input" )
91
111
modifier .modify_node (node )
92
- expected_result = "from sagemaker.session import TrainingInput as training_input"
112
+ expected_result = "from sagemaker.inputs import TrainingInput as training_input"
93
113
assert expected_result == pasta .dump (node )
0 commit comments