Skip to content

feature: add 1p algorithm image_uris migration tool #1792

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/sagemaker/cli/compatibility/v2/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
modifiers.training_input.TrainingInputConstructorRefactor(),
modifiers.training_input.ShuffleConfigModuleRenamer(),
modifiers.serde.SerdeConstructorRenamer(),
modifiers.image_uris.ImageURIRetrieveRefactor(),
]

IMPORT_MODIFIERS = [modifiers.tfs.TensorFlowServingImportRenamer()]
Expand All @@ -55,6 +56,7 @@
modifiers.training_input.ShuffleConfigImportFromRenamer(),
modifiers.serde.SerdeImportFromAmazonCommonRenamer(),
modifiers.serde.SerdeImportFromPredictorRenamer(),
modifiers.image_uris.ImageURIRetrieveImportFromRenamer(),
]


Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/cli/compatibility/v2/modifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@
tfs,
training_params,
training_input,
image_uris,
)
134 changes: 134 additions & 0 deletions src/sagemaker/cli/compatibility/v2/modifiers/image_uris.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Classes to modify image uri retrieve methods for Python SDK v2.0 and later."""
from __future__ import absolute_import

import ast

from sagemaker.cli.compatibility.v2.modifiers import matching
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier

GET_IMAGE_URI_NAME = "get_image_uri"
GET_IMAGE_URI_NAMESPACES = (
"sagemaker",
"sagemaker.amazon_estimator",
"sagemaker.amazon.amazon_estimator",
"amazon_estimator",
"amazon.amazon_estimator",
)


class ImageURIRetrieveRefactor(Modifier):
"""A class to refactor *get_image_uri() method."""

def node_should_be_modified(self, node):
"""Checks if the ``ast.Call`` node calls a function of interest.

This looks for the following calls:

- ``sagemaker.get_image_uri``
- ``sagemaker.amazon_estimator.get_image_uri``
- ``get_image_uri``

Args:
node (ast.Call): a node that represents a function call. For more,
see https://docs.python.org/3/library/ast.html#abstract-grammar.

Returns:
bool: If the ``ast.Call`` instantiates a class of interest.
"""
return matching.matches_name_or_namespaces(
node, GET_IMAGE_URI_NAME, GET_IMAGE_URI_NAMESPACES
)

def modify_node(self, node):
"""Modifies the ``ast.Call`` node to call ``image_uris.retrieve`` instead.
And switch the first two parameters from (region, repo) to (framework, region)

Args:
node (ast.Call): a node that represents a *image_uris.retrieve call.
"""
original_args = [None] * 3
for kw in node.keywords:
if kw.arg == "repo_name":
original_args[0] = ast.Str(kw.value.s)
elif kw.arg == "repo_region":
original_args[1] = ast.Str(kw.value.s)
elif kw.arg == "repo_version":
original_args[2] = ast.Str(kw.value.s)

if len(node.args) > 0:
original_args[1] = ast.Str(node.args[0].s)
if len(node.args) > 1:
original_args[0] = ast.Str(node.args[1].s)
if len(node.args) > 2:
original_args[2] = ast.Str(node.args[2].s)

args = []
for arg in original_args:
if arg:
args.append(arg)

func = node.func
has_sagemaker = False
while hasattr(func, "value"):
if hasattr(func.value, "id") and func.value.id == "sagemaker":
has_sagemaker = True
break
func = func.value

if has_sagemaker:
node.func = ast.Attribute(
value=ast.Attribute(attr="image_uris", value=ast.Name(id="sagemaker")),
attr="retrieve",
)
else:
node.func = ast.Attribute(value=ast.Name(id="image_uris"), attr="retrieve")
node.args = args
node.keywords = []
return node


class ImageURIRetrieveImportFromRenamer(Modifier):
"""A class to update import statements of ``get_image_uri``."""

def node_should_be_modified(self, node):
"""Checks if the import statement imports ``get_image_uri`` from the correct module.

Args:
node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.

Returns:
bool: If the import statement imports ``get_image_uri`` from the correct module.
"""
return node.module in GET_IMAGE_URI_NAMESPACES and any(
name.name == GET_IMAGE_URI_NAME for name in node.names
)

def modify_node(self, node):
"""Changes the ``ast.ImportFrom`` node's name from ``get_image_uri`` to ``image_uris``.

