Skip to content

Commit 89e6fd2

Browse files
committed
handle estimator.create_model()
1 parent 4c3166e commit 89e6fd2

File tree

2 files changed

+95
-0
lines changed

2 files changed

+95
-0
lines changed

src/sagemaker/cli/compatibility/v2/modifiers/renamed_params.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,39 @@ def new_param_name(self):
218218
return "image_uri"
219219

220220

221+
class EstimatorCreateModelImageURIRenamer(MethodParamRenamer):
222+
"""A class to rename ``image`` to ``image_uri`` in estimator ``create_model()`` methods."""
223+
224+
@property
225+
def calls_to_modify(self):
226+
"""A mapping of ``create_model`` to common variable names for estimators."""
227+
return {
228+
"create_model": (
229+
"estimator",
230+
"chainer",
231+
"mxnet",
232+
"mx",
233+
"pytorch",
234+
"rl",
235+
"sklearn",
236+
"tensorflow",
237+
"tf",
238+
"xgboost",
239+
"xgb",
240+
)
241+
}
242+
243+
@property
244+
def old_param_name(self):
245+
"""The previous name for the image URI argument."""
246+
return "image"
247+
248+
@property
249+
def new_param_name(self):
250+
"""The new name for the the image URI argument."""
251+
return "image_uri"
252+
253+
221254
class SessionCreateModelImageURIRenamer(MethodParamRenamer):
222255
"""A class to rename ``primary_container_image`` to ``image_uri``.
223256
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pasta
16+
17+
from sagemaker.cli.compatibility.v2.modifiers import renamed_params
18+
from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call
19+
20+
ESTIMATORS = (
21+
"estimator",
22+
"chainer",
23+
"mxnet",
24+
"mx",
25+
"pytorch",
26+
"rl",
27+
"sklearn",
28+
"tensorflow",
29+
"tf",
30+
"xgboost",
31+
"xgb",
32+
)
33+
34+
35+
def test_node_should_be_modified():
36+
modifier = renamed_params.EstimatorCreateModelImageURIRenamer()
37+
38+
for estimator in ESTIMATORS:
39+
call = "{}.create_model(image='my-image:latest')".format(estimator)
40+
assert modifier.node_should_be_modified(ast_call(call))
41+
42+
43+
def test_node_should_be_modified_no_distribution():
44+
modifier = renamed_params.EstimatorCreateModelImageURIRenamer()
45+
46+
for estimator in ESTIMATORS:
47+
call = "{}.create_model()".format(estimator)
48+
assert not modifier.node_should_be_modified(ast_call(call))
49+
50+
51+
def test_node_should_be_modified_random_function_call():
52+
modifier = renamed_params.EstimatorCreateModelImageURIRenamer()
53+
assert not modifier.node_should_be_modified(ast_call("create_model()"))
54+
55+
56+
def test_modify_node():
57+
node = ast_call("estimator.create_model(image=my_image)")
58+
modifier = renamed_params.EstimatorCreateModelImageURIRenamer()
59+
modifier.modify_node(node)
60+
61+
expected = "estimator.create_model(image_uri=my_image)"
62+
assert expected == pasta.dump(node)

0 commit comments

Comments
 (0)