Skip to content

feature: Update migration tool to support breaking changes to create_model #1800

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/sagemaker/cli/compatibility/v2/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
modifiers.training_input.TrainingInputConstructorRefactor(),
modifiers.training_input.ShuffleConfigModuleRenamer(),
modifiers.serde.SerdeConstructorRenamer(),
modifiers.serde.SerdeKeywordRemover(),
modifiers.image_uris.ImageURIRetrieveRefactor(),
]

Expand Down
42 changes: 42 additions & 0 deletions src/sagemaker/cli/compatibility/v2/modifiers/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,48 @@ def modify_node(self, node):
)


class SerdeKeywordRemover(Modifier):
"""A class to remove Serde-related keyword arguments from call expressions."""

def node_should_be_modified(self, node):
"""Checks if the ``ast.Call`` node uses deprecated keywords.

In particular, this function checks if:

- The ``ast.Call`` represents the ``create_model`` method.
- Either the serializer or deserializer keywords are used.

Args:
node (ast.Call): a node that represents a function call. For more,
see https://docs.python.org/3/library/ast.html#abstract-grammar.

Returns:
bool: If the ``ast.Call`` contains keywords that should be removed.
"""
if not isinstance(node.func, ast.Attribute) or node.func.attr != "create_model":
return False
Comment on lines +178 to +179
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optional - create_model is pretty generic, so you may want to check for common estimator variable names to avoid unintended collisions (e.g. https://github.com/aws/sagemaker-python-sdk/blob/zwei/src/sagemaker/cli/compatibility/v2/modifiers/renamed_params.py#L242)

however, serializer and deserializer are quite specific so maybe there's no need to worry here

return any(keyword.arg in {"serializer", "deserializer"} for keyword in node.keywords)

def modify_node(self, node):
"""Removes the serializer and deserializer keywords, as applicable.

Args:
node (ast.Call): a node that represents a ``create_model`` call.

Returns:
ast.Call: the node that represents a ``create_model`` call without
serializer or deserializers keywords.
"""
i = 0
while i < len(node.keywords):
keyword = node.keywords[i]
if keyword.arg in {"serializer", "deserializer"}:
node.keywords.pop(i)
else:
i += 1
return node


class SerdeObjectRenamer(Modifier):
"""A class to rename SerDe objects imported from ``sagemaker.predictor``."""

Expand Down
39 changes: 39 additions & 0 deletions tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,3 +346,42 @@ def test_deserializer_module_modify_node(src, expected):
node = pasta.parse(src)
modified_node = modifier.modify_node(node)
assert expected == pasta.dump(modified_node)


@pytest.mark.parametrize(
"src, expected",
[
('estimator.create_model(entry_point="inference.py")', False),
("estimator.create_model(serializer=CSVSerializer())", True),
("estimator.create_model(deserializer=CSVDeserializer())", True),
(
"estimator.create_model(serializer=CSVSerializer(), deserializer=CSVDeserializer())",
True,
),
("estimator.deploy(serializer=CSVSerializer())", False),
],
)
def test_create_model_call_node_should_be_modified(src, expected):
modifier = serde.SerdeKeywordRemover()
node = ast_call(src)
assert modifier.node_should_be_modified(node) is expected


@pytest.mark.parametrize(
"src, expected",
[
(
'estimator.create_model(entry_point="inference.py", serializer=CSVSerializer())',
'estimator.create_model(entry_point="inference.py")',
),
(
'estimator.create_model(entry_point="inference.py", deserializer=CSVDeserializer())',
'estimator.create_model(entry_point="inference.py")',
),
],
)
def test_create_model_call_modify_node(src, expected):
modifier = serde.SerdeKeywordRemover()
node = ast_call(src)
modified_node = modifier.modify_node(node)
assert expected == pasta.dump(modified_node)