Args:
node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.

Returns:
ast.AST: the original node, which has been potentially modified.
"""
for name in node.names:
if name.name == GET_IMAGE_URI_NAME:
name.name = "image_uris"
if node.module in GET_IMAGE_URI_NAMESPACES:
node.module = "sagemaker"
return node
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import pasta
import pytest

from sagemaker.cli.compatibility.v2.modifiers import image_uris
from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call, ast_import


@pytest.fixture
def methods():
return (
"get_image_uri('us-west-2', 'sagemaker-xgboost')",
"sagemaker.get_image_uri(repo_region='us-west-2', repo_name='sagemaker-xgboost')",
"sagemaker.amazon_estimator.get_image_uri('us-west-2', repo_name='sagemaker-xgboost')",
"sagemaker.amazon.amazon_estimator.get_image_uri('us-west-2', 'sagemaker-xgboost', repo_version='1')",
)


@pytest.fixture
def import_statements():
return (
"from sagemaker import get_image_uri",
"from sagemaker.amazon_estimator import get_image_uri",
"from sagemaker.amazon.amazon_estimator import get_image_uri",
)


def test_method_node_should_be_modified(methods):
modifier = image_uris.ImageURIRetrieveRefactor()
for method in methods:
node = ast_call(method)
assert modifier.node_should_be_modified(node)


def test_methodnode_should_be_modified_random_call():
modifier = image_uris.ImageURIRetrieveRefactor()
node = ast_call("create_image_uri()")
assert not modifier.node_should_be_modified(node)


def test_method_modify_node(methods, caplog):
modifier = image_uris.ImageURIRetrieveRefactor()

method = "get_image_uri('us-west-2', 'xgboost')"
node = ast_call(method)
modifier.modify_node(node)
assert "image_uris.retrieve('xgboost', 'us-west-2')" == pasta.dump(node)

method = "amazon_estimator.get_image_uri('us-west-2', 'xgboost')"
node = ast_call(method)
modifier.modify_node(node)
assert "image_uris.retrieve('xgboost', 'us-west-2')" == pasta.dump(node)

method = "sagemaker.get_image_uri(repo_region='us-west-2', repo_name='xgboost')"
node = ast_call(method)
modifier.modify_node(node)
assert "sagemaker.image_uris.retrieve('xgboost', 'us-west-2')" == pasta.dump(node)

method = "sagemaker.amazon_estimator.get_image_uri('us-west-2', repo_name='xgboost')"
node = ast_call(method)
modifier.modify_node(node)
assert "sagemaker.image_uris.retrieve('xgboost', 'us-west-2')" == pasta.dump(node)

method = (
"sagemaker.amazon.amazon_estimator.get_image_uri('us-west-2', 'xgboost', repo_version='1')"
)
node = ast_call(method)
modifier.modify_node(node)
assert "sagemaker.image_uris.retrieve('xgboost', 'us-west-2', '1')" == pasta.dump(node)


def test_import_from_node_should_be_modified_image_uris_input(import_statements):
modifier = image_uris.ImageURIRetrieveImportFromRenamer()

statement = "from sagemaker import get_image_uri"
node = ast_import(statement)
assert modifier.node_should_be_modified(node)

statement = "from sagemaker.amazon_estimator import get_image_uri"
node = ast_import(statement)
assert modifier.node_should_be_modified(node)

statement = "from sagemaker.amazon.amazon_estimator import get_image_uri"
node = ast_import(statement)
assert modifier.node_should_be_modified(node)


def test_import_from_node_should_be_modified_random_import():
modifier = image_uris.ImageURIRetrieveImportFromRenamer()
node = ast_import("from sagemaker.amazon_estimator import registry")
assert not modifier.node_should_be_modified(node)


def test_import_from_modify_node(import_statements):
modifier = image_uris.ImageURIRetrieveImportFromRenamer()
expected_result = "from sagemaker import image_uris"

for import_statement in import_statements:
node = ast_import(import_statement)
modifier.modify_node(node)
assert expected_result == pasta.dump(node)