Skip to content

Commit de40298

Browse files
authored
Merge branch 'zwei' into rename-s3-input
2 parents 04cdc7b + 0e4c0fa commit de40298

File tree

4 files changed

+132
-50
lines changed

4 files changed

+132
-50
lines changed

src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
import ast
1717

18+
from packaging.version import InvalidVersion, Version
19+
1820
from sagemaker.cli.compatibility.v2.modifiers import matching, parsing
1921
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
2022

@@ -135,10 +137,15 @@ def _tf_py_version_default(framework_version):
135137
"""Gets the py_version default based on framework_version for TensorFlow."""
136138
if not framework_version:
137139
return "py2"
138-
version = [int(s) for s in framework_version.split(".")]
139-
if version < [1, 12]:
140+
141+
try:
142+
version = Version(framework_version)
143+
except InvalidVersion:
144+
return "py2"
145+
146+
if version < Version("1.12"):
140147
return "py2"
141-
if version < [2, 2]:
148+
if version < Version("2.2"):
142149
return "py3"
143150
return "py37"
144151

@@ -186,7 +193,6 @@ def _version_args_needed(node):
186193
framework, is_model = _framework_from_node(node)
187194
expecting_py_version = _py_version_defaults(framework, framework_version, is_model)
188195
if expecting_py_version:
189-
py_version = parsing.arg_value(node, PY_ARG)
190-
return py_version is None
196+
return not matching.has_arg(node, PY_ARG)
191197

192198
return False

src/sagemaker/cli/compatibility/v2/modifiers/tf_legacy_mode.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,11 @@ def _is_legacy_mode(self, node):
7979

8080
for kw in node.keywords:
8181
if kw.arg == "script_mode":
82-
script_mode = bool(kw.value.value)
82+
script_mode = (
83+
bool(kw.value.value) if isinstance(kw.value, ast.NameConstant) else True
84+
)
8385
if kw.arg == "py_version":
84-
py_version = kw.value.s
86+
py_version = kw.value.s if isinstance(kw.value, ast.Str) else "py3"
8587

8688
return not (py_version.startswith("py3") or script_mode)
8789

@@ -124,7 +126,8 @@ def modify_node(self, node):
124126

125127
if add_image_uri:
126128
image_uri = self._image_uri_from_args(node.keywords)
127-
node.keywords.append(ast.keyword(arg="image_uri", value=ast.Str(s=image_uri)))
129+
if image_uri:
130+
node.keywords.append(ast.keyword(arg="image_uri", value=ast.Str(s=image_uri)))
128131

129132
node.keywords.append(ast.keyword(arg="model_dir", value=ast.NameConstant(value=False)))
130133

@@ -155,19 +158,22 @@ def _to_ast_keyword(self, hps):
155158
return None
156159

157160
def _image_uri_from_args(self, keywords):
158-
"""Returns a legacy TensorFlow image URI based on the estimator arguments."""
161+
"""Returns a legacy TensorFlow image URI based on the estimator arguments if possible."""
159162
tf_version = framework_version.FRAMEWORK_DEFAULTS["TensorFlow"]
160163
instance_type = "ml.m4.xlarge" # CPU default (exact type doesn't matter)
161164

162165
for kw in keywords:
163166
if kw.arg == "framework_version":
164-
tf_version = kw.value.s
167+
tf_version = kw.value.s if isinstance(kw.value, ast.Str) else None
165168
if kw.arg == "train_instance_type":
166-
instance_type = kw.value.s
169+
instance_type = kw.value.s if isinstance(kw.value, ast.Str) else None
167170

168-
return fw_utils.create_image_uri(
169-
self.region, "tensorflow", instance_type, tf_version, "py2"
170-
)
171+
if tf_version and instance_type:
172+
return fw_utils.create_image_uri(
173+
self.region, "tensorflow", instance_type, tf_version, "py2"
174+
)
175+
176+
return None
171177

172178

173179
class TensorBoardParameterRemover(Modifier):

tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_framework_version.py

Lines changed: 100 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15-
import pasta
1615
import pytest
1716

1817
from sagemaker.cli.compatibility.v2.modifiers import framework_version
@@ -36,8 +35,10 @@ def __init__(
3635
self.py_version = py_version
3736
self.py_version_for_model = py_version_for_model
3837

39-
def constructors(self, versions=False, image=False):
40-
return self._frameworks(versions, image) + self._models(versions, image)
38+
def constructors(self, fw_version=False, py_version=False, image=False):
39+
return self._frameworks(fw_version, py_version, image) + self._models(
40+
fw_version, py_version, image
41+
)
4142

4243
def _templates(self, model=False):
4344
module = self.framework.lower()
@@ -54,30 +55,38 @@ def _templates(self, model=False):
5455
for template in templates
5556
)
5657

