Skip to content

Commit 3b14a16

Browse files
author
Balaji Veeramani
committed
Update migration tool
1 parent c4bb695 commit 3b14a16

File tree

3 files changed

+82
-0
lines changed

3 files changed

+82
-0
lines changed

src/sagemaker/cli/compatibility/v2/ast_transformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
modifiers.training_input.TrainingInputConstructorRefactor(),
3838
modifiers.training_input.ShuffleConfigModuleRenamer(),
3939
modifiers.serde.SerdeConstructorRenamer(),
40+
modifiers.serde.SerdeKeywordRemover(),
4041
]
4142

4243
IMPORT_MODIFIERS = [modifiers.tfs.TensorFlowServingImportRenamer()]

src/sagemaker/cli/compatibility/v2/modifiers/serde.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,48 @@ def modify_node(self, node):
157157
)
158158

159159

160+
class SerdeKeywordRemover(Modifier):
161+
"""A class to remove Serde-related keyword arguments from call expressions."""
162+
163+
def node_should_be_modified(self, node):
164+
"""Checks if the ``ast.Call`` node uses deprecated keywords.
165+
166+
In particular, this function checks if:
167+
168+
- The ``ast.Call`` represents the ``create_model`` method.
169+
- Either the serializer or deserializer keywords are used.
170+
171+
Args:
172+
node (ast.Call): a node that represents a function call. For more,
173+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
174+
175+
Returns:
176+
bool: If the ``ast.Call`` contains keywords that should be removed.
177+
"""
178+
if not isinstance(node.func, ast.Attribute) or node.func.attr != "create_model":
179+
return False
180+
return any(keyword.arg in {"serializer", "deserializer"} for keyword in node.keywords)
181+
182+
def modify_node(self, node):
183+
"""Removes the serializer and deserializer keywords, as applicable.
184+
185+
Args:
186+
node (ast.Call): a node that represents a ``create_model`` call.
187+
188+
Returns:
189+
ast.Call: the node that represents a ``create_model`` call without
190+
serializer or deserializers keywords.
191+
"""
192+
i = 0
193+
while i < len(node.keywords):
194+
keyword = node.keywords[i]
195+
if keyword.arg in {"serializer", "deserializer"}:
196+
node.keywords.pop(i)
197+
else:
198+
i += 1
199+
return node
200+
201+
160202
class SerdeObjectRenamer(Modifier):
161203
"""A class to rename SerDe objects imported from ``sagemaker.predictor``."""
162204

tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_serde.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,3 +346,42 @@ def test_deserializer_module_modify_node(src, expected):
346346
node = pasta.parse(src)
347347
modified_node = modifier.modify_node(node)
348348
assert expected == pasta.dump(modified_node)
349+
350+
351+
@pytest.mark.parametrize(
352+
"src, expected",
353+
[
354+
('estimator.create_model(entry_point="inference.py")', False),
355+
("estimator.create_model(serializer=CSVSerializer())", True),
356+
("estimator.create_model(deserializer=CSVDeserializer())", True),
357+
(
358+
"estimator.create_model(serializer=CSVSerializer(), deserializer=CSVDeserializer())",
359+
True,
360+
),
361+
("estimator.deploy(serializer=CSVSerializer())", False),
362+
],
363+
)
364+
def test_create_model_call_node_should_be_modified(src, expected):
365+
modifier = serde.SerdeKeywordRemover()
366+
node = ast_call(src)
367+
assert modifier.node_should_be_modified(node) is expected
368+
369+
370+
@pytest.mark.parametrize(
371+
"src, expected",
372+
[
373+
(
374+
'estimator.create_model(entry_point="inference.py", serializer=CSVSerializer())',
375+
'estimator.create_model(entry_point="inference.py")',
376+
),
377+
(
378+
'estimator.create_model(entry_point="inference.py", deserializer=CSVDeserializer())',
379+
'estimator.create_model(entry_point="inference.py")',
380+
),
381+
],
382+
)
383+
def test_create_model_call_modify_node(src, expected):
384+
modifier = serde.SerdeKeywordRemover()
385+
node = ast_call(src)
386+
modified_node = modifier.modify_node(node)
387+
assert expected == pasta.dump(modified_node)

0 commit comments

Comments
 (0)