Skip to content

Commit 8c7d644

Browse files
committed
More pr comments
1 parent 75028a0 commit 8c7d644

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

src/sagemaker/tensorflow/estimator.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,8 @@ def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=N
186186
framework_version (str): TensorFlow version you want to use for executing your model training code.
187187
List of supported versions https://github.com/aws/sagemaker-python-sdk#tensorflow-sagemaker-estimators
188188
model_dir (str): S3 location where the checkpoint data and models can be exported to during training
189-
(default: None). If not specified a default S3 URI will be generated.
189+
(default: None). If not specified a default S3 URI will be generated. It will be passed in the
190+
training script as one of the command line arguments.
190191
requirements_file (str): Path to a ``requirements.txt`` file (default: ''). The path should be within and
191192
relative to ``source_dir``. Details on the format can be found in the
192193
`Pip User Guide <https://pip.pypa.io/en/stable/reference/pip_install/#requirements-file-format>`_.
@@ -199,7 +200,14 @@ def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=N
199200
script_mode (bool): If set to True will the estimator will use the Script Mode containers (default: False).
200201
This will be ignored if py_version is set to 'py3'.
201202
distribution (dict): A dictionary with information on how to run distributed training
202-
(default: None).
203+
(default: None). Currently we only support distributed training with parameter servers. To enable it
204+
use the following setup:
205+
{
206+
'parameter_server':
207+
{
208+
'enabled': True
209+
}
210+
}
203211
**kwargs: Additional kwargs passed to the Framework constructor.
204212
"""
205213
if framework_version is None:
@@ -289,9 +297,8 @@ def fit_super():
289297
if run_tensorboard_locally and wait is False:
290298
raise ValueError("Tensorboard is not supported with async fit")
291299

292-
if self._script_mode_enabled():
293-
if run_tensorboard_locally:
294-
LOGGER.warning(_SCRIPT_MODE_TENSORBOARD_WARNING.format(self.model_dir))
300+
if self._script_mode_enabled() and run_tensorboard_locally:
301+
LOGGER.warning(_SCRIPT_MODE_TENSORBOARD_WARNING.format(self.model_dir))
295302
fit_super()
296303
elif run_tensorboard_locally:
297304
tensorboard = Tensorboard(self)

0 commit comments

Comments
 (0)