57-
def _frameworks(self, versions=False, image=False):
58-
keywords = dict()
59-
if image:
60-
keywords["image_uri"] = "my:image"
61-
if versions:
62-
keywords["framework_version"] = self.framework_version
63-
keywords["py_version"] = self.py_version
58+
def _frameworks(self, fw_version=False, py_version=False, image=False):
59+
keywords = self._base_keywords(fw_version, image)
60+
if py_version:
61+
keywords["py_version"] = (
62+
"py_version" if py_version == "named" else "'{}'".format(self.py_version)
63+
)
6464
return _format_templates(keywords, self._templates())
6565

66-
def _models(self, versions=False, image=False):
66+
def _models(self, fw_version=False, py_version=False, image=False):
67+
keywords = self._base_keywords(fw_version, image)
68+
if py_version and self.py_version_for_model:
69+
keywords["py_version"] = (
70+
"py_version" if py_version == "named" else "'{}'".format(self.py_version)
71+
)
72+
return _format_templates(keywords, self._templates(model=True))
73+
74+
def _base_keywords(self, fw_version=False, image=False):
6775
keywords = dict()
6876
if image:
69-
keywords["image_uri"] = "my:image"
70-
if versions:
71-
keywords["framework_version"] = self.framework_version
72-
if self.py_version_for_model:
73-
keywords["py_version"] = self.py_version
74-
return _format_templates(keywords, self._templates(model=True))
77+
keywords["image_uri"] = "'my:image'"
78+
if fw_version:
79+
keywords["framework_version"] = (
80+
"fw_version" if fw_version == "named" else "'{}'".format(self.framework_version)
81+
)
82+
return keywords
7583

7684

7785
def _format_templates(keywords, templates):
7886
args = ", ".join(
79-
"{key}='{value}'".format(key=key, value=value) for key, value in keywords.items()
87+
"{key}={value}".format(key=key, value=value) for key, value in keywords.items()
8088
)
89+
8190
return [template.format(args) for template in templates]
8291

8392

@@ -100,8 +109,12 @@ def _format_templates(keywords, templates):
100109
]
101110

102111

103-
def constructors(versions=False, image=False):
104-
return [ctr for template in TEMPLATES for ctr in template.constructors(versions, image)]
112+
def constructors(fw_version=False, py_version=False, image=False):
113+
return [
114+
ctr
115+
for template in TEMPLATES
116+
for ctr in template.constructors(fw_version, py_version, image)
117+
]
105118

106119

107120
@pytest.fixture
@@ -110,18 +123,34 @@ def constructors_empty():
110123

111124

112125
@pytest.fixture
113-
def constructors_with_versions():
114-
return constructors(versions=True)
126+
def constructors_with_only_fw_version_that_need_py_version():
127+
ctrs = []
128+
for template in TEMPLATES:
129+
if template.py_version_for_model:
130+
ctrs.extend(template.constructors(fw_version=True))
131+
else:
132+
ctrs.extend(template._frameworks(fw_version=True))
133+
return ctrs
115134

116135

117136
@pytest.fixture
118-
def constructors_with_image():
119-
return constructors(image=True)
137+
def constructors_with_only_fw_version():
138+
return constructors(fw_version=True)
139+
140+
141+
@pytest.fixture
142+
def constructors_with_only_py_version():
143+
return constructors(py_version=True)
120144

121145

122146
@pytest.fixture
123-
def constructors_with_both():
124-
return constructors(versions=True, image=True)
147+
def constructors_with_both_versions():
148+
return constructors(fw_version=True, py_version=True)
149+
150+
151+
@pytest.fixture
152+
def constructors_with_image():
153+
return constructors(image=True)
125154

126155

127156
def _test_node_should_be_modified(ctrs, should_modify=True):
@@ -138,8 +167,20 @@ def test_node_should_be_modified_empty(constructors_empty):
138167
_test_node_should_be_modified(constructors_empty, should_modify=True)
139168

140169

