Skip to content

Commit b584b7e

Browse files
committed
Merge branch 'zwei' into tf-docs
2 parents e9f429e + baf1c35 commit b584b7e

File tree

7 files changed

+88
-81
lines changed

7 files changed

+88
-81
lines changed

doc/conf.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,34 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
import pkg_resources
17-
import sys
1817
from datetime import datetime
19-
from unittest.mock import MagicMock
20-
21-
22-
class Mock(MagicMock):
23-
@classmethod
24-
def __getattr__(cls, name):
25-
"""
26-
Args:
27-
name:
28-
"""
29-
if name == "__version__":
30-
return "1.4.0"
31-
else:
32-
return MagicMock()
33-
34-
35-
MOCK_MODULES = [
36-
"tensorflow",
37-
"tensorflow.core",
38-
"tensorflow.core.framework",
39-
"tensorflow.python",
40-
"tensorflow.python.framework",
41-
"tensorflow_serving",
42-
"tensorflow_serving.apis",
43-
]
44-
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
4518

4619
project = u"sagemaker"
4720
version = pkg_resources.require(project)[0].version

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def read_version():
3535
# Declare minimal set for installation
3636
required_packages = [
3737
"boto3>=1.13.6",
38+
"google-pasta",
3839
"numpy>=1.9.0",
3940
"protobuf>=3.1",
4041
"protobuf3-to-dict>=0.1.5",
@@ -51,7 +52,6 @@ def read_version():
5152
"docker-compose>=1.25.2",
5253
"PyYAML>=5.3, <6", # PyYAML version has to match docker-compose requirements
5354
],
54-
"tensorflow": ["tensorflow>=1.3.0"],
5555
"scipy": ["scipy>=0.19.0"],
5656
}
5757
# Meta dependency groups

src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
# TODO: check for sagemaker.tensorflow.serving.Model
3030
FRAMEWORK_CLASSES = FRAMEWORKS + ["{}Model".format(fw) for fw in FRAMEWORKS]
3131
FRAMEWORK_MODULES = [fw.lower() for fw in FRAMEWORKS]
32+
FRAMEWORK_SUBMODULES = ("model", "estimator")
3233

3334

3435
class FrameworkVersionEnforcer(Modifier):
@@ -68,19 +69,30 @@ def _is_framework_constructor(self, node):
6869
if isinstance(node.func, ast.Name):
6970
return node.func.id in FRAMEWORK_CLASSES
7071

71-
# Check for sagemaker.<framework>.<Framework> call
72-
ends_with_framework_constructor = (
73-
isinstance(node.func, ast.Attribute) and node.func.attr in FRAMEWORK_CLASSES
74-
)
72+
# Check for something.that.ends.with.<framework>.<Framework> call
73+
if not (isinstance(node.func, ast.Attribute) and node.func.attr in FRAMEWORK_CLASSES):
74+
return False
7575

76-
is_in_framework_module = (
76+
# Check for sagemaker.<frameworks>.<estimator/model>.<Framework> call
77+
if (
7778
isinstance(node.func.value, ast.Attribute)
78-
and node.func.value.attr in FRAMEWORK_MODULES
79-
and isinstance(node.func.value.value, ast.Name)
80-
and node.func.value.value.id == "sagemaker"
81-
)
79+
and node.func.value.attr in FRAMEWORK_SUBMODULES
80+
):
81+
return self._is_in_framework_module(node.func.value)
8282

83-
return ends_with_framework_constructor and is_in_framework_module
83+
# Check for sagemaker.<framework>.<Framework> call
84+
return self._is_in_framework_module(node.func)
85+
86+
def _is_in_framework_module(self, node):
87+
"""Checks if the node is an ``ast.Attribute`` that represents a
88+
``sagemaker.<framework>`` module.
89+
"""
90+
return (
91+
isinstance(node.value, ast.Attribute)
92+
and node.value.attr in FRAMEWORK_MODULES
93+
and isinstance(node.value.value, ast.Name)
94+
and node.value.value.id == "sagemaker"
95+
)
8496

