Skip to content

change: handle image_uri rename for estimators and models in v2 migration tool #1675

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/sagemaker/cli/compatibility/v2/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from sagemaker.cli.compatibility.v2 import modifiers

FUNCTION_CALL_MODIFIERS = [
modifiers.renamed_params.EstimatorImageURIRenamer(),
modifiers.renamed_params.ModelImageURIRenamer(),
modifiers.framework_version.FrameworkVersionEnforcer(),
modifiers.tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader(),
modifiers.tf_legacy_mode.TensorBoardParameterRemover(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier

FRAMEWORK_ARG = "framework_version"
IMAGE_ARG = "image_uri"
PY_ARG = "py_version"

FRAMEWORK_DEFAULTS = {
Expand Down Expand Up @@ -70,11 +71,8 @@ def node_should_be_modified(self, node):
bool: If the ``ast.Call`` is instantiating a framework class that
should specify ``framework_version``, but doesn't.
"""
if matching.matches_any(node, ESTIMATORS):
return _version_args_needed(node, "image_name")

if matching.matches_any(node, MODELS):
return _version_args_needed(node, "image")
if matching.matches_any(node, ESTIMATORS) or matching.matches_any(node, MODELS):
return _version_args_needed(node)

return False

Expand Down Expand Up @@ -169,13 +167,13 @@ def _framework_from_node(node):
return framework, is_model


def _version_args_needed(node, image_arg):
def _version_args_needed(node):
"""Determines if image_arg or version_arg was supplied

Applies similar logic as ``validate_version_or_image_args``
"""
# if image_arg is present, no need to supply version arguments
if matching.has_arg(node, image_arg):
if matching.has_arg(node, IMAGE_ARG):
return False

# if framework_version is None, need args
Expand Down
62 changes: 62 additions & 0 deletions src/sagemaker/cli/compatibility/v2/modifiers/renamed_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,65 @@ def node_should_be_modified(self, node):
return False

return super(S3SessionRenamer, self).node_should_be_modified(node)


class EstimatorImageURIRenamer(ParamRenamer):
"""A class to rename the ``image_name`` attribute to ``image_uri`` in estimators."""

@property
def calls_to_modify(self):
"""A dictionary mapping estimators with the ``image_name`` attribute to their
respective namespaces.
Comment on lines +157 to +158
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ignorable nit pep-0257

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I blame our 100-character line limit 😂

"""
return {
"Chainer": ("sagemaker.chainer", "sagemaker.chainer.estimator"),
"Estimator": ("sagemaker.estimator",),
"Framework": ("sagemaker.estimator",),
"MXNet": ("sagemaker.mxnet", "sagemaker.mxnet.estimator"),
"PyTorch": ("sagemaker.pytorch", "sagemaker.pytorch.estimator"),
"RLEstimator": ("sagemaker.rl", "sagemaker.rl.estimator"),
"SKLearn": ("sagemaker.sklearn", "sagemaker.sklearn.estimator"),
"TensorFlow": ("sagemaker.tensorflow", "sagemaker.tensorflow.estimator"),
"XGBoost": ("sagemaker.xgboost", "sagemaker.xgboost.estimator"),
}

@property
def old_param_name(self):
"""The previous name for the image URI argument."""
return "image_name"

@property
def new_param_name(self):
"""The new name for the image URI argument."""
return "image_uri"


class ModelImageURIRenamer(ParamRenamer):
"""A class to rename the ``image`` attribute to ``image_uri`` in models."""

@property
def calls_to_modify(self):
"""A dictionary mapping models with the ``image`` attribute to their
respective namespaces.
"""
return {
"ChainerModel": ("sagemaker.chainer", "sagemaker.chainer.model"),
"Model": ("sagemaker.model",),
"MultiDataModel": ("sagemaker.multidatamodel",),
"FrameworkModel": ("sagemaker.model",),
"MXNetModel": ("sagemaker.mxnet", "sagemaker.mxnet.model"),
"PyTorchModel": ("sagemaker.pytorch", "sagemaker.pytorch.model"),
"SKLearnModel": ("sagemaker.sklearn", "sagemaker.sklearn.model"),
"TensorFlowModel": ("sagemaker.tensorflow", "sagemaker.tensorflow.model"),
"XGBoostModel": ("sagemaker.xgboost", "sagemaker.xgboost.model"),
}

@property
def old_param_name(self):
"""The previous name for the image URI argument."""
return "image"

@property
def new_param_name(self):
"""The new name for the image URI argument."""
return "image_uri"
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,15 @@ def modify_node(self, node):
hp_key = self._hyperparameter_key_for_param(kw.arg)
additional_hps[hp_key] = kw.value
kw_to_remove.append(kw)
if kw.arg == "image_name":
if kw.arg == "image_uri":
add_image_uri = False

self._remove_keywords(node, kw_to_remove)
self._add_updated_hyperparameters(node, base_hps, additional_hps)

if add_image_uri:
image_uri = self._image_uri_from_args(node.keywords)
node.keywords.append(ast.keyword(arg="image_name", value=ast.Str(s=image_uri)))
node.keywords.append(ast.keyword(arg="image_uri", value=ast.Str(s=image_uri)))

node.keywords.append(ast.keyword(arg="model_dir", value=ast.NameConstant(value=False)))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _templates(self, model=False):
def _frameworks(self, versions=False, image=False):
keywords = dict()
if image:
keywords["image_name"] = "my:image"
keywords["image_uri"] = "my:image"
if versions:
keywords["framework_version"] = self.framework_version
keywords["py_version"] = self.py_version
Expand All @@ -66,7 +66,7 @@ def _frameworks(self, versions=False, image=False):
def _models(self, versions=False, image=False):
keywords = dict()
if image:
keywords["image"] = "my:image"
keywords["image_uri"] = "my:image"
if versions:
keywords["framework_version"] = self.framework_version
if self.py_version_for_model:
Expand Down
118 changes: 118 additions & 0 deletions tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_image_uri.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import pasta

from sagemaker.cli.compatibility.v2.modifiers import renamed_params
from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call

ESTIMATORS = {
"Chainer": ("sagemaker.chainer", "sagemaker.chainer.estimator"),
"Estimator": ("sagemaker.estimator",),
"Framework": ("sagemaker.estimator",),
"MXNet": ("sagemaker.mxnet", "sagemaker.mxnet.estimator"),
"PyTorch": ("sagemaker.pytorch", "sagemaker.pytorch.estimator"),
"RLEstimator": ("sagemaker.rl", "sagemaker.rl.estimator"),
"SKLearn": ("sagemaker.sklearn", "sagemaker.sklearn.estimator"),
"TensorFlow": ("sagemaker.tensorflow", "sagemaker.tensorflow.estimator"),
"XGBoost": ("sagemaker.xgboost", "sagemaker.xgboost.estimator"),
}

MODELS = {
"ChainerModel": ("sagemaker.chainer", "sagemaker.chainer.model"),
"Model": ("sagemaker.model",),
"MultiDataModel": ("sagemaker.multidatamodel",),
"FrameworkModel": ("sagemaker.model",),
"MXNetModel": ("sagemaker.mxnet", "sagemaker.mxnet.model"),
"PyTorchModel": ("sagemaker.pytorch", "sagemaker.pytorch.model"),
"SKLearnModel": ("sagemaker.sklearn", "sagemaker.sklearn.model"),
"TensorFlowModel": ("sagemaker.tensorflow", "sagemaker.tensorflow.model"),
"XGBoostModel": ("sagemaker.xgboost", "sagemaker.xgboost.model"),
}


def test_estimator_node_should_be_modified():
modifier = renamed_params.EstimatorImageURIRenamer()

for estimator, namespaces in ESTIMATORS.items():
call = "{}(image_name='my-image:latest')".format(estimator)
assert modifier.node_should_be_modified(ast_call(call))

for namespace in namespaces:
call = "{}.{}(image_name='my-image:latest')".format(namespace, estimator)
assert modifier.node_should_be_modified(ast_call(call))


def test_estimator_node_should_be_modified_no_distribution():
modifier = renamed_params.EstimatorImageURIRenamer()

for estimator, namespaces in ESTIMATORS.items():
call = "{}()".format(estimator)
assert not modifier.node_should_be_modified(ast_call(call))

for namespace in namespaces:
call = "{}.{}()".format(namespace, estimator)
assert not modifier.node_should_be_modified(ast_call(call))


def test_estimator_node_should_be_modified_random_function_call():
modifier = renamed_params.EstimatorImageURIRenamer()
assert not modifier.node_should_be_modified(ast_call("Session()"))


def test_estimator_modify_node():
node = ast_call("TensorFlow(image_name=my_image)")
modifier = renamed_params.EstimatorImageURIRenamer()
modifier.modify_node(node)

expected = "TensorFlow(image_uri=my_image)"
assert expected == pasta.dump(node)


def test_model_node_should_be_modified():
modifier = renamed_params.ModelImageURIRenamer()

for model, namespaces in MODELS.items():
call = "{}(image='my-image:latest')".format(model)
assert modifier.node_should_be_modified(ast_call(call))

for namespace in namespaces:
call = "{}.{}(image='my-image:latest')".format(namespace, model)
assert modifier.node_should_be_modified(ast_call(call))


def test_model_node_should_be_modified_no_distribution():
modifier = renamed_params.ModelImageURIRenamer()

for model, namespaces in MODELS.items():
call = "{}()".format(model)
assert not modifier.node_should_be_modified(ast_call(call))

for namespace in namespaces:
call = "{}.{}()".format(namespace, model)
assert not modifier.node_should_be_modified(ast_call(call))


def test_model_node_should_be_modified_random_function_call():
modifier = renamed_params.ModelImageURIRenamer()
assert not modifier.node_should_be_modified(ast_call("Session()"))


def test_model_modify_node():
node = ast_call("TensorFlowModel(image=my_image)")
modifier = renamed_params.ModelImageURIRenamer()
modifier.modify_node(node)

expected = "TensorFlowModel(image_uri=my_image)"
assert expected == pasta.dump(node)
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def test_modify_node_set_model_dir_and_image_name(create_image_uri, boto_session
node = ast_call(constructor)
modifier.modify_node(node)

assert "TensorFlow(image_name='{}', model_dir=False)".format(IMAGE_URI) == pasta.dump(node)
assert "TensorFlow(image_uri='{}', model_dir=False)".format(IMAGE_URI) == pasta.dump(node)
create_image_uri.assert_called_with(
REGION_NAME, "tensorflow", "ml.m4.xlarge", "1.11.0", "py2"
)
Expand All @@ -111,7 +111,7 @@ def test_modify_node_set_image_name_from_args(create_image_uri, boto_session):

expected_string = (
"TensorFlow(train_instance_type='ml.p2.xlarge', framework_version='1.4.0', "
"image_name='{}', model_dir=False)".format(IMAGE_URI)
"image_uri='{}', model_dir=False)".format(IMAGE_URI)
)
assert expected_string == pasta.dump(node)

Expand Down