10
10
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
11
# ANY KIND, either express or implied. See the License for the specific
12
12
# language governing permissions and limitations under the License.
13
- """Classes to modify Predictor code to be compatible
13
+ """Classes to modify TrainingInput code to be compatible
14
14
with version 2.0 and later of the SageMaker Python SDK.
15
15
"""
16
16
from __future__ import absolute_import
17
17
18
18
from sagemaker .cli .compatibility .v2 .modifiers import matching
19
19
from sagemaker .cli .compatibility .v2 .modifiers .modifier import Modifier
20
20
21
- BASE_S3_INPUT = "s3_input"
22
- SESSION = "session"
23
- S3_INPUT = {"s3_input" : ("sagemaker" , "sagemaker.session" )}
21
+ S3_INPUT_NAME = "s3_input"
22
+ S3_INPUT_NAMESPACES = ("sagemaker" , "sagemaker.inputs" , "sagemaker.session" )
24
23
25
24
26
25
class TrainingInputConstructorRefactor (Modifier ):
@@ -42,40 +41,35 @@ def node_should_be_modified(self, node):
42
41
Returns:
43
42
bool: If the ``ast.Call`` instantiates a class of interest.
44
43
"""
45
- return matching .matches_any (node , S3_INPUT )
44
+ return matching .matches_name_or_namespaces (node , S3_INPUT_NAME , S3_INPUT_NAMESPACES )
46
45
47
46
def modify_node (self , node ):
48
47
"""Modifies the ``ast.Call`` node to call ``TrainingInput`` instead.
49
48
50
49
Args:
51
50
node (ast.Call): a node that represents a *TrainingInput constructor.
52
51
"""
53
- _rename_class (node )
54
-
55
-
56
- def _rename_class (node ):
57
- """Renames the s3_input class to TrainingInput"""
58
- if matching .matches_name (node , BASE_S3_INPUT ):
59
- node .func .id = "TrainingInput"
60
- elif matching .matches_attr (node , BASE_S3_INPUT ):
61
- node .func .attr = "TrainingInput"
52
+ if matching .matches_name (node , S3_INPUT_NAME ):
53
+ node .func .id = "TrainingInput"
54
+ elif matching .matches_attr (node , S3_INPUT_NAME ):
55
+ node .func .attr = "TrainingInput"
62
56
63
57
64
58
class TrainingInputImportFromRenamer (Modifier ):
65
59
"""A class to update import statements of ``s3_input``."""
66
60
67
61
def node_should_be_modified (self , node ):
68
- """Checks if the import statement imports ``RealTimePredictor `` from the correct module.
62
+ """Checks if the import statement imports ``s3_input `` from the correct module.
69
63
70
64
Args:
71
65
node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
72
66
For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.
73
67
74
68
Returns:
75
- bool: If the import statement imports ``RealTimePredictor `` from the correct module.
69
+ bool: If the import statement imports ``s3_input `` from the correct module.
76
70
"""
77
- return node .module in S3_INPUT [ BASE_S3_INPUT ] and any (
78
- name .name == BASE_S3_INPUT for name in node .names
71
+ return node .module in S3_INPUT_NAMESPACES and any (
72
+ name .name == S3_INPUT_NAME for name in node .names
79
73
)
80
74
81
75
def modify_node (self , node ):
@@ -86,7 +80,7 @@ def modify_node(self, node):
86
80
For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.
87
81
"""
88
82
for name in node .names :
89
- if name .name == BASE_S3_INPUT :
83
+ if name .name == S3_INPUT_NAME :
90
84
name .name = "TrainingInput"
91
85
elif name .name == "session" :
92
86
name .name = "inputs"
0 commit comments