Skip to content

Commit 394fde9

Browse files
committed
change: handle image_uri rename for Session methods
1 parent fe68d4e commit 394fde9

File tree

2 files changed

+167
-10
lines changed

2 files changed

+167
-10
lines changed

src/sagemaker/cli/compatibility/v2/modifiers/renamed_params.py

Lines changed: 73 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,22 @@ def modify_node(self, node):
6868
keyword.arg = self.new_param_name
6969

7070

71+
class MethodParamRenamer(ParamRenamer):
72+
"""Abstract class to handle parameter renames for methods that belong to objects."""
73+
74+
def node_should_be_modified(self, node):
75+
"""Checks if the node matches any of the relevant functions and
76+
contains the parameter to be renamed.
77+
78+
This looks for a call of the form ``<object>.<method>``, and
79+
assumes the method cannot be called on its own.
80+
"""
81+
if isinstance(node.func, ast.Name):
82+
return False
83+
84+
return super(MethodParamRenamer, self).node_should_be_modified(node)
85+
86+
7187
class DistributionParameterRenamer(ParamRenamer):
7288
"""A class to rename the ``distributions`` attribute to ``distrbution`` in
7389
MXNet and TensorFlow estimators.
@@ -100,7 +116,7 @@ def new_param_name(self):
100116
return "distribution"
101117

102118

103-
class S3SessionRenamer(ParamRenamer):
119+
class S3SessionRenamer(MethodParamRenamer):
104120
"""A class to rename the ``session`` attribute to ``sagemaker_session`` in
105121
``S3Uploader`` and ``S3Downloader``.
106122
@@ -139,15 +155,6 @@ def new_param_name(self):
139155
"""The new name for the SageMaker session argument."""
140156
return "sagemaker_session"
141157

142-
def node_should_be_modified(self, node):
143-
"""Checks if the node is one of the S3 utility functions and
144-
contains the ``session`` parameter.
145-
"""
146-
if isinstance(node.func, ast.Name):
147-
return False
148-
149-
return super(S3SessionRenamer, self).node_should_be_modified(node)
150-
151158

152159
class EstimatorImageURIRenamer(ParamRenamer):
153160
"""A class to rename the ``image_name`` attribute to ``image_uri`` in estimators."""
@@ -208,4 +215,60 @@ def old_param_name(self):
208215
@property
209216
def new_param_name(self):
210217
"""The new name for the image URI argument."""
218+
219+
220+
class SessionCreateModelImageURIRenamer(MethodParamRenamer):
221+
"""A class to rename ``primary_container_image`` to ``image_uri``.
222+
223+
This looks for the following calls:
224+
225+
- ``sagemaker_session.create_model_from_job()``
226+
- ``sess.create_model_from_job()``
227+
"""
228+
229+
@property
230+
def calls_to_modify(self):
231+
"""A mapping of ``create_model_from_job`` to common variable names for Session."""
232+
return {
233+
"create_model_from_job": ("sagemaker_session", "sess"),
234+
}
235+
236+
@property
237+
def old_param_name(self):
238+
"""The previous name for the image URI argument."""
239+
return "primary_container_image"
240+
241+
@property
242+
def new_param_name(self):
243+
"""The new name for the the image URI argument."""
244+
return "image_uri"
245+
246+
247+
class SessionCreateEndpointImageURIRenamer(MethodParamRenamer):
248+
"""A class to rename ``deployment_image`` to ``image_uri``.
249+
250+
This looks for the following calls:
251+
252+
- ``sagemaker_session.endpoint_from_job()``
253+
- ``sess.endpoint_from_job()``
254+
- ``sagemaker_session.endpoint_from_model_data()``
255+
- ``sess.endpoint_from_model_data()``
256+
"""
257+
258+
@property
259+
def calls_to_modify(self):
260+
"""A mapping of the ``endpoint_from_*`` functions to common variable names for Session."""
261+
return {
262+
"endpoint_from_job": ("sagemaker_session", "sess"),
263+
"endpoint_from_model_data": ("sagemaker_session", "sess"),
264+
}
265+
266+
@property
267+
def old_param_name(self):
268+
"""The previous name for the image URI argument."""
269+
return "deployment_image"
270+
271+
@property
272+
def new_param_name(self):
273+
"""The new name for the the image URI argument."""
211274
return "image_uri"
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pasta
16+
17+
from sagemaker.cli.compatibility.v2.modifiers import renamed_params
18+
from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call
19+
20+
CREATE_MODEL_TEMPLATES = (
21+
"sagemaker_session.create_model_from_job({})",
22+
"sess.create_model_from_job({})",
23+
)
24+
25+
CREATE_ENDPOINT_TEMPLATES = (
26+
"sagemaker_session.endpoint_from_job({})",
27+
"sagemaker_session.endpoint_from_model_data({})",
28+
"sess.endpoint_from_job({})",
29+
"sess.endpoint_from_model_data({})",
30+
)
31+
32+
33+
def test_create_model_node_should_be_modified():
34+
modifier = renamed_params.SessionCreateModelImageURIRenamer()
35+
36+
for template in CREATE_MODEL_TEMPLATES:
37+
call = ast_call(template.format("primary_container_image=my_image"))
38+
assert modifier.node_should_be_modified(call)
39+
40+
41+
def test_create_model_node_should_be_modified_no_image():
42+
modifier = renamed_params.SessionCreateModelImageURIRenamer()
43+
44+
for template in CREATE_MODEL_TEMPLATES:
45+
call = ast_call(template.format(""))
46+
assert not modifier.node_should_be_modified(call)
47+
48+
49+
def test_create_model_node_should_be_modified_random_function_call():
50+
modifier = renamed_params.SessionCreateModelImageURIRenamer()
51+
assert not modifier.node_should_be_modified(ast_call("create_model()"))
52+
53+
54+
def test_create_model_modify_node():
55+
modifier = renamed_params.SessionCreateModelImageURIRenamer()
56+
57+
for template in CREATE_MODEL_TEMPLATES:
58+
call = ast_call(template.format("primary_container_image=my_image"))
59+
modifier.modify_node(call)
60+
61+
expected = template.format("image_uri=my_image")
62+
assert expected == pasta.dump(call)
63+
64+
65+
def test_create_endpoint_node_should_be_modified():
66+
modifier = renamed_params.SessionCreateEndpointImageURIRenamer()
67+
68+
for template in CREATE_ENDPOINT_TEMPLATES:
69+
call = ast_call(template.format("deployment_image=my_image"))
70+
assert modifier.node_should_be_modified(call)
71+
72+
73+
def test_create_endpoint_node_should_be_modified_no_image():
74+
modifier = renamed_params.SessionCreateEndpointImageURIRenamer()
75+
76+
for template in CREATE_ENDPOINT_TEMPLATES:
77+
call = ast_call(template.format(""))
78+
assert not modifier.node_should_be_modified(call)
79+
80+
81+
def test_create_endpoint_node_should_be_modified_random_function_call():
82+
modifier = renamed_params.SessionCreateEndpointImageURIRenamer()
83+
assert not modifier.node_should_be_modified(ast_call("create_endpoint()"))
84+
85+
86+
def test_create_endpoint_modify_node():
87+
modifier = renamed_params.SessionCreateEndpointImageURIRenamer()
88+
89+
for template in CREATE_ENDPOINT_TEMPLATES:
90+
call = ast_call(template.format("deployment_image=my_image"))
91+
modifier.modify_node(call)
92+
93+
expected = template.format("image_uri=my_image")
94+
assert expected == pasta.dump(call)

0 commit comments

Comments
 (0)