29
29
30
30
logger = logging .getLogger ("sagemaker" )
31
31
32
- # TODO: consider creating a function for generating this command before removing this constant
33
- _SCRIPT_MODE_TENSORBOARD_WARNING = (
34
- "Tensorboard is not supported with script mode. You can run the following "
35
- "command: tensorboard --logdir %s --host localhost --port 6006 This can be "
36
- "run from anywhere with access to the S3 URI used as the logdir."
37
- )
38
-
39
32
40
33
class TensorFlow (Framework ):
41
34
"""Handle end-to-end training and deployment of user-provided TensorFlow code."""
42
35
43
36
__framework_name__ = "tensorflow"
44
- _SCRIPT_MODE_REPO_NAME = "tensorflow-scriptmode"
37
+ _ECR_REPO_NAME = "tensorflow-scriptmode"
45
38
46
39
LATEST_VERSION = defaults .LATEST_VERSION
47
40
48
41
_LATEST_1X_VERSION = "1.15.2"
49
42
50
43
_HIGHEST_LEGACY_MODE_ONLY_VERSION = version .Version ("1.10.0" )
51
- _LOWEST_SCRIPT_MODE_ONLY_VERSION = version .Version ("1.13.1" )
52
-
53
44
_HIGHEST_PYTHON_2_VERSION = version .Version ("2.1.0" )
54
45
55
46
def __init__ (
@@ -59,7 +50,6 @@ def __init__(
59
50
model_dir = None ,
60
51
image_name = None ,
61
52
distributions = None ,
62
- script_mode = True ,
63
53
** kwargs
64
54
):
65
55
"""Initialize a ``TensorFlow`` estimator.
@@ -82,6 +72,8 @@ def __init__(
82
72
* *Local Mode with local sources (file:// instead of s3://)* - \
83
73
``/opt/ml/shared/model``
84
74
75
+ To disable having ``model_dir`` passed to your training script,
76
+ set ``model_dir=False``.
85
77
image_name (str): If specified, the estimator will use this image for training and
86
78
hosting, instead of selecting the appropriate SageMaker official image based on
87
79
framework_version and py_version. It can be an ECR url or dockerhub image and tag.
@@ -114,8 +106,6 @@ def __init__(
114
106
}
115
107
}
116
108
117
- script_mode (bool): Whether or not to use the Script Mode TensorFlow images
118
- (default: True).
119
109
**kwargs: Additional kwargs passed to the Framework constructor.
120
110
121
111
.. tip::
@@ -154,7 +144,6 @@ def __init__(
154
144
self .model_dir = model_dir
155
145
self .distributions = distributions or {}
156
146
157
- self ._script_mode_enabled = script_mode
158
147
self ._validate_args (py_version = py_version , framework_version = self .framework_version )
159
148
160
149
def _validate_args (self , py_version , framework_version ):
@@ -171,30 +160,33 @@ def _validate_args(self, py_version, framework_version):
171
160
)
172
161
raise AttributeError (msg )
173
162
174
- if (not self ._script_mode_enabled ) and self ._only_script_mode_supported ():
175
- logger .warning (
176
- "Legacy mode is deprecated in versions 1.13 and higher. Using script mode instead."
163
+ if self ._only_legacy_mode_supported () and self .image_name is None :
164
+ additional_instructions = ""
165
+ if self .model_dir is not False :
166
+ additional_instructions = " and set 'model_dir=False'"
167
+
168
+ legacy_image_uri = fw .create_image_uri (
169
+ self .sagemaker_session .boto_region_name ,
170
+ "tensorflow" ,
171
+ self .train_instance_type ,
172
+ self .framework_version ,
173
+ self .py_version ,
177
174
)
178
- self ._script_mode_enabled = True
179
175
180
- if self ._only_legacy_mode_supported ():
181
176
# TODO: add link to docs to explain how to use legacy mode with v2
182
- logger .warning (
183
- "TF %s supports only legacy mode. If you were using any legacy mode parameters "
177
+ msg = (
178
+ "TF {} supports only legacy mode. Please supply the image URI directly with "
179
+ "'image_name={}'{}. If you were using any legacy mode parameters "
184
180
"(training_steps, evaluation_steps, checkpoint_path, requirements_file), "
185
- "make sure to pass them directly as hyperparameters instead." ,
186
- self .framework_version ,
187
- )
188
- self . _script_mode_enabled = False
181
+ "make sure to pass them directly as hyperparameters instead."
182
+ ). format ( self .framework_version , legacy_image_uri , additional_instructions )
183
+
184
+ raise ValueError ( msg )
189
185
190
186
def _only_legacy_mode_supported (self ):
191
187
"""Placeholder docstring"""
192
188
return version .Version (self .framework_version ) <= self ._HIGHEST_LEGACY_MODE_ONLY_VERSION
193
189
194
- def _only_script_mode_supported (self ):
195
- """Placeholder docstring"""
196
- return version .Version (self .framework_version ) >= self ._LOWEST_SCRIPT_MODE_ONLY_VERSION
197
-
198
190
def _only_python_3_supported (self ):
199
191
"""Placeholder docstring"""
200
192
return version .Version (self .framework_version ) > self ._HIGHEST_PYTHON_2_VERSION
@@ -214,10 +206,6 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
214
206
job_details , model_channel_name
215
207
)
216
208
217
- model_dir = init_params ["hyperparameters" ].pop ("model_dir" , None )
218
- if model_dir is not None :
219
- init_params ["model_dir" ] = model_dir
220
-
221
209
image_name = init_params .pop ("image" )
222
210
framework , py_version , tag , script_mode = fw .framework_name_from_image (image_name )
223
211
if not framework :
@@ -226,8 +214,11 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
226
214
init_params ["image_name" ] = image_name
227
215
return init_params
228
216
229
- if script_mode is None :
230
- init_params ["script_mode" ] = False
217
+ model_dir = init_params ["hyperparameters" ].pop ("model_dir" , None )
218
+ if model_dir :
219
+ init_params ["model_dir" ] = model_dir
220
+ elif script_mode is None :
221
+ init_params ["model_dir" ] = False
231
222
232
223
init_params ["py_version" ] = py_version
233
224
@@ -239,6 +230,10 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
239
230
"1.4" if tag == "1.0" else fw .framework_version_from_tag (tag )
240
231
)
241
232
233
+ # Legacy images are required to be passed in explicitly.
234
+ if not script_mode :
235
+ init_params ["image_name" ] = image_name
236
+
242
237
training_job_name = init_params ["base_job_name" ]
243
238
if framework != cls .__framework_name__ :
244
239
raise ValueError (
@@ -309,27 +304,26 @@ def hyperparameters(self):
309
304
hyperparameters = super (TensorFlow , self ).hyperparameters ()
310
305
additional_hyperparameters = {}
311
306
312
- if self ._script_mode_enabled :
313
- mpi_enabled = False
314
-
315
- if "parameter_server" in self .distributions :
316
- ps_enabled = self .distributions ["parameter_server" ].get ("enabled" , False )
317
- additional_hyperparameters [self .LAUNCH_PS_ENV_NAME ] = ps_enabled
307
+ if "parameter_server" in self .distributions :
308
+ ps_enabled = self .distributions ["parameter_server" ].get ("enabled" , False )
309
+ additional_hyperparameters [self .LAUNCH_PS_ENV_NAME ] = ps_enabled
318
310
319
- if "mpi" in self .distributions :
320
- mpi_dict = self .distributions ["mpi" ]
321
- mpi_enabled = mpi_dict .get ("enabled" , False )
322
- additional_hyperparameters [self .LAUNCH_MPI_ENV_NAME ] = mpi_enabled
311
+ mpi_enabled = False
312
+ if "mpi" in self .distributions :
313
+ mpi_dict = self .distributions ["mpi" ]
314
+ mpi_enabled = mpi_dict .get ("enabled" , False )
315
+ additional_hyperparameters [self .LAUNCH_MPI_ENV_NAME ] = mpi_enabled
323
316
324
- if mpi_dict .get ("processes_per_host" ):
325
- additional_hyperparameters [self .MPI_NUM_PROCESSES_PER_HOST ] = mpi_dict .get (
326
- "processes_per_host"
327
- )
328
-
329
- additional_hyperparameters [self .MPI_CUSTOM_MPI_OPTIONS ] = mpi_dict .get (
330
- "custom_mpi_options" , ""
317
+ if mpi_dict .get ("processes_per_host" ):
318
+ additional_hyperparameters [self .MPI_NUM_PROCESSES_PER_HOST ] = mpi_dict .get (
319
+ "processes_per_host"
331
320
)
332
321
322
+ additional_hyperparameters [self .MPI_CUSTOM_MPI_OPTIONS ] = mpi_dict .get (
323
+ "custom_mpi_options" , ""
324
+ )
325
+
326
+ if self .model_dir is not False :
333
327
self .model_dir = self .model_dir or self ._default_s3_path ("model" , mpi = mpi_enabled )
334
328
additional_hyperparameters ["model_dir" ] = self .model_dir
335
329
@@ -375,16 +369,13 @@ def train_image(self):
375
369
if self .image_name :
376
370
return self .image_name
377
371
378
- if self ._script_mode_enabled :
379
- return fw .create_image_uri (
380
- self .sagemaker_session .boto_region_name ,
381
- self ._SCRIPT_MODE_REPO_NAME ,
382
- self .train_instance_type ,
383
- self .framework_version ,
384
- self .py_version ,
385
- )
386
-
387
- return super (TensorFlow , self ).train_image ()
372
+ return fw .create_image_uri (
373
+ self .sagemaker_session .boto_region_name ,
374
+ self ._ECR_REPO_NAME ,
375
+ self .train_instance_type ,
376
+ self .framework_version ,
377
+ self .py_version ,
378
+ )
388
379
389
380
def transformer (
390
381
self ,
0 commit comments