Skip to content

Commit a309b62

Browse files
author
Balaji Veeramani
committed
Appease review comments
1 parent 2d9ce08 commit a309b62

File tree

2 files changed

+33
-25
lines changed

2 files changed

+33
-25
lines changed

src/sagemaker/cli/compatibility/v2/ast_transformer.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class ASTTransformer(ast.NodeTransformer):
6363

6464
def visit_Call(self, node):
6565
"""Visits an ``ast.Call`` node and returns a modified node or None.
66+
6667
See https://docs.python.org/3/library/ast.html#ast.NodeTransformer.
6768
6869
Args:
@@ -79,38 +80,41 @@ def visit_Call(self, node):
7980

8081
def visit_Name(self, node):
8182
"""Visits an ``ast.Name`` node and returns a modified node or None.
83+
8284
See https://docs.python.org/3/library/ast.html#ast.NodeTransformer.
8385
8486
Args:
8587
node (ast.Name): a node that represents an identifier.
8688
8789
Returns:
8890
ast.AST: if the returned node is None, the original node is removed
89-
from its location. Otherwise, the original node is replaced with the
90-
returned node.
91+
from its location. Otherwise, the original node is replaced with
92+
the returned node.
9193
"""
9294
for name_checker in NAME_MODIFIERS:
9395
node = name_checker.check_and_modify_node(node)
9496
return ast.fix_missing_locations(node) if node else None
9597

9698
def visit_Import(self, node):
9799
"""Visits an ``ast.Import`` node and returns a modified node or None.
100+
98101
See https://docs.python.org/3/library/ast.html#ast.NodeTransformer.
99102
100103
Args:
101104
node (ast.Import): a node that represents an import statement.
102105
103106
Returns:
104107
ast.AST: if the returned node is None, the original node is removed
105-
from its location. Otherwise, the original node is replaced with the
106-
returned node.
108+
from its location. Otherwise, the original node is replaced with
109+
the returned node.
107110
"""
108111
for import_checker in IMPORT_MODIFIERS:
109112
node = import_checker.check_and_modify_node(node)
110113
return ast.fix_missing_locations(node) if node else None
111114

112115
def visit_Module(self, node):
113116
"""Visits an ``ast.Module`` node and returns a modified node or None.
117+
114118
See https://docs.python.org/3/library/ast.html#ast.NodeTransformer.
115119
116120
The ``ast.NodeTransformer`` walks the abstract syntax tree and modifies
@@ -121,8 +125,8 @@ def visit_Module(self, node):
121125
122126
Returns:
123127
ast.AST: if the returned node is None, the original node is removed
124-
from its location. Otherwise, the original node is replaced with the
125-
returned node.
128+
from its location. Otherwise, the original node is replaced with
129+
the returned node.
126130
"""
127131
self.generic_visit(node)
128132
for module_checker in MODULE_MODIFIERS:
@@ -131,15 +135,16 @@ def visit_Module(self, node):
131135

132136
def visit_ImportFrom(self, node):
133137
"""Visits an ``ast.ImportFrom`` node and returns a modified node or None.
138+
134139
See https://docs.python.org/3/library/ast.html#ast.NodeTransformer.
135140
136141
Args:
137142
node (ast.ImportFrom): a node that represents an import statement.
138143
139144
Returns:
140145
ast.AST: if the returned node is None, the original node is removed
141-
from its location. Otherwise, the original node is replaced with the
142-
returned node.
146+
from its location. Otherwise, the original node is replaced with
147+
the returned node.
143148
"""
144149
for import_checker in IMPORT_FROM_MODIFIERS:
145150
node = import_checker.check_and_modify_node(node)

src/sagemaker/cli/compatibility/v2/modifiers/serde.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,32 +10,31 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""Classes to modify serializer and deserializer code to be compatible with
14-
version 2.0 and later of the SageMaker Python SDK.
15-
"""
13+
"""Classes to modify SerDe code to be compatibile with version 2.0 and later."""
1614
from __future__ import absolute_import
1715

1816
import ast
1917

20-
import pasta
21-
2218
from sagemaker.cli.compatibility.v2.modifiers import matching
2319
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
2420

2521
OLD_AMAZON_CLASS_NAMES = {"numpy_to_record_serializer", "record_deserializer"}
2622
NEW_AMAZON_CLASS_NAMES = {"RecordSerializer", "RecordDeserializer"}
23+
OLD_PREDICTOR_CLASS_NAMES = {
24+
"_CsvSerializer",
25+
"_JsonSerializer",
26+
"_NpySerializer",
27+
"_CsvDeserializer",
28+
"BytesDeserializer",
29+
"StringDeserializer",
30+
"StreamDeserializer",
31+
"_NumpyDeserializer",
32+
"_JsonDeserializer",
33+
}
2734

2835
# The values are tuples so that the object can be passed to matching.matches_any.
2936
OLD_CLASS_NAME_TO_NAMESPACES = {
30-
"_CsvSerializer": ("sagemaker.predictor",),
31-
"_JsonSerializer": ("sagemaker.predictor",),
32-
"_NpySerializer": ("sagemaker.predictor",),
33-
"_CsvDeserializer": ("sagemaker.predictor",),
34-
"BytesDeserializer": ("sagemaker.predictor",),
35-
"StringDeserializer": ("sagemaker.predictor",),
36-
"StreamDeserializer": ("sagemaker.predictor",),
37-
"_NumpyDeserializer": ("sagemaker.predictor",),
38-
"_JsonDeserializer": ("sagemaker.predictor",),
37+
class_name: ("sagemaker.predictor",) for class_name in OLD_PREDICTOR_CLASS_NAMES
3938
}
4039
OLD_CLASS_NAME_TO_NAMESPACES.update(
4140
{class_name: ("sagemaker.amazon.common",) for class_name in OLD_AMAZON_CLASS_NAMES}
@@ -205,7 +204,7 @@ def modify_node(self, node):
205204
object_name = node.id if isinstance(node, ast.Name) else node.attr
206205
new_class_name = OLD_OBJECT_NAME_TO_NEW_CLASS_NAME[object_name]
207206
namespace_name = NEW_CLASS_NAME_TO_NAMESPACES[new_class_name][0]
208-
subpackage_name = namespace_name[namespace_name.find(".") + 1 :]
207+
subpackage_name = namespace_name.split(".")[1]
209208
return ast.Call(
210209
func=ast.Attribute(value=ast.Name(id=subpackage_name), attr=new_class_name),
211210
args=[],
@@ -375,7 +374,9 @@ def __init__(self):
375374
for class_name in NEW_CLASS_NAMES - NEW_AMAZON_CLASS_NAMES
376375
if "Serializer" in class_name
377376
}
378-
import_node = pasta.parse("from sagemaker import serializers\n").body[0]
377+
import_node = ast.ImportFrom(
378+
module="sagemaker", names=[ast.alias(name="serializers", asname=None)], level=0
379+
)
379380
super().__init__(class_names, import_node)
380381

381382

@@ -403,5 +404,7 @@ def __init__(self):
403404
for class_name in NEW_CLASS_NAMES - NEW_AMAZON_CLASS_NAMES
404405
if "Deserializer" in class_name
405406
}
406-
import_node = pasta.parse("from sagemaker import deserializers\n").body[0]
407+
import_node = ast.ImportFrom(
408+
module="sagemaker", names=[ast.alias(name="deserializers", asname=None)], level=0
409+
)
407410
super().__init__(class_names, import_node)

0 commit comments

Comments
 (0)