Skip to content

change: update migration tool for S3 utility functions #1665

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/sagemaker/cli/compatibility/v2/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
modifiers.tfs.TensorFlowServingConstructorRenamer(),
modifiers.predictors.PredictorConstructorRefactor(),
modifiers.airflow.ModelConfigArgModifier(),
modifiers.estimators.DistributionParameterRenamer(),
modifiers.renamed_params.DistributionParameterRenamer(),
modifiers.renamed_params.S3SessionRenamer(),
]

IMPORT_MODIFIERS = [modifiers.tfs.TensorFlowServingImportRenamer()]
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/cli/compatibility/v2/modifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from sagemaker.cli.compatibility.v2.modifiers import ( # noqa: F401 (imported but unused)
airflow,
deprecated_params,
estimators,
framework_version,
predictors,
renamed_params,
tf_legacy_mode,
tfs,
)
72 changes: 0 additions & 72 deletions src/sagemaker/cli/compatibility/v2/modifiers/estimators.py

This file was deleted.

169 changes: 169 additions & 0 deletions src/sagemaker/cli/compatibility/v2/modifiers/renamed_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Classes to modify Predictor code to be compatible
with version 2.0 and later of the SageMaker Python SDK.
"""
from __future__ import absolute_import

import ast
from abc import abstractmethod

from sagemaker.cli.compatibility.v2.modifiers import matching
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier


class ParamRenamer(Modifier):
"""Abstract class to take in an AST node, check if it is a function call with
an argument that needs to be renamed, and rename the argument if needed.
"""

@property
@abstractmethod
def calls_to_modify(self):
"""A dictionary mapping function names to possible namespaces."""

@property
@abstractmethod
def old_param_name(self):
"""The parameter name used in previous versions of the SageMaker Python SDK."""

@property
@abstractmethod
def new_param_name(self):
"""The parameter name used in version 2.0 and later of the SageMaker Python SDK."""

def node_should_be_modified(self, node):
"""Checks if the node matches any of the relevant functions and
contains the parameter to be renamed.

Args:
node (ast.Call): a node that represents a function call. For more,
see https://docs.python.org/3/library/ast.html#abstract-grammar.

Returns:
bool: If the ``ast.Call`` matches the relevant function calls and
contains the parameter to be renamed.
"""
return matching.matches_any(node, self.calls_to_modify) and self._has_param_to_rename(node)

def _has_param_to_rename(self, node):
"""Checks if the call has the argument that needs to be renamed."""
return _keyword_from_keywords(node, self.old_param_name) is not None

def modify_node(self, node):
"""Modifies the ``ast.Call`` node to rename the attribute.

Args:
node (ast.Call): a node that represents the relevant function call.
"""
keyword = _keyword_from_keywords(node, self.old_param_name)
keyword.arg = self.new_param_name


def _keyword_from_keywords(node, param_name):
"""Retrieves a keyword argument from the node's keywords.

Args:
node (ast.Call): a node that represents a function call. For more,
see https://docs.python.org/3/library/ast.html#abstract-grammar.
param_name (str): the name of the argument.

Returns:
ast.keyword: the keyword argument if it is present. Otherwise, this returns ``None``.
"""
for kw in node.keywords:
if kw.arg == param_name:
return kw

return None


class DistributionParameterRenamer(ParamRenamer):
"""A class to rename the ``distributions`` attribute to ``distrbution`` in
MXNet and TensorFlow estimators.

This looks for the following calls:

- ``<Framework>``
- ``sagemaker.<framework>.<Framework>``
- ``sagemaker.<framework>.estimator.<Framework>``

where ``<Framework>`` is either ``TensorFlow`` or ``MXNet``.
"""

@property
def calls_to_modify(self):
"""A dictionary mapping ``MXNet`` and ``TensorFlow`` to their respective namespaces."""
return {
"TensorFlow": ("sagemaker.tensorflow", "sagemaker.tensorflow.estimator"),
"MXNet": ("sagemaker.mxnet", "sagemaker.mxnet.estimator"),
}

@property
def old_param_name(self):
"""The previous name for the distribution argument."""
return "distributions"

@property
def new_param_name(self):
"""The new name for the distribution argument."""
return "distribution"


class S3SessionRenamer(ParamRenamer):
"""A class to rename the ``session`` attribute to ``sagemaker_session`` in
``S3Uploader`` and ``S3Downloader``.

