22
22
from sagemaker .cli .compatibility .v2 .modifiers import matching
23
23
from sagemaker .cli .compatibility .v2 .modifiers .modifier import Modifier
24
24
25
+ OLD_AMAZON_CLASS_NAMES = {"numpy_to_record_serializer" , "record_deserializer" }
26
+ NEW_AMAZON_CLASS_NAMES = {"RecordSerializer" , "RecordDeserializer" }
27
+
25
28
# The values are tuples so that the object can be passed to matching.matches_any.
26
29
OLD_CLASS_NAME_TO_NAMESPACES = {
27
30
"_CsvSerializer" : ("sagemaker.predictor" ,),
33
36
"StreamDeserializer" : ("sagemaker.predictor" ,),
34
37
"_NumpyDeserializer" : ("sagemaker.predictor" ,),
35
38
"_JsonDeserializer" : ("sagemaker.predictor" ,),
36
- "numpy_to_record_serializer" : ("sagemaker.amazon.common" ,),
37
- "record_deserializer" : ("sagemaker.amazon.common" ,),
38
39
}
40
+ OLD_CLASS_NAMES_TO_NAMESPACES .update ({
41
+ class_name : ("sagemaker.amazon.common" ,) for class_name in OLD_AMAZON_CLASS_NAMES
42
+ })
39
43
40
44
# The values are tuples so that the object can be passed to matching.matches_any.
41
45
NEW_CLASS_NAME_TO_NAMESPACES = {
75
79
"numpy_deserializer" : "NumpyDeserializer" ,
76
80
}
77
81
78
- OLD_AMAZON_CLASS_NAMES = set (
79
- {
80
- class_name
81
- for class_name , namespaces in OLD_CLASS_NAME_TO_NAMESPACES .items ()
82
- if "sagemaker.amazon.common" in namespaces
83
- }
84
- )
85
- NEW_AMAZON_CLASS_NAMES = set (
86
- {
87
- class_name
88
- for class_name , namespaces in NEW_CLASS_NAME_TO_NAMESPACES .items ()
89
- if "sagemaker.amazon.common" in namespaces
90
- }
91
- )
92
-
93
82
NEW_CLASS_NAMES = set (OLD_CLASS_NAME_TO_NEW_CLASS_NAME .values ())
94
83
OLD_CLASS_NAMES = set (OLD_CLASS_NAME_TO_NEW_CLASS_NAME .keys ())
95
84
@@ -102,7 +91,7 @@ class SerdeConstructorRenamer(Modifier):
102
91
def node_should_be_modified (self , node ):
103
92
"""Checks if the ``ast.Call`` node instantiates a SerDe class.
104
93
105
- This looks for the following calls:
94
+ This looks for the following calls (both with and without namespaces) :
106
95
107
96
- ``sagemaker.predictor._CsvSerializer``
108
97
- ``sagemaker.predictor._JsonSerializer``
@@ -126,7 +115,9 @@ def node_should_be_modified(self, node):
126
115
return matching .matches_any (node , OLD_CLASS_NAME_TO_NAMESPACES )
127
116
128
117
def modify_node (self , node ):
129
- """Modifies the ``ast.Call`` node to use the classes for SerDe
118
+ """Updates the name and namespace of the ``ast.Call`` node, as applicable.
119
+
120
+ This method modifies the ``ast.Call`` node to use the SerDe classes
130
121
available in version 2.0 and later of the Python SDK:
131
122
132
123
- ``sagemaker.serializers.CSVSerializer``
@@ -232,7 +223,7 @@ def node_should_be_modified(self, node):
232
223
from the ``sagemaker.predictor`` module.
233
224
"""
234
225
return node .module == "sagemaker.predictor" and any (
235
- [ name .name in (OLD_CLASS_NAMES | OLD_OBJECT_NAMES ) for name in node .names ]
226
+ name .name in (OLD_CLASS_NAMES | OLD_OBJECT_NAMES ) for name in node .names
236
227
)
237
228
238
229
def modify_node (self , node ):
0 commit comments