8597
def _fw_version_in_keywords(self, node):
8698
"""Checks if the ``ast.Call`` node's keywords contain ``framework_version``."""

src/sagemaker/cli/compatibility/v2/modifiers/tf_legacy_mode.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def node_should_be_modified(self, node):
5454
5555
- ``TensorFlow``
5656
- ``sagemaker.tensorflow.TensorFlow``
57+
- ``sagemaker.tensorflow.estimator.TensorFlow``
5758
5859
Legacy mode is enabled if (1) ``script_mode`` is ``False``, ``None``, or not specified,
5960
and (2) if ``py_version`` is ``py2`` or not specified.
@@ -68,27 +69,35 @@ def node_should_be_modified(self, node):
6869
return self._is_tf_constructor(node) and self._is_legacy_mode(node)
6970

7071
def _is_tf_constructor(self, node):
71-
"""Checks if the ``ast.Call`` node represents a call of the form
72-
``TensorFlow`` or ``sagemaker.tensorflow.TensorFlow``.
72+
"""Checks if the ``ast.Call`` node represents a call of the form ``TensorFlow``,
73+
``sagemaker.tensorflow.TensorFlow``, or ``sagemaker.tensorflow.estimator.TensorFlow``.
7374
"""
7475
# Check for TensorFlow()
7576
if isinstance(node.func, ast.Name):
7677
return node.func.id == "TensorFlow"
7778

79+
# Check for something.that.ends.with.TensorFlow()
80+
if not (isinstance(node.func, ast.Attribute) and node.func.attr == "TensorFlow"):
81+
return False
82+
83+
# Check for sagemaker.tensorflow.estimator.TensorFlow()
84+
if isinstance(node.func.value, ast.Attribute) and node.func.value.attr == "estimator":
85+
return self._is_in_tensorflow_module(node.func.value)
86+
7887
# Check for sagemaker.tensorflow.TensorFlow()
79-
ends_with_tensorflow_constructor = (
80-
isinstance(node.func, ast.Attribute) and node.func.attr == "TensorFlow"
81-
)
88+
return self._is_in_tensorflow_module(node.func)
8289

83-
is_in_tensorflow_module = (
84-
isinstance(node.func.value, ast.Attribute)
85-
and node.func.value.attr == "tensorflow"
86-
and isinstance(node.func.value.value, ast.Name)
87-
and node.func.value.value.id == "sagemaker"
90+
def _is_in_tensorflow_module(self, node):
91+
"""Checks if the node is an ``ast.Attribute`` that represents the
92+
``sagemaker.tensorflow`` module.
93+
"""
94+
return (
95+
isinstance(node.value, ast.Attribute)
96+
and node.value.attr == "tensorflow"
97+
and isinstance(node.value.value, ast.Name)
98+
and node.value.value.id == "sagemaker"
8899
)
89100

90-
return ends_with_tensorflow_constructor and is_in_tensorflow_module
91-
92101
def _is_legacy_mode(self, node):
93102
"""Checks if the ``ast.Call`` node's keywords signal using legacy mode."""
94103
script_mode = False

tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_framework_version.py

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -32,24 +32,34 @@ def test_node_should_be_modified_fw_constructor_no_fw_version():
3232
fw_constructors = (
3333
"TensorFlow()",
3434
"sagemaker.tensorflow.TensorFlow()",
35+
"sagemaker.tensorflow.estimator.TensorFlow()",
3536
"TensorFlowModel()",
3637
"sagemaker.tensorflow.TensorFlowModel()",
38+
"sagemaker.tensorflow.model.TensorFlowModel()",
3739
"MXNet()",
3840
"sagemaker.mxnet.MXNet()",
41+
"sagemaker.mxnet.estimator.MXNet()",
3942
"MXNetModel()",
4043
"sagemaker.mxnet.MXNetModel()",
44+
"sagemaker.mxnet.model.MXNetModel()",
4145
"Chainer()",
4246
"sagemaker.chainer.Chainer()",
47+
"sagemaker.chainer.estimator.Chainer()",
4348
"ChainerModel()",
4449
"sagemaker.chainer.ChainerModel()",
50+
"sagemaker.chainer.model.ChainerModel()",
4551
"PyTorch()",
4652
"sagemaker.pytorch.PyTorch()",
53+
"sagemaker.pytorch.estimator.PyTorch()",
4754
"PyTorchModel()",
4855
"sagemaker.pytorch.PyTorchModel()",
56+
"sagemaker.pytorch.model.PyTorchModel()",
4957
"SKLearn()",
5058
"sagemaker.sklearn.SKLearn()",
59+
"sagemaker.sklearn.estimator.SKLearn()",
5160
"SKLearnModel()",
5261
"sagemaker.sklearn.SKLearnModel()",
62+
"sagemaker.sklearn.model.SKLearnModel()",
5363
)
5464

