Skip to content

Commit 75028a0

Browse files
committed
Address pr comments
1 parent cfea5ab commit 75028a0

File tree

4 files changed

+39
-45
lines changed

4 files changed

+39
-45
lines changed

src/sagemaker/tensorflow/estimator.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,8 @@ def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=N
185185
py_version (str): Python version you want to use for executing your model training code (default: 'py2').
186186
framework_version (str): TensorFlow version you want to use for executing your model training code.
187187
List of supported versions https://github.com/aws/sagemaker-python-sdk#tensorflow-sagemaker-estimators
188+
model_dir (str): S3 location where the checkpoint data and models can be exported to during training
189+
(default: None). If not specified a default S3 URI will be generated.
188190
requirements_file (str): Path to a ``requirements.txt`` file (default: ''). The path should be within and
189191
relative to ``source_dir``. Details on the format can be found in the
190192
`Pip User Guide <https://pip.pypa.io/en/stable/reference/pip_install/#requirements-file-format>`_.
@@ -194,6 +196,10 @@ def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=N
194196
Examples:
195197
123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
196198
custom-image:latest.
199+
script_mode (bool): If set to True will the estimator will use the Script Mode containers (default: False).
200+
This will be ignored if py_version is set to 'py3'.
201+
distribution (dict): A dictionary with information on how to run distributed training
202+
(default: None).
197203
**kwargs: Additional kwargs passed to the Framework constructor.
198204
"""
199205
if framework_version is None:
@@ -207,7 +213,7 @@ def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=N
207213
self.evaluation_steps = evaluation_steps
208214
self.model_dir = model_dir
209215
self.script_mode = script_mode
210-
self.distributions = distributions
216+
self.distributions = distributions or {}
211217

212218
self._validate_args(py_version=py_version, script_mode=script_mode, framework_version=framework_version,
213219
training_steps=training_steps, evaluation_steps=evaluation_steps,
@@ -283,12 +289,11 @@ def fit_super():
283289
if run_tensorboard_locally and wait is False:
284290
raise ValueError("Tensorboard is not supported with async fit")
285291

286-
if run_tensorboard_locally:
287-
288-
if self.script_mode_enabled():
292+
if self._script_mode_enabled():
293+
if run_tensorboard_locally:
289294
LOGGER.warning(_SCRIPT_MODE_TENSORBOARD_WARNING.format(self.model_dir))
290-
return
291-
295+
fit_super()
296+
elif run_tensorboard_locally:
292297
tensorboard = Tensorboard(self)
293298
tensorboard.validate_requirements()
294299

@@ -371,12 +376,9 @@ def create_model(self, model_server_workers=None, role=None,
371376
"""
372377

373378
role = role or self.role
374-
if endpoint_type == 'tensorflow-serving':
379+
if endpoint_type == 'tensorflow-serving' or self._script_mode_enabled():
375380
return self._create_tfs_model(role=role, vpc_config_override=vpc_config_override)
376381

377-
if self.script_mode_enabled():
378-
raise ValueError(_SCRIPT_MODE_SERVING_ERROR_MSG)
379-
380382
return self._create_default_model(model_server_workers=model_server_workers, role=role,
381383
vpc_config_override=vpc_config_override)
382384

@@ -408,17 +410,14 @@ def hyperparameters(self):
408410
"""Return hyperparameters used by your custom TensorFlow code during model training."""
409411
hyperparameters = super(TensorFlow, self).hyperparameters()
410412

411-
if not self.checkpoint_path:
412-
self.checkpoint_path = self._default_s3_path('checkpoints')
413+
self.checkpoint_path = self.checkpoint_path or self._default_s3_path('checkpoints')
413414

