Skip to content

Commit bbf6f2a

Browse files
author
Balaji Veeramani
committed
Address review comments
1 parent 6eab494 commit bbf6f2a

File tree

1 file changed

+11
-20
lines changed
  • src/sagemaker/cli/compatibility/v2/modifiers

1 file changed

+11
-20
lines changed

src/sagemaker/cli/compatibility/v2/modifiers/serde.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
from sagemaker.cli.compatibility.v2.modifiers import matching
2323
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
2424

25+
OLD_AMAZON_CLASS_NAMES = {"numpy_to_record_serializer", "record_deserializer"}
26+
NEW_AMAZON_CLASS_NAMES = {"RecordSerializer", "RecordDeserializer"}
27+
2528
# The values are tuples so that the object can be passed to matching.matches_any.
2629
OLD_CLASS_NAME_TO_NAMESPACES = {
2730
"_CsvSerializer": ("sagemaker.predictor",),
@@ -33,9 +36,10 @@
3336
"StreamDeserializer": ("sagemaker.predictor",),
3437
"_NumpyDeserializer": ("sagemaker.predictor",),
3538
"_JsonDeserializer": ("sagemaker.predictor",),
36-
"numpy_to_record_serializer": ("sagemaker.amazon.common",),
37-
"record_deserializer": ("sagemaker.amazon.common",),
3839
}
40+
OLD_CLASS_NAMES_TO_NAMESPACES.update({
41+
class_name: ("sagemaker.amazon.common",) for class_name in OLD_AMAZON_CLASS_NAMES
42+
})
3943

4044
# The values are tuples so that the object can be passed to matching.matches_any.
4145
NEW_CLASS_NAME_TO_NAMESPACES = {
@@ -75,21 +79,6 @@
7579
"numpy_deserializer": "NumpyDeserializer",
7680
}
7781

78-
OLD_AMAZON_CLASS_NAMES = set(
79-
{
80-
class_name
81-
for class_name, namespaces in OLD_CLASS_NAME_TO_NAMESPACES.items()
82-
if "sagemaker.amazon.common" in namespaces
83-
}
84-
)
85-
NEW_AMAZON_CLASS_NAMES = set(
86-
{
87-
class_name
88-
for class_name, namespaces in NEW_CLASS_NAME_TO_NAMESPACES.items()
89-
if "sagemaker.amazon.common" in namespaces
90-
}
91-
)
92-
9382
NEW_CLASS_NAMES = set(OLD_CLASS_NAME_TO_NEW_CLASS_NAME.values())
9483
OLD_CLASS_NAMES = set(OLD_CLASS_NAME_TO_NEW_CLASS_NAME.keys())
9584

@@ -102,7 +91,7 @@ class SerdeConstructorRenamer(Modifier):
10291
def node_should_be_modified(self, node):
10392
"""Checks if the ``ast.Call`` node instantiates a SerDe class.
10493
105-
This looks for the following calls:
94+
This looks for the following calls (both with and without namespaces):
10695
10796
- ``sagemaker.predictor._CsvSerializer``
10897
- ``sagemaker.predictor._JsonSerializer``
@@ -126,7 +115,9 @@ def node_should_be_modified(self, node):
126115
return matching.matches_any(node, OLD_CLASS_NAME_TO_NAMESPACES)
127116

128117
def modify_node(self, node):
129-
"""Modifies the ``ast.Call`` node to use the classes for SerDe
118+
"""Updates the name and namespace of the ``ast.Call`` node, as applicable.
119+
120+
This method modifies the ``ast.Call`` node to use the SerDe classes
130121
available in version 2.0 and later of the Python SDK:
131122
132123
- ``sagemaker.serializers.CSVSerializer``
@@ -232,7 +223,7 @@ def node_should_be_modified(self, node):
232223
from the ``sagemaker.predictor`` module.
233224
"""
234225
return node.module == "sagemaker.predictor" and any(
235-
[name.name in (OLD_CLASS_NAMES | OLD_OBJECT_NAMES) for name in node.names]
226+
name.name in (OLD_CLASS_NAMES | OLD_OBJECT_NAMES) for name in node.names
236227
)
237228

238229
def modify_node(self, node):

0 commit comments

Comments
 (0)