5565
modifier = framework_version.FrameworkVersionEnforcer()
@@ -63,24 +73,34 @@ def test_node_should_be_modified_fw_constructor_with_fw_version():
6373
fw_constructors = (
6474
"TensorFlow(framework_version='2.2')",
6575
"sagemaker.tensorflow.TensorFlow(framework_version='2.2')",
76+
"sagemaker.tensorflow.estimator.TensorFlow(framework_version='2.2')",
6677
"TensorFlowModel(framework_version='1.10')",
6778
"sagemaker.tensorflow.TensorFlowModel(framework_version='1.10')",
79+
"sagemaker.tensorflow.model.TensorFlowModel(framework_version='1.10')",
6880
"MXNet(framework_version='1.6')",
6981
"sagemaker.mxnet.MXNet(framework_version='1.6')",
82+
"sagemaker.mxnet.estimator.MXNet(framework_version='1.6')",
7083
"MXNetModel(framework_version='1.6')",
7184
"sagemaker.mxnet.MXNetModel(framework_version='1.6')",
85+
"sagemaker.mxnet.model.MXNetModel(framework_version='1.6')",
7286
"PyTorch(framework_version='1.4')",
7387
"sagemaker.pytorch.PyTorch(framework_version='1.4')",
88+
"sagemaker.pytorch.estimator.PyTorch(framework_version='1.4')",
7489
"PyTorchModel(framework_version='1.4')",
7590
"sagemaker.pytorch.PyTorchModel(framework_version='1.4')",
91+
"sagemaker.pytorch.model.PyTorchModel(framework_version='1.4')",
7692
"Chainer(framework_version='5.0')",
7793
"sagemaker.chainer.Chainer(framework_version='5.0')",
94+
"sagemaker.chainer.estimator.Chainer(framework_version='5.0')",
7895
"ChainerModel(framework_version='5.0')",
7996
"sagemaker.chainer.ChainerModel(framework_version='5.0')",
97+
"sagemaker.chainer.model.ChainerModel(framework_version='5.0')",
8098
"SKLearn(framework_version='0.20.0')",
8199
"sagemaker.sklearn.SKLearn(framework_version='0.20.0')",
100+
"sagemaker.sklearn.estimator.SKLearn(framework_version='0.20.0')",
82101
"SKLearnModel(framework_version='0.20.0')",
83102
"sagemaker.sklearn.SKLearnModel(framework_version='0.20.0')",
103+
"sagemaker.sklearn.model.SKLearnModel(framework_version='0.20.0')",
84104
)
85105

86106
modifier = framework_version.FrameworkVersionEnforcer()
@@ -97,51 +117,36 @@ def test_node_should_be_modified_random_function_call():
97117

98118

99119
def test_modify_node_tf():
100-
classes = (
101-
"TensorFlow" "sagemaker.tensorflow.TensorFlow",
102-
"TensorFlowModel",
103-
"sagemaker.tensorflow.TensorFlowModel",
104-
)
105-
_test_modify_node(classes, "1.11.0")
120+
_test_modify_node("TensorFlow", "1.11.0")
106121

107122