This looks for the following calls:

- ``sagemaker.s3.S3Uploader.<function>``
- ``s3.S3Uploader.<function>``
- ``S3Uploader.<function>``

where ``S3Uploader`` is either ``S3Uploader`` or ``S3Downloader``, and where
``<function>`` is any of the functions belonging to those two classes.
"""

@property
def calls_to_modify(self):
"""A dictionary mapping S3 utility functions to their respective namespaces."""
return {
"download": ("sagemaker.s3.S3Downloader", "s3.S3Downloader", "S3Downloader"),
"list": ("sagemaker.s3.S3Downloader", "s3.S3Downloader", "S3Downloader"),
"read_file": ("sagemaker.s3.S3Downloader", "s3.S3Downloader", "S3Downloader"),
"upload": ("sagemaker.s3.S3Uploader", "s3.S3Uploader", "S3Uploader"),
"upload_string_as_file_body": (
"sagemaker.s3.S3Uploader",
"s3.S3Uploader",
"S3Uploader",
),
}

@property
def old_param_name(self):
"""The previous name for the SageMaker session argument."""
return "session"

@property
def new_param_name(self):
"""The new name for the SageMaker session argument."""
return "sagemaker_session"

def node_should_be_modified(self, node):
"""Checks if the node is one of the S3 utility functions and
contains the ``session`` parameter.
"""
if isinstance(node.func, ast.Name):
return False

return super(S3SessionRenamer, self).node_should_be_modified(node)
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import pasta

from sagemaker.cli.compatibility.v2.modifiers import estimators
from sagemaker.cli.compatibility.v2.modifiers import renamed_params
from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call


Expand All @@ -28,7 +28,7 @@ def test_node_should_be_modified():
"sagemaker.mxnet.estimator.MXNet(distributions={})",
)

modifier = estimators.DistributionParameterRenamer()
modifier = renamed_params.DistributionParameterRenamer()

for call in constructors:
assert modifier.node_should_be_modified(ast_call(call))
Expand All @@ -44,20 +44,20 @@ def test_node_should_be_modified_no_distribution():
"sagemaker.mxnet.estimator.MXNet()",
)

modifier = estimators.DistributionParameterRenamer()
modifier = renamed_params.DistributionParameterRenamer()

for call in constructors:
assert not modifier.node_should_be_modified(ast_call(call))


def test_node_should_be_modified_random_function_call():
modifier = estimators.DistributionParameterRenamer()
modifier = renamed_params.DistributionParameterRenamer()
assert not modifier.node_should_be_modified(ast_call("Session()"))


def test_modify_node():
node = ast_call("TensorFlow(distributions={'parameter_server': {'enabled': True}})")
modifier = estimators.DistributionParameterRenamer()
modifier = renamed_params.DistributionParameterRenamer()
modifier.modify_node(node)

expected = "TensorFlow(distribution={'parameter_server': {'enabled': True}})"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import pasta

from sagemaker.cli.compatibility.v2.modifiers import renamed_params
from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call

NAMESPACES = ("", "s3.", "sagemaker.s3.")
FUNCTIONS = (
"S3Downloader.download",
"S3Downloader.list",
"S3Downloader.read_file",
"S3Uploader.upload",
"S3Uploader.upload_string_as_file_body",
)


def test_node_should_be_modified():
modifier = renamed_params.S3SessionRenamer()

for func in FUNCTIONS:
for namespace in NAMESPACES:
call = ast_call("{}{}(session=sess)".format(namespace, func))
assert modifier.node_should_be_modified(call)


def test_node_should_be_modified_no_session():
modifier = renamed_params.S3SessionRenamer()

for func in FUNCTIONS:
for namespace in NAMESPACES:
call = ast_call("{}{}()".format(namespace, func))
assert not modifier.node_should_be_modified(call)


def test_node_should_be_modified_random_function_call():
modifier = renamed_params.S3SessionRenamer()

generic_function_calls = (
"download()",
"list()",
"read_file()",
"upload()",
)

for call in generic_function_calls:
assert not modifier.node_should_be_modified(ast_call(call))


def test_modify_node():
node = ast_call("S3Downloader.download(session=sess)")
modifier = renamed_params.S3SessionRenamer()
modifier.modify_node(node)

expected = "S3Downloader.download(sagemaker_session=sess)"
assert expected == pasta.dump(node)