@@ -185,6 +185,8 @@ def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=N
185
185
py_version (str): Python version you want to use for executing your model training code (default: 'py2').
186
186
framework_version (str): TensorFlow version you want to use for executing your model training code.
187
187
List of supported versions https://github.com/aws/sagemaker-python-sdk#tensorflow-sagemaker-estimators
188
+ model_dir (str): S3 location where the checkpoint data and models can be exported to during training
189
+ (default: None). If not specified a default S3 URI will be generated.
188
190
requirements_file (str): Path to a ``requirements.txt`` file (default: ''). The path should be within and
189
191
relative to ``source_dir``. Details on the format can be found in the
190
192
`Pip User Guide <https://pip.pypa.io/en/stable/reference/pip_install/#requirements-file-format>`_.
@@ -194,6 +196,10 @@ def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=N
194
196
Examples:
195
197
123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
196
198
custom-image:latest.
199
+ script_mode (bool): If set to True will the estimator will use the Script Mode containers (default: False).
200
+ This will be ignored if py_version is set to 'py3'.
201
+ distribution (dict): A dictionary with information on how to run distributed training
202
+ (default: None).
197
203
**kwargs: Additional kwargs passed to the Framework constructor.
198
204
"""
199
205
if framework_version is None :
@@ -207,7 +213,7 @@ def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=N
207
213
self .evaluation_steps = evaluation_steps
208
214
self .model_dir = model_dir
209
215
self .script_mode = script_mode
210
- self .distributions = distributions
216
+ self .distributions = distributions or {}
211
217
212
218
self ._validate_args (py_version = py_version , script_mode = script_mode , framework_version = framework_version ,
213
219
training_steps = training_steps , evaluation_steps = evaluation_steps ,
@@ -283,12 +289,11 @@ def fit_super():
283
289
if run_tensorboard_locally and wait is False :
284
290
raise ValueError ("Tensorboard is not supported with async fit" )
285
291
286
- if run_tensorboard_locally :
287
-
288
- if self .script_mode_enabled ():
292
+ if self ._script_mode_enabled ():
293
+ if run_tensorboard_locally :
289
294
LOGGER .warning (_SCRIPT_MODE_TENSORBOARD_WARNING .format (self .model_dir ))
290
- return
291
-
295
+ fit_super ()
296
+ elif run_tensorboard_locally :
292
297
tensorboard = Tensorboard (self )
293
298
tensorboard .validate_requirements ()
294
299
@@ -371,12 +376,9 @@ def create_model(self, model_server_workers=None, role=None,
371
376
"""
372
377
373
378
role = role or self .role
374
- if endpoint_type == 'tensorflow-serving' :
379
+ if endpoint_type == 'tensorflow-serving' or self . _script_mode_enabled () :
375
380
return self ._create_tfs_model (role = role , vpc_config_override = vpc_config_override )
376
381
377
- if self .script_mode_enabled ():
378
- raise ValueError (_SCRIPT_MODE_SERVING_ERROR_MSG )
379
-
380
382
return self ._create_default_model (model_server_workers = model_server_workers , role = role ,
381
383
vpc_config_override = vpc_config_override )
382
384
@@ -408,17 +410,14 @@ def hyperparameters(self):
408
410
"""Return hyperparameters used by your custom TensorFlow code during model training."""
409
411
hyperparameters = super (TensorFlow , self ).hyperparameters ()
410
412
411
- if not self .checkpoint_path :
412
- self .checkpoint_path = self ._default_s3_path ('checkpoints' )
413
+ self .checkpoint_path = self .checkpoint_path or self ._default_s3_path ('checkpoints' )
413
414
414
- if self .script_mode_enabled ():
415
- if not self .model_dir :
416
- self .model_dir = self ._default_s3_path ('model' )
415
+ if self ._script_mode_enabled ():
416
+ self .model_dir = self .model_dir or self ._default_s3_path ('model' )
417
417
additional_hyperparameters = {'model_dir' : self .model_dir }
418
- if self .distributions :
419
- if 'parameter_server' in self .distributions :
420
- enabled = self .distributions ['parameter_server' ].get ('enabled' , False )
421
- additional_hyperparameters [self .LAUNCH_PS_ENV_NAME ] = enabled
418
+ if 'parameter_server' in self .distributions :
419
+ enabled = self .distributions ['parameter_server' ].get ('enabled' , False )
420
+ additional_hyperparameters [self .LAUNCH_PS_ENV_NAME ] = enabled
422
421
else :
423
422
additional_hyperparameters = {'checkpoint_path' : self .checkpoint_path ,
424
423
'training_steps' : self .training_steps ,
@@ -435,15 +434,15 @@ def _default_s3_path(self, directory):
435
434
else :
436
435
return os .path .join (self .output_path , self ._current_job_name , directory )
437
436
438
- def script_mode_enabled (self ):
437
+ def _script_mode_enabled (self ):
439
438
return self .py_version == 'py3' or self .script_mode
440
439
441
440
def train_image (self ):
442
441
if self .image_name :
443
442
return self .image_name
444
443
445
- if self .script_mode_enabled ():
444
+ if self ._script_mode_enabled ():
446
445
return fw .create_image_uri (self .sagemaker_session .boto_region_name , _SCRIPT_MODE ,
447
446
self .train_instance_type , self .framework_version , self .py_version )
448
- else :
449
- return super (TensorFlow , self ).train_image ()
447
+
448
+ return super (TensorFlow , self ).train_image ()
0 commit comments