141-
def test_node_should_be_modified_with_versions(constructors_with_versions):
142-
_test_node_should_be_modified(constructors_with_versions, should_modify=False)
170+
def test_node_should_be_modified_with_only_fw_versions(
171+
constructors_with_only_fw_version_that_need_py_version,
172+
):
173+
_test_node_should_be_modified(
174+
constructors_with_only_fw_version_that_need_py_version, should_modify=True
175+
)
176+
177+
178+
def test_node_should_be_modified_with_only_py_versions(constructors_with_only_py_version):
179+
_test_node_should_be_modified(constructors_with_only_py_version, should_modify=True)
180+
181+
182+
def test_node_should_be_modified_with_versions(constructors_with_both_versions):
183+
_test_node_should_be_modified(constructors_with_both_versions, should_modify=False)
143184

144185

145186
def test_node_should_be_modified_with_image(constructors_with_image):
@@ -155,17 +196,40 @@ def _test_modify_node(ctrs_before, ctrs_expected):
155196
for before, expected in zip(ctrs_before, ctrs_expected):
156197
node = ast_call(before)
157198
modifier.modify_node(node)
158-
# NOTE: this type of equality with pasta depends on ordering of args...
159-
assert expected == pasta.dump(node)
199+
_assert_equal_kwargs(ast_call(expected), node)
200+
201+
202+
def _assert_equal_kwargs(expected, actual):
203+
assert _keywords_for_node(expected) == _keywords_for_node(actual)
160204

161205

162-
def test_modify_node_empty(constructors_empty, constructors_with_versions):
163-
_test_modify_node(constructors_empty, constructors_with_versions)
206+
def _keywords_for_node(node):
207+
return {kw.arg: getattr(kw.value, kw.value._fields[0]) for kw in node.keywords}
164208

165209

166-
def test_modify_node_with_versions(constructors_with_versions):
167-
_test_modify_node(constructors_with_versions, constructors_with_versions)
210+
def test_modify_node_empty(constructors_empty, constructors_with_both_versions):
211+
_test_modify_node(constructors_empty, constructors_with_both_versions)
168212

169213

170-
def test_modify_node_with_image(constructors_with_image, constructors_with_both):
171-
_test_modify_node(constructors_with_image, constructors_with_both)
214+
def test_modify_node_only_fw_version(
215+
constructors_with_only_fw_version, constructors_with_both_versions
216+
):
217+
_test_modify_node(constructors_with_only_fw_version, constructors_with_both_versions)
218+
219+
220+
def test_modify_node_only_py_version(
221+
constructors_with_only_py_version, constructors_with_both_versions
222+
):
223+
_test_modify_node(constructors_with_only_py_version, constructors_with_both_versions)
224+
225+
226+
def test_modify_node_only_named_fw_version():
227+
_test_modify_node(
228+
constructors(fw_version="named"), constructors(fw_version="named", py_version="literal")
229+
)
230+
231+
232+
def test_modify_node_only_named_py_version():
233+
_test_modify_node(
234+
constructors(py_version="named"), constructors(fw_version="literal", py_version="named")
235+
)

tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_tf_legacy_mode.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,20 @@ def test_node_should_be_modified_tf_constructor_script_mode():
5151
"TensorFlow(py_version='py3')",
5252
"TensorFlow(py_version='py37')",
5353
"TensorFlow(py_version='py3', script_mode=False)",
54+
"TensorFlow(py_version=py_version, script_mode=False)",
55+
"TensorFlow(py_version='py3', script_mode=script_mode)",
5456
"sagemaker.tensorflow.TensorFlow(script_mode=True)",
5557
"sagemaker.tensorflow.TensorFlow(py_version='py3')",
5658
"sagemaker.tensorflow.TensorFlow(py_version='py37')",
5759
"sagemaker.tensorflow.TensorFlow(py_version='py3', script_mode=False)",
60+
"sagemaker.tensorflow.TensorFlow(py_version=py_version, script_mode=False)",
61+
"sagemaker.tensorflow.TensorFlow(py_version='py3', script_mode=script_mode)",
5862
"sagemaker.tensorflow.estimator.TensorFlow(script_mode=True)",
5963
"sagemaker.tensorflow.estimator.TensorFlow(py_version='py3')",
6064
"sagemaker.tensorflow.estimator.TensorFlow(py_version='py37')",
6165
"sagemaker.tensorflow.estimator.TensorFlow(py_version='py3', script_mode=False)",
66+
"sagemaker.tensorflow.estimator.TensorFlow(py_version=py_version, script_mode=False)",
67+
"sagemaker.tensorflow.estimator.TensorFlow(py_version='py3', script_mode=script_mode)",
6268
)
6369

6470
modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()

0 commit comments

Comments
 (0)