@@ -68,6 +68,22 @@ def modify_node(self, node):
68
68
keyword .arg = self .new_param_name
69
69
70
70
71
+ class MethodParamRenamer (ParamRenamer ):
72
+ """Abstract class to handle parameter renames for methods that belong to objects."""
73
+
74
+ def node_should_be_modified (self , node ):
75
+ """Checks if the node matches any of the relevant functions and
76
+ contains the parameter to be renamed.
77
+
78
+ This looks for a call of the form ``<object>.<method>``, and
79
+ assumes the method cannot be called on its own.
80
+ """
81
+ if isinstance (node .func , ast .Name ):
82
+ return False
83
+
84
+ return super (MethodParamRenamer , self ).node_should_be_modified (node )
85
+
86
+
71
87
class DistributionParameterRenamer (ParamRenamer ):
72
88
"""A class to rename the ``distributions`` attribute to ``distrbution`` in
73
89
MXNet and TensorFlow estimators.
@@ -100,7 +116,7 @@ def new_param_name(self):
100
116
return "distribution"
101
117
102
118
103
- class S3SessionRenamer (ParamRenamer ):
119
+ class S3SessionRenamer (MethodParamRenamer ):
104
120
"""A class to rename the ``session`` attribute to ``sagemaker_session`` in
105
121
``S3Uploader`` and ``S3Downloader``.
106
122
@@ -139,15 +155,6 @@ def new_param_name(self):
139
155
"""The new name for the SageMaker session argument."""
140
156
return "sagemaker_session"
141
157
142
- def node_should_be_modified (self , node ):
143
- """Checks if the node is one of the S3 utility functions and
144
- contains the ``session`` parameter.
145
- """
146
- if isinstance (node .func , ast .Name ):
147
- return False
148
-
149
- return super (S3SessionRenamer , self ).node_should_be_modified (node )
150
-
151
158
152
159
class EstimatorImageURIRenamer (ParamRenamer ):
153
160
"""A class to rename the ``image_name`` attribute to ``image_uri`` in estimators."""
@@ -208,4 +215,60 @@ def old_param_name(self):
208
215
@property
209
216
def new_param_name (self ):
210
217
"""The new name for the image URI argument."""
218
+
219
+
220
+ class SessionCreateModelImageURIRenamer (MethodParamRenamer ):
221
+ """A class to rename ``primary_container_image`` to ``image_uri``.
222
+
223
+ This looks for the following calls:
224
+
225
+ - ``sagemaker_session.create_model_from_job()``
226
+ - ``sess.create_model_from_job()``
227
+ """
228
+
229
+ @property
230
+ def calls_to_modify (self ):
231
+ """A mapping of ``create_model_from_job`` to common variable names for Session."""
232
+ return {
233
+ "create_model_from_job" : ("sagemaker_session" , "sess" ),
234
+ }
235
+
236
+ @property
237
+ def old_param_name (self ):
238
+ """The previous name for the image URI argument."""
239
+ return "primary_container_image"
240
+
241
+ @property
242
+ def new_param_name (self ):
243
+ """The new name for the the image URI argument."""
244
+ return "image_uri"
245
+
246
+
247
+ class SessionCreateEndpointImageURIRenamer (MethodParamRenamer ):
248
+ """A class to rename ``deployment_image`` to ``image_uri``.
249
+
250
+ This looks for the following calls:
251
+
252
+ - ``sagemaker_session.endpoint_from_job()``
253
+ - ``sess.endpoint_from_job()``
254
+ - ``sagemaker_session.endpoint_from_model_data()``
255
+ - ``sess.endpoint_from_model_data()``
256
+ """
257
+
258
+ @property
259
+ def calls_to_modify (self ):
260
+ """A mapping of the ``endpoint_from_*`` functions to common variable names for Session."""
261
+ return {
262
+ "endpoint_from_job" : ("sagemaker_session" , "sess" ),
263
+ "endpoint_from_model_data" : ("sagemaker_session" , "sess" ),
264
+ }
265
+
266
+ @property
267
+ def old_param_name (self ):
268
+ """The previous name for the image URI argument."""
269
+ return "deployment_image"
270
+
271
+ @property
272
+ def new_param_name (self ):
273
+ """The new name for the the image URI argument."""
211
274
return "image_uri"
0 commit comments