Skip to content

fix: handle named variables in v2 migration tool #1702

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import ast

from packaging.version import InvalidVersion, Version
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1


from sagemaker.cli.compatibility.v2.modifiers import matching, parsing
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier

Expand Down Expand Up @@ -135,10 +137,15 @@ def _tf_py_version_default(framework_version):
"""Gets the py_version default based on framework_version for TensorFlow."""
if not framework_version:
return "py2"
version = [int(s) for s in framework_version.split(".")]
if version < [1, 12]:

try:
version = Version(framework_version)
except InvalidVersion:
return "py2"

if version < Version("1.12"):
return "py2"
if version < [2, 2]:
if version < Version("2.2"):
return "py3"
return "py37"

Expand Down Expand Up @@ -186,7 +193,6 @@ def _version_args_needed(node):
framework, is_model = _framework_from_node(node)
expecting_py_version = _py_version_defaults(framework, framework_version, is_model)
if expecting_py_version:
py_version = parsing.arg_value(node, PY_ARG)
return py_version is None
return not matching.has_arg(node, PY_ARG)

return False
24 changes: 15 additions & 9 deletions src/sagemaker/cli/compatibility/v2/modifiers/tf_legacy_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,11 @@ def _is_legacy_mode(self, node):

for kw in node.keywords:
if kw.arg == "script_mode":
script_mode = bool(kw.value.value)
script_mode = (
bool(kw.value.value) if isinstance(kw.value, ast.NameConstant) else True
)
if kw.arg == "py_version":
py_version = kw.value.s
py_version = kw.value.s if isinstance(kw.value, ast.Str) else "py3"

return not (py_version.startswith("py3") or script_mode)

Expand Down Expand Up @@ -124,7 +126,8 @@ def modify_node(self, node):

if add_image_uri:
image_uri = self._image_uri_from_args(node.keywords)
node.keywords.append(ast.keyword(arg="image_uri", value=ast.Str(s=image_uri)))
if image_uri:
node.keywords.append(ast.keyword(arg="image_uri", value=ast.Str(s=image_uri)))

node.keywords.append(ast.keyword(arg="model_dir", value=ast.NameConstant(value=False)))

Expand Down Expand Up @@ -155,19 +158,22 @@ def _to_ast_keyword(self, hps):
return None

def _image_uri_from_args(self, keywords):
"""Returns a legacy TensorFlow image URI based on the estimator arguments."""
"""Returns a legacy TensorFlow image URI based on the estimator arguments if possible."""
tf_version = framework_version.FRAMEWORK_DEFAULTS["TensorFlow"]
instance_type = "ml.m4.xlarge" # CPU default (exact type doesn't matter)

for kw in keywords:
if kw.arg == "framework_version":
tf_version = kw.value.s
tf_version = kw.value.s if isinstance(kw.value, ast.Str) else None
if kw.arg == "train_instance_type":
instance_type = kw.value.s
instance_type = kw.value.s if isinstance(kw.value, ast.Str) else None

return fw_utils.create_image_uri(
self.region, "tensorflow", instance_type, tf_version, "py2"
)
if tf_version and instance_type:
return fw_utils.create_image_uri(
self.region, "tensorflow", instance_type, tf_version, "py2"
)

return None


class TensorBoardParameterRemover(Modifier):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import pasta
import pytest

from sagemaker.cli.compatibility.v2.modifiers import framework_version
Expand All @@ -36,8 +35,10 @@ def __init__(
self.py_version = py_version
self.py_version_for_model = py_version_for_model

def constructors(self, versions=False, image=False):
return self._frameworks(versions, image) + self._models(versions, image)
def constructors(self, fw_version=False, py_version=False, image=False):
return self._frameworks(fw_version, py_version, image) + self._models(
fw_version, py_version, image
)

def _templates(self, model=False):
module = self.framework.lower()
Expand All @@ -54,30 +55,38 @@ def _templates(self, model=False):
for template in templates
)

def _frameworks(self, versions=False, image=False):
keywords = dict()
if image:
keywords["image_uri"] = "my:image"
if versions:
keywords["framework_version"] = self.framework_version
keywords["py_version"] = self.py_version
def _frameworks(self, fw_version=False, py_version=False, image=False):
keywords = self._base_keywords(fw_version, image)
if py_version:
keywords["py_version"] = (
"py_version" if py_version == "named" else "'{}'".format(self.py_version)
)
return _format_templates(keywords, self._templates())

def _models(self, versions=False, image=False):
def _models(self, fw_version=False, py_version=False, image=False):
keywords = self._base_keywords(fw_version, image)
if py_version and self.py_version_for_model:
keywords["py_version"] = (
"py_version" if py_version == "named" else "'{}'".format(self.py_version)
)
return _format_templates(keywords, self._templates(model=True))

def _base_keywords(self, fw_version=False, image=False):
keywords = dict()
if image:
keywords["image_uri"] = "my:image"
if versions:
keywords["framework_version"] = self.framework_version
if self.py_version_for_model:
keywords["py_version"] = self.py_version
return _format_templates(keywords, self._templates(model=True))
keywords["image_uri"] = "'my:image'"
if fw_version:
keywords["framework_version"] = (
"fw_version" if fw_version == "named" else "'{}'".format(self.framework_version)
)
return keywords


def _format_templates(keywords, templates):
args = ", ".join(
"{key}='{value}'".format(key=key, value=value) for key, value in keywords.items()
"{key}={value}".format(key=key, value=value) for key, value in keywords.items()
)

return [template.format(args) for template in templates]


