Skip to content

Commit a680be1

Browse files
authored
change: convert TF legacy mode parameters to hyperparameters in v2 migration script (#1534)
1 parent c65c80f commit a680be1

File tree

6 files changed

+334
-6
lines changed

6 files changed

+334
-6
lines changed

buildspec-unittests.yml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,12 @@ phases:
1818
- start_time=`date +%s`
1919
- AWS_ACCESS_KEY_ID= AWS_SECRET_ACCESS_KEY= AWS_SESSION_TOKEN=
2020
AWS_CONTAINER_CREDENTIALS_RELATIVE_URI= AWS_DEFAULT_REGION=
21-
tox -e py27,py36,py37 --parallel all -- tests/unit
22-
- ./ci-scripts/displaytime.sh 'py27,py36,py37 unit' $start_time
21+
tox -e py36,py37 --parallel all -- tests/unit
22+
- ./ci-scripts/displaytime.sh 'py36,py37 unit' $start_time
23+
24+
# Remove once https://github.com/aws/sagemaker-python-sdk/issues/1461 is addressed.
25+
- start_time=`date +%s`
26+
- AWS_ACCESS_KEY_ID= AWS_SECRET_ACCESS_KEY= AWS_SESSION_TOKEN=
27+
AWS_CONTAINER_CREDENTIALS_RELATIVE_URI= AWS_DEFAULT_REGION=
28+
IGNORE_COVERAGE=- tox -e py27 --parallel all -- tests/unit
29+
- ./ci-scripts/displaytime.sh 'py27 unit' $start_time

src/sagemaker/cli/compatibility/v2/ast_transformer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515

1616
import ast
1717

18-
from sagemaker.cli.compatibility.v2.modifiers import framework_version
18+
from sagemaker.cli.compatibility.v2 import modifiers
1919

20-
FUNCTION_CALL_MODIFIERS = [framework_version.FrameworkVersionEnforcer()]
20+
FUNCTION_CALL_MODIFIERS = [
21+
modifiers.framework_version.FrameworkVersionEnforcer(),
22+
modifiers.tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader(),
23+
]
2124

2225

2326
class ASTTransformer(ast.NodeTransformer):
@@ -38,4 +41,6 @@ def visit_Call(self, node):
3841
"""
3942
for function_checker in FUNCTION_CALL_MODIFIERS:
4043
function_checker.check_and_modify_node(node)
44+
45+
ast.fix_missing_locations(node)
4146
return node

src/sagemaker/cli/compatibility/v2/modifiers/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,8 @@
1212
# language governing permissions and limitations under the License.
1313
"""Classes for modifying AST nodes"""
1414
from __future__ import absolute_import
15+
16+
from sagemaker.cli.compatibility.v2.modifiers import ( # noqa: F401 (imported but unused)
17+
framework_version,
18+
tf_legacy_mode,
19+
)

src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,7 @@ def _is_framework_constructor(self, node):
6666
"""
6767
# Check for <Framework> call
6868
if isinstance(node.func, ast.Name):
69-
if node.func.id in FRAMEWORK_CLASSES:
70-
return True
69+
return node.func.id in FRAMEWORK_CLASSES
7170

7271
# Check for sagemaker.<framework>.<Framework> call
7372
ends_with_framework_constructor = (
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Classes to modify TensorFlow legacy mode code to be compatible with SageMaker Python SDK v2."""
14+
# TODO: handle fit(run_tensorboard_locally=True)
15+
from __future__ import absolute_import
16+
17+
import ast
18+
19+
import six
20+
21+
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
22+
23+
24+
class TensorFlowLegacyModeConstructorUpgrader(Modifier):
25+
"""A class to turn legacy mode parameters into hyperparameters when
26+
instantiating a TensorFlow estimator.
27+
"""
28+
29+
LEGACY_MODE_PARAMETERS = (
30+
"checkpoint_path",
31+
"evaluation_steps",
32+
"requirements_file",
33+
"training_steps",
34+
)
35+
36+
def node_should_be_modified(self, node):
37+
"""Checks if the ``ast.Call`` node instantiates a TensorFlow estimator with legacy mode.
38+
39+
This looks for the following formats:
40+
41+
- ``TensorFlow``
42+
- ``sagemaker.tensorflow.TensorFlow``
43+
44+
Legacy mode is enabled if (1) ``script_mode`` is ``False``, ``None``, or not specified,
45+
and (2) if ``py_version`` is ``py2`` or not specified.
46+
47+
Args:
48+
node (ast.Call): a node that represents a function call. For more,
49+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
50+
51+
Returns:
52+
bool: If the ``ast.Call`` is instantiating a TensorFlow estimator with legacy mode.
53+
"""
54+
return self._is_tf_constructor(node) and self._is_legacy_mode(node)
55+
56+
def _is_tf_constructor(self, node):
57+
"""Checks if the ``ast.Call`` node represents a call of the form
58+
``TensorFlow`` or ``sagemaker.tensorflow.TensorFlow``.
59+
"""
60+
# Check for TensorFlow()
61+
if isinstance(node.func, ast.Name):
62+
return node.func.id == "TensorFlow"
63+
64+
# Check for sagemaker.tensorflow.TensorFlow()
65+
ends_with_tensorflow_constructor = (
66+
isinstance(node.func, ast.Attribute) and node.func.attr == "TensorFlow"
67+
)
68+
69+
is_in_tensorflow_module = (
70+
isinstance(node.func.value, ast.Attribute)
71+
and node.func.value.attr == "tensorflow"
72+
and isinstance(node.func.value.value, ast.Name)
73+
and node.func.value.value.id == "sagemaker"
74+
)
75+
76+
return ends_with_tensorflow_constructor and is_in_tensorflow_module
77+
78+
def _is_legacy_mode(self, node):
79+
"""Checks if the ``ast.Call`` node's keywords signal using legacy mode."""
80+
script_mode = False
81+
py_version = "py2"
82+
83+
for kw in node.keywords:
84+
if kw.arg == "script_mode":
85+
script_mode = bool(kw.value.value)
86+
if kw.arg == "py_version":
87+
py_version = kw.value.s
88+
89+
return not (py_version.startswith("py3") or script_mode)
90+
91+
def modify_node(self, node):
92+
"""Modifies the ``ast.Call`` node's keywords to turn TensorFlow legacy mode parameters
93+
into hyperparameters and set ``script_mode=False``.
94+
95+
The parameters that are converted into hyperparameters:
96+
97+
- ``training_steps``
98+
- ``evaluation_steps``
99+
- ``checkpoint_path``
100+
- ``requirements_file``
101+
102+
Args:
103+
node (ast.Call): a node that represents a TensorFlow constructor.
104+
"""
105+
base_hps = {}
106+
additional_hps = {}
107+
kw_to_remove = [] # remove keyword args after so that none are skipped during iteration
108+
109+
for kw in node.keywords:
110+
if kw.arg == "script_mode":
111+
# remove here because is set to False later regardless of current value
112+
kw_to_remove.append(kw)
113+
if kw.arg == "hyperparameters" and kw.value:
114+
base_hps = dict(zip(kw.value.keys, kw.value.values))
115+
kw_to_remove.append(kw)
116+
if kw.arg in self.LEGACY_MODE_PARAMETERS and kw.value:
117+
hp_key = self._hyperparameter_key_for_param(kw.arg)
118+
additional_hps[hp_key] = kw.value
119+
kw_to_remove.append(kw)
120+
121+
self._remove_keywords(node, kw_to_remove)
122+
self._add_updated_hyperparameters(node, base_hps, additional_hps)
123+
124+
node.keywords.append(ast.keyword(arg="script_mode", value=ast.NameConstant(value=False)))
125+
126+
def _hyperparameter_key_for_param(self, arg):
127+
"""Returns an ``ast.Str`` for a hyperparameter key replacing a legacy mode parameter."""
128+
name = "sagemaker_requirements" if arg == "requirements_file" else arg
129+
return ast.Str(s=name)
130+
131+
def _remove_keywords(self, node, keywords):
132+
"""Removes the keywords from the ``ast.Call`` node."""
133+
for kw in keywords:
134+
node.keywords.remove(kw)
135+
136+
def _add_updated_hyperparameters(self, node, base_hps, additional_hps):
137+
"""Combines and adds the hyperparameters to the ``ast.Call`` node's keywords."""
138+
base_hps.update(additional_hps)
139+
updated_hp_keyword = self._to_ast_keyword(base_hps)
140+
141+
if updated_hp_keyword:
142+
node.keywords.append(updated_hp_keyword)
143+
144+
def _to_ast_keyword(self, hps):
145+
"""Returns an ``ast.keyword`` for the ``hyperparameters`` kwarg if there are any."""
146+
if hps:
147+
keys, values = zip(*six.iteritems(hps))
148+
return ast.keyword(arg="hyperparameters", value=ast.Dict(keys=keys, values=values))
149+
150+
return None
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import sys
16+
17+
import pasta
18+
import pytest
19+
20+
from sagemaker.cli.compatibility.v2.modifiers import tf_legacy_mode
21+
22+
23+
@pytest.fixture(autouse=True)
24+
def skip_if_py2():
25+
# Remove once https://github.com/aws/sagemaker-python-sdk/issues/1461 is addressed.
26+
if sys.version_info.major < 3:
27+
pytest.skip("v2 migration script doesn't support Python 2.")
28+
29+
30+
def test_node_should_be_modified_tf_constructor_legacy_mode():
31+
tf_legacy_mode_constructors = (
32+
"TensorFlow(script_mode=False)",
33+
"TensorFlow(script_mode=None)",
34+
"TensorFlow(py_version='py2')",
35+
"TensorFlow()",
36+
"sagemaker.tensorflow.TensorFlow(script_mode=False)",
37+
"sagemaker.tensorflow.TensorFlow(script_mode=None)",
38+
"sagemaker.tensorflow.TensorFlow(py_version='py2')",
39+
"sagemaker.tensorflow.TensorFlow()",
40+
)
41+
42+
modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()
43+
44+
for constructor in tf_legacy_mode_constructors:
45+
node = _ast_call(constructor)
46+
assert modifier.node_should_be_modified(node) is True
47+
48+
49+
def test_node_should_be_modified_tf_constructor_script_mode():
50+
tf_script_mode_constructors = (
51+
"TensorFlow(script_mode=True)",
52+
"TensorFlow(py_version='py3')",
53+
"TensorFlow(py_version='py37')",
54+
"TensorFlow(py_version='py3', script_mode=False)",
55+
"sagemaker.tensorflow.TensorFlow(script_mode=True)",
56+
"sagemaker.tensorflow.TensorFlow(py_version='py3')",
57+
"sagemaker.tensorflow.TensorFlow(py_version='py37')",
58+
"sagemaker.tensorflow.TensorFlow(py_version='py3', script_mode=False)",
59+
)
60+
61+
modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()
62+
63+
for constructor in tf_script_mode_constructors:
64+
node = _ast_call(constructor)
65+
assert modifier.node_should_be_modified(node) is False
66+
67+
68+
def test_node_should_be_modified_random_function_call():
69+
node = _ast_call("MXNet(py_version='py3')")
70+
modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()
71+
assert modifier.node_should_be_modified(node) is False
72+
73+
74+
def test_modify_node_set_script_mode_false():
75+
tf_constructors = (
76+
"TensorFlow()",
77+
"TensorFlow(script_mode=False)",
78+
"TensorFlow(script_mode=None)",
79+
)
80+
modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()
81+
82+
for constructor in tf_constructors:
83+
node = _ast_call(constructor)
84+
modifier.modify_node(node)
85+
assert "TensorFlow(script_mode=False)" == pasta.dump(node)
86+
87+
88+
def test_modify_node_set_hyperparameters():
89+
tf_constructor = """TensorFlow(
90+
checkpoint_path='s3://foo/bar',
91+
training_steps=100,
92+
evaluation_steps=10,
93+
requirements_file='source/requirements.txt',
94+
)"""
95+
96+
node = _ast_call(tf_constructor)
97+
modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()
98+
modifier.modify_node(node)
99+
100+
expected_hyperparameters = {
101+
"checkpoint_path": "s3://foo/bar",
102+
"evaluation_steps": 10,
103+
"sagemaker_requirements": "source/requirements.txt",
104+
"training_steps": 100,
105+
}
106+
107+
assert expected_hyperparameters == _hyperparameters_from_node(node)
108+
109+
110+
def test_modify_node_preserve_other_hyperparameters():
111+
tf_constructor = """sagemaker.tensorflow.TensorFlow(
112+
training_steps=100,
113+
evaluation_steps=10,
114+
requirements_file='source/requirements.txt',
115+
hyperparameters={'optimizer': 'sgd', 'lr': 0.1, 'checkpoint_path': 's3://foo/bar'},
116+
)"""
117+
118+
node = _ast_call(tf_constructor)
119+
modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()
120+
modifier.modify_node(node)
121+
122+
expected_hyperparameters = {
123+
"optimizer": "sgd",
124+
"lr": 0.1,
125+
"checkpoint_path": "s3://foo/bar",
126+
"evaluation_steps": 10,
127+
"sagemaker_requirements": "source/requirements.txt",
128+
"training_steps": 100,
129+
}
130+
131+
assert expected_hyperparameters == _hyperparameters_from_node(node)
132+
133+
134+
def test_modify_node_prefer_param_over_hyperparameter():
135+
tf_constructor = """sagemaker.tensorflow.TensorFlow(
136+
training_steps=100,
137+
requirements_file='source/requirements.txt',
138+
hyperparameters={'training_steps': 10, 'sagemaker_requirements': 'foo.txt'},
139+
)"""
140+
141+
node = _ast_call(tf_constructor)
142+
modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()
143+
modifier.modify_node(node)
144+
145+
expected_hyperparameters = {
146+
"sagemaker_requirements": "source/requirements.txt",
147+
"training_steps": 100,
148+
}
149+
150+
assert expected_hyperparameters == _hyperparameters_from_node(node)
151+
152+
153+
def _hyperparameters_from_node(node):
154+
for kw in node.keywords:
155+
if kw.arg == "hyperparameters":
156+
keys = [k.s for k in kw.value.keys]
157+
values = [getattr(v, v._fields[0]) for v in kw.value.values]
158+
return dict(zip(keys, values))
159+
160+
161+
def _ast_call(code):
162+
return pasta.parse(code).body[0].value

0 commit comments

Comments
 (0)