414-
if self.script_mode_enabled():
415-
if not self.model_dir:
416-
self.model_dir = self._default_s3_path('model')
415+
if self._script_mode_enabled():
416+
self.model_dir = self.model_dir or self._default_s3_path('model')
417417
additional_hyperparameters = {'model_dir': self.model_dir}
418-
if self.distributions:
419-
if 'parameter_server' in self.distributions:
420-
enabled = self.distributions['parameter_server'].get('enabled', False)
421-
additional_hyperparameters[self.LAUNCH_PS_ENV_NAME] = enabled
418+
if 'parameter_server' in self.distributions:
419+
enabled = self.distributions['parameter_server'].get('enabled', False)
420+
additional_hyperparameters[self.LAUNCH_PS_ENV_NAME] = enabled
422421
else:
423422
additional_hyperparameters = {'checkpoint_path': self.checkpoint_path,
424423
'training_steps': self.training_steps,
@@ -435,15 +434,15 @@ def _default_s3_path(self, directory):
435434
else:
436435
return os.path.join(self.output_path, self._current_job_name, directory)
437436

438-
def script_mode_enabled(self):
437+
def _script_mode_enabled(self):
439438
return self.py_version == 'py3' or self.script_mode
440439

441440
def train_image(self):
442441
if self.image_name:
443442
return self.image_name
444443

445-
if self.script_mode_enabled():
444+
if self._script_mode_enabled():
446445
return fw.create_image_uri(self.sagemaker_session.boto_region_name, _SCRIPT_MODE,
447446
self.train_instance_type, self.framework_version, self.py_version)
448-
else:
449-
return super(TensorFlow, self).train_image()
447+
448+
return super(TensorFlow, self).train_image()

tests/data/tensorflow_mnist/mnist.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,16 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
from __future__ import absolute_import
13+
from __future__ import absolute_import, division, print_function
1414

15-
from __future__ import division
16-
from __future__ import print_function
17-
18-
import numpy as np
19-
import tensorflow as tf
20-
import os
21-
import json
2215
import argparse
23-
from tensorflow.python.platform import tf_logging
16+
import json
2417
import logging as _logging
18+
import numpy as np
19+
import os
2520
import sys as _sys
21+
import tensorflow as tf
22+
from tensorflow.python.platform import tf_logging
2623

2724
tf.logging.set_verbosity(tf.logging.DEBUG)
2825
_handler = _logging.StreamHandler(_sys.stdout)
@@ -137,11 +134,11 @@ def _parse_args():
137134
# hyperparameters sent by the client are passed as command-line arguments to the script.
138135
parser.add_argument('--epochs', type=int, default=1)
139136
# Data, model, and output directories
140-
parser.add_argument('--output-data-dir', type=str, default=os.environ['SM_OUTPUT_DATA_DIR'])
141-
parser.add_argument('--model_dir', type=str, default=os.environ['SM_MODEL_DIR'])
142-
parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAINING'])
143-
parser.add_argument('--hosts', type=list, default=json.loads(os.environ['SM_HOSTS']))
144-
parser.add_argument('--current-host', type=str, default=os.environ['SM_CURRENT_HOST'])
137+
parser.add_argument('--output-data-dir', type=str, default=os.environ.get('SM_OUTPUT_DATA_DIR'))
138+
parser.add_argument('--model_dir', type=str)
139+
parser.add_argument('--train', type=str, default=os.environ.get('SM_CHANNEL_TRAINING'))
140+
parser.add_argument('--hosts', type=list, default=json.loads(os.environ.get('SM_HOSTS')))
141+
parser.add_argument('--current-host', type=str, default=os.environ.get('SM_CURRENT_HOST'))
145142

146143
return parser.parse_known_args()
147144

tests/integ/test_tf_script_mode.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ def test_mnist_distributed(sagemaker_session, instance_type):
5656
train_instance_count=2,
5757
train_instance_type=instance_type,
5858
sagemaker_session=sagemaker_session,
59-
py_version='py3',
59+
py_version=integ.PYTHON_VERSION,
60+
script_mode=True,
6061
framework_version='1.11',
6162
distributions=DISTRIBUTION_ENABLED,
6263
base_job_name='test-tf-sm-mnist')

tests/unit/test_tf_estimator.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -690,22 +690,19 @@ def test_script_mode_deprecated_args(sagemaker_session):
690690

691691
def test_script_mode_enabled(sagemaker_session):
692692
tf = _build_tf(sagemaker_session=sagemaker_session, py_version='py3')
693-
assert tf.script_mode_enabled() is True
693+
assert tf._script_mode_enabled() is True
694694

695695
tf = _build_tf(sagemaker_session=sagemaker_session, script_mode=True)
696-
assert tf.script_mode_enabled() is True
696+
assert tf._script_mode_enabled() is True
697697

698698
tf = _build_tf(sagemaker_session=sagemaker_session)
699-
assert tf.script_mode_enabled() is False
699+
assert tf._script_mode_enabled() is False
700700

701701

702702
@patch('sagemaker.tensorflow.estimator.TensorFlow._create_tfs_model')
703703
def test_script_mode_create_model(create_tfs_model, sagemaker_session):
704704
tf = _build_tf(sagemaker_session=sagemaker_session, py_version='py3')
705-
with pytest.raises(ValueError) as e:
706-
tf.create_model()
707-
assert tfe._SCRIPT_MODE_SERVING_ERROR_MSG in str(e)
708-
tf.create_model(endpoint_type='tensorflow-serving')
705+
tf.create_model()
709706
create_tfs_model.assert_called_once()
710707

711708

0 commit comments

Comments
 (0)