Expand All @@ -100,8 +109,12 @@ def _format_templates(keywords, templates):
]


def constructors(versions=False, image=False):
return [ctr for template in TEMPLATES for ctr in template.constructors(versions, image)]
def constructors(fw_version=False, py_version=False, image=False):
return [
ctr
for template in TEMPLATES
for ctr in template.constructors(fw_version, py_version, image)
]


@pytest.fixture
Expand All @@ -110,18 +123,34 @@ def constructors_empty():


@pytest.fixture
def constructors_with_versions():
return constructors(versions=True)
def constructors_with_only_fw_version_that_need_py_version():
ctrs = []
for template in TEMPLATES:
if template.py_version_for_model:
ctrs.extend(template.constructors(fw_version=True))
else:
ctrs.extend(template._frameworks(fw_version=True))
return ctrs


@pytest.fixture
def constructors_with_image():
return constructors(image=True)
def constructors_with_only_fw_version():
return constructors(fw_version=True)


@pytest.fixture
def constructors_with_only_py_version():
return constructors(py_version=True)


@pytest.fixture
def constructors_with_both():
return constructors(versions=True, image=True)
def constructors_with_both_versions():
return constructors(fw_version=True, py_version=True)


@pytest.fixture
def constructors_with_image():
return constructors(image=True)


def _test_node_should_be_modified(ctrs, should_modify=True):
Expand All @@ -138,8 +167,20 @@ def test_node_should_be_modified_empty(constructors_empty):
_test_node_should_be_modified(constructors_empty, should_modify=True)


def test_node_should_be_modified_with_versions(constructors_with_versions):
_test_node_should_be_modified(constructors_with_versions, should_modify=False)
def test_node_should_be_modified_with_only_fw_versions(
constructors_with_only_fw_version_that_need_py_version,
):
_test_node_should_be_modified(
constructors_with_only_fw_version_that_need_py_version, should_modify=True
)


def test_node_should_be_modified_with_only_py_versions(constructors_with_only_py_version):
_test_node_should_be_modified(constructors_with_only_py_version, should_modify=True)


def test_node_should_be_modified_with_versions(constructors_with_both_versions):
_test_node_should_be_modified(constructors_with_both_versions, should_modify=False)


def test_node_should_be_modified_with_image(constructors_with_image):
Expand All @@ -155,17 +196,40 @@ def _test_modify_node(ctrs_before, ctrs_expected):
for before, expected in zip(ctrs_before, ctrs_expected):
node = ast_call(before)
modifier.modify_node(node)
# NOTE: this type of equality with pasta depends on ordering of args...
assert expected == pasta.dump(node)
_assert_equal_kwargs(ast_call(expected), node)


def _assert_equal_kwargs(expected, actual):
assert _keywords_for_node(expected) == _keywords_for_node(actual)


def test_modify_node_empty(constructors_empty, constructors_with_versions):
_test_modify_node(constructors_empty, constructors_with_versions)
def _keywords_for_node(node):
return {kw.arg: getattr(kw.value, kw.value._fields[0]) for kw in node.keywords}


def test_modify_node_with_versions(constructors_with_versions):
_test_modify_node(constructors_with_versions, constructors_with_versions)
def test_modify_node_empty(constructors_empty, constructors_with_both_versions):
_test_modify_node(constructors_empty, constructors_with_both_versions)


def test_modify_node_with_image(constructors_with_image, constructors_with_both):
_test_modify_node(constructors_with_image, constructors_with_both)
def test_modify_node_only_fw_version(
constructors_with_only_fw_version, constructors_with_both_versions
):
_test_modify_node(constructors_with_only_fw_version, constructors_with_both_versions)


def test_modify_node_only_py_version(
constructors_with_only_py_version, constructors_with_both_versions
):
_test_modify_node(constructors_with_only_py_version, constructors_with_both_versions)


def test_modify_node_only_named_fw_version():
_test_modify_node(
constructors(fw_version="named"), constructors(fw_version="named", py_version="literal")
)


def test_modify_node_only_named_py_version():
_test_modify_node(
constructors(py_version="named"), constructors(fw_version="literal", py_version="named")
)
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,20 @@ def test_node_should_be_modified_tf_constructor_script_mode():
"TensorFlow(py_version='py3')",
"TensorFlow(py_version='py37')",
"TensorFlow(py_version='py3', script_mode=False)",
"TensorFlow(py_version=py_version, script_mode=False)",
"TensorFlow(py_version='py3', script_mode=script_mode)",
"sagemaker.tensorflow.TensorFlow(script_mode=True)",
"sagemaker.tensorflow.TensorFlow(py_version='py3')",
"sagemaker.tensorflow.TensorFlow(py_version='py37')",
"sagemaker.tensorflow.TensorFlow(py_version='py3', script_mode=False)",
"sagemaker.tensorflow.TensorFlow(py_version=py_version, script_mode=False)",
"sagemaker.tensorflow.TensorFlow(py_version='py3', script_mode=script_mode)",
"sagemaker.tensorflow.estimator.TensorFlow(script_mode=True)",
"sagemaker.tensorflow.estimator.TensorFlow(py_version='py3')",
"sagemaker.tensorflow.estimator.TensorFlow(py_version='py37')",
"sagemaker.tensorflow.estimator.TensorFlow(py_version='py3', script_mode=False)",
"sagemaker.tensorflow.estimator.TensorFlow(py_version=py_version, script_mode=False)",
"sagemaker.tensorflow.estimator.TensorFlow(py_version='py3', script_mode=script_mode)",
)

modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()
Expand Down