108123
def test_modify_node_mx():
109-
classes = ("MXNet", "sagemaker.mxnet.MXNet", "MXNetModel", "sagemaker.mxnet.MXNetModel")
110-
_test_modify_node(classes, "1.2.0")
124+
_test_modify_node("MXNet", "1.2.0")
111125

112126

113127
def test_modify_node_chainer():
114-
classes = (
115-
"Chainer",
116-
"sagemaker.chainer.Chainer",
117-
"ChainerModel",
118-
"sagemaker.chainer.ChainerModel",
119-
)
120-
_test_modify_node(classes, "4.1.0")
128+
_test_modify_node("Chainer", "4.1.0")
121129

122130

123131
def test_modify_node_pt():
124-
classes = (
125-
"PyTorch",
126-
"sagemaker.pytorch.PyTorch",
127-
"PyTorchModel",
128-
"sagemaker.pytorch.PyTorchModel",
129-
)
130-
_test_modify_node(classes, "0.4.0")
132+
_test_modify_node("PyTorch", "0.4.0")
131133

132134

133135
def test_modify_node_sklearn():
134-
classes = (
135-
"SKLearn",
136-
"sagemaker.sklearn.SKLearn",
137-
"SKLearnModel",
138-
"sagemaker.sklearn.SKLearnModel",
139-
)
140-
_test_modify_node(classes, "0.20.0")
136+
_test_modify_node("SKLearn", "0.20.0")
141137

142138

143-
def _test_modify_node(classes, default_version):
139+
def _test_modify_node(framework, default_version):
144140
modifier = framework_version.FrameworkVersionEnforcer()
141+
142+
classes = (
143+
"{}".format(framework),
144+
"sagemaker.{}.{}".format(framework.lower(), framework),
145+
"sagemaker.{}.estimator.{}".format(framework.lower(), framework),
146+
"{}Model".format(framework),
147+
"sagemaker.{}.{}Model".format(framework.lower(), framework),
148+
"sagemaker.{}.model.{}Model".format(framework.lower(), framework),
149+
)
145150
for cls in classes:
146151
node = ast_call("{}()".format(cls))
147152
modifier.modify_node(node)

tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_tf_legacy_mode.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ def test_node_should_be_modified_tf_constructor_legacy_mode():
4242
"sagemaker.tensorflow.TensorFlow(script_mode=None)",
4343
"sagemaker.tensorflow.TensorFlow(py_version='py2')",
4444
"sagemaker.tensorflow.TensorFlow()",
45+
"sagemaker.tensorflow.estimator.TensorFlow(script_mode=False)",
46+
"sagemaker.tensorflow.estimator.TensorFlow(script_mode=None)",
47+
"sagemaker.tensorflow.estimator.TensorFlow(py_version='py2')",
48+
"sagemaker.tensorflow.estimator.TensorFlow()",
4549
)
4650

4751
modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()
@@ -61,6 +65,10 @@ def test_node_should_be_modified_tf_constructor_script_mode():
6165
"sagemaker.tensorflow.TensorFlow(py_version='py3')",
6266
"sagemaker.tensorflow.TensorFlow(py_version='py37')",
6367
"sagemaker.tensorflow.TensorFlow(py_version='py3', script_mode=False)",
68+
"sagemaker.tensorflow.estimator.TensorFlow(script_mode=True)",
69+
"sagemaker.tensorflow.estimator.TensorFlow(py_version='py3')",
70+
"sagemaker.tensorflow.estimator.TensorFlow(py_version='py37')",
71+
"sagemaker.tensorflow.estimator.TensorFlow(py_version='py3', script_mode=False)",
6472
)
6573

6674
modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()

tox.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ passenv =
6363
# Can be used to specify which tests to run, e.g.: tox -- -s
6464
commands =
6565
coverage run --source sagemaker -m pytest {posargs}
66-
{env:IGNORE_COVERAGE:} coverage report --fail-under=86 --omit */tensorflow/tensorflow_serving/*
66+
{env:IGNORE_COVERAGE:} coverage report --fail-under=86
6767
extras = test
6868

6969
[testenv:flake8]

0 commit comments

Comments
 (0)