Skip to content

Commit af2b6d5

Browse files
committed
change: make v2 migration script add py_version when needed
1 parent 49bab2b commit af2b6d5

File tree

4 files changed

+271
-0
lines changed

4 files changed

+271
-0
lines changed

src/sagemaker/cli/compatibility/v2/ast_transformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from sagemaker.cli.compatibility.v2 import modifiers
1919

2020
FUNCTION_CALL_MODIFIERS = [
21+
modifiers.py_version.PyVersionEnforcer(),
2122
modifiers.framework_version.FrameworkVersionEnforcer(),
2223
modifiers.tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader(),
2324
modifiers.tf_legacy_mode.TensorBoardParameterRemover(),

src/sagemaker/cli/compatibility/v2/modifiers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from sagemaker.cli.compatibility.v2.modifiers import ( # noqa: F401 (imported but unused)
1717
deprecated_params,
1818
framework_version,
19+
py_version,
1920
tf_legacy_mode,
2021
tfs,
2122
)
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""A class to ensure that ``py_version`` is defined when constructing framework classes."""
14+
from __future__ import absolute_import
15+
16+
import ast
17+
18+
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
19+
20+
PY_VERSION_ARG = "py_version"
21+
PY_VERSION_DEFAULT = "py3"
22+
23+
FRAMEWORK_MODEL_REQUIRES_PY_VERSION = {
24+
"Chainer": True,
25+
"MXNet": True,
26+
"PyTorch": True,
27+
"SKLearn": False,
28+
"TensorFlow": False,
29+
}
30+
31+
FRAMEWORK_CLASSES = list(FRAMEWORK_MODEL_REQUIRES_PY_VERSION.keys())
32+
33+
MODEL_CLASSES = [
34+
"{}Model".format(fw) for fw, required in FRAMEWORK_MODEL_REQUIRES_PY_VERSION.items() if required
35+
]
36+
FRAMEWORK_MODULES = [fw.lower() for fw in FRAMEWORK_CLASSES]
37+
FRAMEWORK_SUBMODULES = ("model", "estimator")
38+
39+
40+
class PyVersionEnforcer(Modifier):
41+
"""A class to ensure that ``py_version`` is defined when
42+
instantiating a framework estimator or model, where appropriate.
43+
"""
44+
45+
def node_should_be_modified(self, node):
46+
"""Checks if the ast.Call node should be modified to include ``py_version``.
47+
48+
If the ast.Call node instantiates a framework estimator or model, but doesn't
49+
specify the ``py_version`` parameter when required, then the node should be
50+
modified. However, if ``image_name`` for a framework estimator or ``image``
51+
for a model is supplied to the call, then ``py_version`` is not required.
52+
53+
This looks for the following formats:
54+
55+
- ``PyTorch``
56+
- ``sagemaker.pytorch.PyTorch``
57+
58+
where "PyTorch" can be Chainer, MXNet, PyTorch, SKLearn, or TensorFlow.
59+
60+
Args:
61+
node (ast.Call): a node that represents a function call. For more,
62+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
63+
64+
Returns:
65+
bool: If the ``ast.Call`` is instantiating a framework class that
66+
should specify ``py_version``, but doesn't.
67+
"""
68+
if _is_named_constructor(node, FRAMEWORK_CLASSES):
69+
return _version_arg_needed(node, "image_name", PY_VERSION_ARG)
70+
71+
if _is_named_constructor(node, MODEL_CLASSES):
72+
return _version_arg_needed(node, "image", PY_VERSION_ARG)
73+
74+
return False
75+
76+
def modify_node(self, node):
77+
"""Modifies the ``ast.Call`` node's keywords to include ``py_version``.
78+
79+
Args:
80+
node (ast.Call): a node that represents the constructor of a framework class.
81+
"""
82+
node.keywords.append(ast.keyword(arg=PY_VERSION_ARG, value=ast.Str(s=PY_VERSION_DEFAULT)))
83+
84+
85+
def _is_named_constructor(node, names):
86+
"""Checks if the ``ast.Call`` node represents a call to particular named constructors.
87+
88+
Forms that qualify are either <Framework> or sagemaker.<framework>.<Framework>
89+
where <Framework> belongs to the list of names passed in.
90+
"""
91+
# Check for call from particular names of constructors
92+
if isinstance(node.func, ast.Name):
93+
return node.func.id in names
94+
95+
# Check for something.that.ends.with.<framework>.<Framework> call for Framework in names
96+
if not (isinstance(node.func, ast.Attribute) and node.func.attr in names):
97+
return False
98+
99+
# Check for sagemaker.<frameworks>.<estimator/model>.<Framework> call
100+
if isinstance(node.func.value, ast.Attribute) and node.func.value.attr in FRAMEWORK_SUBMODULES:
101+
return _is_in_framework_module(node.func.value)
102+
103+
# Check for sagemaker.<framework>.<Framework> call
104+
return _is_in_framework_module(node.func)
105+
106+
107+
def _is_in_framework_module(node):
108+
"""Checks if node is an ``ast.Attribute`` representing a ``sagemaker.<framework>`` module."""
109+
return (
110+
isinstance(node.value, ast.Attribute)
111+
and node.value.attr in FRAMEWORK_MODULES
112+
and isinstance(node.value.value, ast.Name)
113+
and node.value.value.id == "sagemaker"
114+
)
115+
116+
117+
def _version_arg_needed(node, image_arg, version_arg):
118+
"""Determines if image_arg or version_arg was supplied"""
119+
return not (_arg_supplied(node, image_arg) or _arg_supplied(node, version_arg))
120+
121+
122+
def _arg_supplied(node, arg):
123+
"""Checks if the ``ast.Call`` node's keywords contain ``arg``."""
124+
return any(kw.arg == arg and kw.value for kw in node.keywords)
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import sys
16+
17+
import pasta
18+
import pytest
19+
20+
from sagemaker.cli.compatibility.v2.modifiers import py_version
21+
from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call
22+
23+
24+
@pytest.fixture(autouse=True)
25+
def skip_if_py2():
26+
# Remove once https://github.com/aws/sagemaker-python-sdk/issues/1461 is addressed.
27+
if sys.version_info.major < 3:
28+
pytest.skip("v2 migration script doesn't support Python 2.")
29+
30+
31+
@pytest.fixture
32+
def constructor_framework_templates():
33+
return (
34+
"TensorFlow({})",
35+
"sagemaker.tensorflow.TensorFlow({})",
36+
"sagemaker.tensorflow.estimator.TensorFlow({})",
37+
"MXNet({})",
38+
"sagemaker.mxnet.MXNet({})",
39+
"sagemaker.mxnet.estimator.MXNet({})",
40+
"Chainer({})",
41+
"sagemaker.chainer.Chainer({})",
42+
"sagemaker.chainer.estimator.Chainer({})",
43+
"PyTorch({})",
44+
"sagemaker.pytorch.PyTorch({})",
45+
"sagemaker.pytorch.estimator.PyTorch({})",
46+
"SKLearn({})",
47+
"sagemaker.sklearn.SKLearn({})",
48+
"sagemaker.sklearn.estimator.SKLearn({})",
49+
)
50+
51+
52+
@pytest.fixture
53+
def constructor_model_templates():
54+
return (
55+
"MXNetModel({})",
56+
"sagemaker.mxnet.MXNetModel({})",
57+
"sagemaker.mxnet.model.MXNetModel({})",
58+
"ChainerModel({})",
59+
"sagemaker.chainer.ChainerModel({})",
60+
"sagemaker.chainer.model.ChainerModel({})",
61+
"PyTorchModel({})",
62+
"sagemaker.pytorch.PyTorchModel({})",
63+
"sagemaker.pytorch.model.PyTorchModel({})",
64+
)
65+
66+
67+
@pytest.fixture
68+
def constructor_templates(constructor_framework_templates, constructor_model_templates):
69+
return tuple(list(constructor_framework_templates) + list(constructor_model_templates))
70+
71+
72+
@pytest.fixture
73+
def constructors_no_version(constructor_templates):
74+
return (ctr.format("") for ctr in constructor_templates)
75+
76+
77+
@pytest.fixture
78+
def constructors_with_version(constructor_templates):
79+
return (ctr.format("py_version='py3'") for ctr in constructor_templates)
80+
81+
82+
@pytest.fixture
83+
def constructors_with_image_name(constructor_framework_templates):
84+
return (ctr.format("image_name='my:image'") for ctr in constructor_framework_templates)
85+
86+
87+
@pytest.fixture
88+
def constructors_with_image(constructor_model_templates):
89+
return (ctr.format("image='my:image'") for ctr in constructor_model_templates)
90+
91+
92+
@pytest.fixture
93+
def constructors_version_not_needed():
94+
return (
95+
"TensorFlowModel()",
96+
"sagemaker.tensorflow.TensorFlowModel()",
97+
"sagemaker.tensorflow.model.TensorFlowModel()",
98+
"SKLearnModel()",
99+
"sagemaker.sklearn.SKLearnModel()",
100+
"sagemaker.sklearn.model.SKLearnModel()",
101+
)
102+
103+
104+
def _test_modified(constructors, should_be):
105+
modifier = py_version.PyVersionEnforcer()
106+
for constructor in constructors:
107+
node = ast_call(constructor)
108+
if should_be:
109+
assert modifier.node_should_be_modified(node)
110+
else:
111+
assert not modifier.node_should_be_modified(node)
112+
113+
114+
def test_node_should_be_modified_fw_constructor_no_version(constructors_no_version):
115+
_test_modified(constructors_no_version, should_be=True)
116+
117+
118+
def test_node_should_be_modified_fw_constructor_with_version(constructors_with_version):
119+
_test_modified(constructors_with_version, should_be=False)
120+
121+
122+
def test_node_should_be_modified_fw_constructor_with_image_name(constructors_with_image_name):
123+
_test_modified(constructors_with_image_name, should_be=False)
124+
125+
126+
def test_node_should_be_modified_fw_constructor_with_image(constructors_with_image):
127+
_test_modified(constructors_with_image, should_be=False)
128+
129+
130+
def test_node_should_be_modified_fw_constructor_not_needed(constructors_version_not_needed):
131+
_test_modified(constructors_version_not_needed, should_be=False)
132+
133+
134+
def test_node_should_be_modified_random_function_call():
135+
_test_modified(["sagemaker.session.Session()"], should_be=False)
136+
137+
138+
def test_modify_node(constructor_templates):
139+
modifier = py_version.PyVersionEnforcer()
140+
for template in constructor_templates:
141+
no_version, with_version = template.format(""), template.format("py_version='py3'")
142+
node = ast_call(no_version)
143+
modifier.modify_node(node)
144+
145+
assert with_version == pasta.dump(node)

0 commit comments

Comments
 (0)