22
22
import time
23
23
24
24
from sagemaker .estimator import Framework
25
- from sagemaker .fw_utils import framework_name_from_image , framework_version_from_tag , \
26
- empty_framework_version_warning
25
+ import sagemaker .fw_utils as fw
27
26
from sagemaker .tensorflow .defaults import TF_VERSION
28
27
from sagemaker .tensorflow .model import TensorFlowModel
29
28
from sagemaker .tensorflow .serving import Model
34
33
LOGGER = logging .getLogger ('sagemaker' )
35
34
36
35
36
+ _FRAMEWORK_MODE_ARGS = ('training_steps' , 'evaluation_steps' , 'requirements_file' , 'checkpoint_path' )
37
+ _SCRIPT_MODE = 'tensorflow-scriptmode'
38
+ _SCRIPT_MODE_SERVING_ERROR_MSG = 'Script mode containers does not support serving yet. ' \
39
+ 'Please use our new tensorflow-serving container by creating the model ' \
40
+ 'with \' endpoint_type\' set to \' tensorflow-serving\' .'
41
+ _SCRIPT_MODE_TENSORBOARD_WARNING = 'Tensorboard is not supported with script mode. You can run the following ' \
42
+ 'command: tensorboard --logdir {} --host localhost --port 6006 This can be ' \
43
+ 'run from anywhere with access to the S3 URI used as the logdir.'
44
+
45
+
37
46
class Tensorboard (threading .Thread ):
38
47
def __init__ (self , estimator , logdir = None ):
39
48
"""Initialize ``Tensorboard`` instance.
@@ -163,9 +172,9 @@ class TensorFlow(Framework):
163
172
164
173
__framework_name__ = 'tensorflow'
165
174
166
- def __init__ (self , training_steps = None , evaluation_steps = None , checkpoint_path = None ,
167
- py_version = 'py2' , framework_version = None , requirements_file = '' , image_name = None ,
168
- ** kwargs ):
175
+ def __init__ (self , training_steps = None , evaluation_steps = None , checkpoint_path = None , py_version = 'py2' ,
176
+ framework_version = None , model_dir = None , requirements_file = '' , image_name = None ,
177
+ script_mode = False , distributions = None , ** kwargs ):
169
178
"""Initialize an ``TensorFlow`` estimator.
170
179
Args:
171
180
training_steps (int): Perform this many steps of training. `None`, the default means train forever.
@@ -176,6 +185,9 @@ def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=N
176
185
py_version (str): Python version you want to use for executing your model training code (default: 'py2').
177
186
framework_version (str): TensorFlow version you want to use for executing your model training code.
178
187
List of supported versions https://github.com/aws/sagemaker-python-sdk#tensorflow-sagemaker-estimators
188
+ model_dir (str): S3 location where the checkpoint data and models can be exported to during training
189
+ (default: None). If not specified a default S3 URI will be generated. It will be passed in the
190
+ training script as one of the command line arguments.
179
191
requirements_file (str): Path to a ``requirements.txt`` file (default: ''). The path should be within and
180
192
relative to ``source_dir``. Details on the format can be found in the
181
193
`Pip User Guide <https://pip.pypa.io/en/stable/reference/pip_install/#requirements-file-format>`_.
@@ -185,21 +197,61 @@ def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=N
185
197
Examples:
186
198
123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
187
199
custom-image:latest.
200
+ script_mode (bool): If set to True will the estimator will use the Script Mode containers (default: False).
201
+ This will be ignored if py_version is set to 'py3'.
202
+ distribution (dict): A dictionary with information on how to run distributed training
203
+ (default: None). Currently we only support distributed training with parameter servers. To enable it
204
+ use the following setup:
205
+ {
206
+ 'parameter_server':
207
+ {
208
+ 'enabled': True
209
+ }
210
+ }
188
211
**kwargs: Additional kwargs passed to the Framework constructor.
189
212
"""
190
213
if framework_version is None :
191
- LOGGER .warning (empty_framework_version_warning (TF_VERSION , TF_VERSION ))
214
+ LOGGER .warning (fw . empty_framework_version_warning (TF_VERSION , TF_VERSION ))
192
215
self .framework_version = framework_version or TF_VERSION
193
216
194
217
super (TensorFlow , self ).__init__ (image_name = image_name , ** kwargs )
195
218
self .checkpoint_path = checkpoint_path
196
219
self .py_version = py_version
197
220
self .training_steps = training_steps
198
221
self .evaluation_steps = evaluation_steps
222
+ self .model_dir = model_dir
223
+ self .script_mode = script_mode
224
+ self .distributions = distributions or {}
199
225
226
+ self ._validate_args (py_version = py_version , script_mode = script_mode , framework_version = framework_version ,
227
+ training_steps = training_steps , evaluation_steps = evaluation_steps ,
228
+ requirements_file = requirements_file , checkpoint_path = checkpoint_path )
200
229
self ._validate_requirements_file (requirements_file )
201
230
self .requirements_file = requirements_file
202
231
232
+ def _validate_args (self , py_version , script_mode , framework_version , training_steps ,
233
+ evaluation_steps , requirements_file , checkpoint_path ):
234
+
235
+ if py_version == 'py3' or script_mode :
236
+
237
+ if framework_version is None :
238
+ raise AttributeError (fw .EMPTY_FRAMEWORK_VERSION_ERROR )
239
+
240
+ found_args = []
241
+ if training_steps :
242
+ found_args .append ('training_steps' )
243
+ if evaluation_steps :
244
+ found_args .append ('evaluation_steps' )
245
+ if requirements_file :
246
+ found_args .append ('requirements_file' )
247
+ if checkpoint_path :
248
+ found_args .append ('checkpoint_path' )
249
+ if found_args :
250
+ raise AttributeError (
251
+ '{} are deprecated in script mode. Please do not set {}.'
252
+ .format (', ' .join (_FRAMEWORK_MODE_ARGS ), ', ' .join (found_args ))
253
+ )
254
+
203
255
def _validate_requirements_file (self , requirements_file ):
204
256
if not requirements_file :
205
257
return
@@ -245,7 +297,10 @@ def fit_super():
245
297
if run_tensorboard_locally and wait is False :
246
298
raise ValueError ("Tensorboard is not supported with async fit" )
247
299
248
- if run_tensorboard_locally :
300
+ if self ._script_mode_enabled () and run_tensorboard_locally :
301
+ LOGGER .warning (_SCRIPT_MODE_TENSORBOARD_WARNING .format (self .model_dir ))
302
+ fit_super ()
303
+ elif run_tensorboard_locally :
249
304
tensorboard = Tensorboard (self )
250
305
tensorboard .validate_requirements ()
251
306
@@ -275,13 +330,13 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
275
330
model_channel_name )
276
331
277
332
# Move some of the tensorflow specific init params from hyperparameters into the main init params.
278
- for argument in [ 'checkpoint_path' , 'training_steps' , 'evaluation_steps' ] :
333
+ for argument in ( 'checkpoint_path' , 'training_steps' , 'evaluation_steps' , 'model_dir' ) :
279
334
value = init_params ['hyperparameters' ].pop (argument , None )
280
335
if value is not None :
281
336
init_params [argument ] = value
282
337
283
338
image_name = init_params .pop ('image' )
284
- framework , py_version , tag = framework_name_from_image (image_name )
339
+ framework , py_version , tag = fw . framework_name_from_image (image_name )
285
340
if not framework :
286
341
# If we were unable to parse the framework name from the image it is not one of our
287
342
# officially supported images, in this case just add the image to the init params.
@@ -294,7 +349,7 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
294
349
# containing framework version, device type and python version (e.g. '1.5-gpu-py2').
295
350
# For backward compatibility map deprecated image tag '1.0' to a '1.4' framework version
296
351
# otherwise extract framework version from the tag itself.
297
- init_params ['framework_version' ] = '1.4' if tag == '1.0' else framework_version_from_tag (
352
+ init_params ['framework_version' ] = '1.4' if tag == '1.0' else fw . framework_version_from_tag (
298
353
tag )
299
354
300
355
training_job_name = init_params ['base_job_name' ]
@@ -328,7 +383,7 @@ def create_model(self, model_server_workers=None, role=None,
328
383
"""
329
384
330
385
role = role or self .role
331
- if endpoint_type == 'tensorflow-serving' :
386
+ if endpoint_type == 'tensorflow-serving' or self . _script_mode_enabled () :
332
387
return self ._create_tfs_model (role = role , vpc_config_override = vpc_config_override )
333
388
334
389
return self ._create_default_model (model_server_workers = model_server_workers , role = role ,
@@ -362,18 +417,39 @@ def hyperparameters(self):
362
417
"""Return hyperparameters used by your custom TensorFlow code during model training."""
363
418
hyperparameters = super (TensorFlow , self ).hyperparameters ()
364
419
365
- if not self .checkpoint_path :
366
- local_code = get_config_value ('local.local_code' , self .sagemaker_session .config )
367
- if self .sagemaker_session .local_mode and local_code :
368
- self .checkpoint_path = '/opt/ml/shared/checkpoints'
369
- else :
370
- self .checkpoint_path = os .path .join (self .output_path ,
371
- self ._current_job_name , 'checkpoints' )
420
+ self .checkpoint_path = self .checkpoint_path or self ._default_s3_path ('checkpoints' )
372
421
373
- additional_hyperparameters = {'checkpoint_path' : self .checkpoint_path ,
374
- 'training_steps' : self .training_steps ,
375
- 'evaluation_steps' : self .evaluation_steps ,
376
- 'sagemaker_requirements' : self .requirements_file }
422
+ if self ._script_mode_enabled ():
423
+ self .model_dir = self .model_dir or self ._default_s3_path ('model' )
424
+ additional_hyperparameters = {'model_dir' : self .model_dir }
425
+ if 'parameter_server' in self .distributions :
426
+ enabled = self .distributions ['parameter_server' ].get ('enabled' , False )
427
+ additional_hyperparameters [self .LAUNCH_PS_ENV_NAME ] = enabled
428
+ else :
429
+ additional_hyperparameters = {'checkpoint_path' : self .checkpoint_path ,
430
+ 'training_steps' : self .training_steps ,
431
+ 'evaluation_steps' : self .evaluation_steps ,
432
+ 'sagemaker_requirements' : self .requirements_file }
377
433
378
434
hyperparameters .update (Framework ._json_encode_hyperparameters (additional_hyperparameters ))
379
435
return hyperparameters
436
+
437
+ def _default_s3_path (self , directory ):
438
+ local_code = get_config_value ('local.local_code' , self .sagemaker_session .config )
439
+ if self .sagemaker_session .local_mode and local_code :
440
+ return '/opt/ml/shared/{}' .format (directory )
441
+ else :
442
+ return os .path .join (self .output_path , self ._current_job_name , directory )
443
+
444
+ def _script_mode_enabled (self ):
445
+ return self .py_version == 'py3' or self .script_mode
446
+
447
+ def train_image (self ):
448
+ if self .image_name :
449
+ return self .image_name
450
+
451
+ if self ._script_mode_enabled ():
452
+ return fw .create_image_uri (self .sagemaker_session .boto_region_name , _SCRIPT_MODE ,
453
+ self .train_instance_type , self .framework_version , self .py_version )
454
+
455
+ return super (TensorFlow , self ).train_image ()
0 commit comments