Skip to content

Commit 835d1af

Browse files
authored
Add support for TensorFlow script mode and Python 3 (#475)
* Add script_mode flag to TensorFlow estimator * Add model_dir and distributions to tf estimator * Validate args for script mode * Add unit tests * Add integ tests
1 parent faccfb2 commit 835d1af

File tree

14 files changed

+536
-56
lines changed

14 files changed

+536
-56
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ CHANGELOG
2424
* build: added pylint
2525
* build: upgrade docker-compose to 1.23
2626
* enhancement: Frameworks: update warning for not setting framework_version as we aren't planning a breaking change anymore
27+
* feature: Estimator: add script mode and Python 3 support for TensorFlow
2728
* enhancement: Session: remove hardcoded 'training' from job status error message
2829
* bug-fix: Updated Cloudwatch namespace for metrics in TrainingJobsAnalytics
2930
* bug-fix: Changes to use correct s3 bucket and time range for dataframes in TrainingJobAnalytics.

src/sagemaker/estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,7 @@ class Framework(EstimatorBase):
610610
"""
611611

612612
__framework_name__ = None
613+
LAUNCH_PS_ENV_NAME = 'sagemaker_parameter_server_enabled'
613614

614615
def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cloudwatch_metrics=False,
615616
container_log_level=logging.INFO, code_location=None, image_name=None, **kwargs):

src/sagemaker/fw_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
'If you would like to use version {latest}, ' \
3232
'please add framework_version={latest} to your constructor.'
3333

34+
EMPTY_FRAMEWORK_VERSION_ERROR = 'framework_version is required for script mode estimator. ' \
35+
'Please add framework_version={} to your constructor to avoid this error.'
36+
3437
VALID_PY_VERSIONS = ['py2', 'py3']
3538

3639

src/sagemaker/mxnet/estimator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ class MXNet(Framework):
3030
__framework_name__ = 'mxnet'
3131

3232
_LOWEST_SCRIPT_MODE_VERSION = ['1', '3']
33-
LAUNCH_PS_ENV_NAME = 'sagemaker_parameter_server_enabled'
3433
LATEST_VERSION = '1.3'
3534

3635
def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_version='py2',

src/sagemaker/tensorflow/estimator.py

Lines changed: 98 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@
2222
import time
2323

2424
from sagemaker.estimator import Framework
25-
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag, \
26-
empty_framework_version_warning
25+
import sagemaker.fw_utils as fw
2726
from sagemaker.tensorflow.defaults import TF_VERSION
2827
from sagemaker.tensorflow.model import TensorFlowModel
2928
from sagemaker.tensorflow.serving import Model
@@ -34,6 +33,16 @@
3433
LOGGER = logging.getLogger('sagemaker')
3534

3635

36+
_FRAMEWORK_MODE_ARGS = ('training_steps', 'evaluation_steps', 'requirements_file', 'checkpoint_path')
37+
_SCRIPT_MODE = 'tensorflow-scriptmode'
38+
_SCRIPT_MODE_SERVING_ERROR_MSG = 'Script mode containers does not support serving yet. ' \
39+
'Please use our new tensorflow-serving container by creating the model ' \
40+
'with \'endpoint_type\' set to \'tensorflow-serving\'.'
41+
_SCRIPT_MODE_TENSORBOARD_WARNING = 'Tensorboard is not supported with script mode. You can run the following ' \
42+
'command: tensorboard --logdir {} --host localhost --port 6006 This can be ' \
43+
'run from anywhere with access to the S3 URI used as the logdir.'
44+
45+
3746
class Tensorboard(threading.Thread):
3847
def __init__(self, estimator, logdir=None):
3948
"""Initialize ``Tensorboard`` instance.
@@ -163,9 +172,9 @@ class TensorFlow(Framework):
163172

164173
__framework_name__ = 'tensorflow'
165174

166-
def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=None,
167-
py_version='py2', framework_version=None, requirements_file='', image_name=None,
168-
**kwargs):
175+
def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=None, py_version='py2',
176+
framework_version=None, model_dir=None, requirements_file='', image_name=None,
177+
script_mode=False, distributions=None, **kwargs):
169178
"""Initialize an ``TensorFlow`` estimator.
170179
Args:
171180
training_steps (int): Perform this many steps of training. `None`, the default means train forever.
@@ -176,6 +185,9 @@ def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=N
176185
py_version (str): Python version you want to use for executing your model training code (default: 'py2').
177186
framework_version (str): TensorFlow version you want to use for executing your model training code.
178187
List of supported versions https://github.com/aws/sagemaker-python-sdk#tensorflow-sagemaker-estimators
188+
model_dir (str): S3 location where the checkpoint data and models can be exported to during training
189+
(default: None). If not specified a default S3 URI will be generated. It will be passed in the
190+
training script as one of the command line arguments.
179191
requirements_file (str): Path to a ``requirements.txt`` file (default: ''). The path should be within and
180192
relative to ``source_dir``. Details on the format can be found in the
181193
`Pip User Guide <https://pip.pypa.io/en/stable/reference/pip_install/#requirements-file-format>`_.
@@ -185,21 +197,61 @@ def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=N
185197
Examples:
186198
123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
187199
custom-image:latest.
200+
script_mode (bool): If set to True will the estimator will use the Script Mode containers (default: False).
201+
This will be ignored if py_version is set to 'py3'.
202+
distribution (dict): A dictionary with information on how to run distributed training
203+
(default: None). Currently we only support distributed training with parameter servers. To enable it
204+
use the following setup:
205+
{
206+
'parameter_server':
207+
{
208+
'enabled': True
209+
}
210+
}
188211
**kwargs: Additional kwargs passed to the Framework constructor.
189212
"""
190213
if framework_version is None:
191-
LOGGER.warning(empty_framework_version_warning(TF_VERSION, TF_VERSION))
214+
LOGGER.warning(fw.empty_framework_version_warning(TF_VERSION, TF_VERSION))
192215
self.framework_version = framework_version or TF_VERSION
193216

194217
super(TensorFlow, self).__init__(image_name=image_name, **kwargs)
195218
self.checkpoint_path = checkpoint_path
196219
self.py_version = py_version
197220
self.training_steps = training_steps
198221
self.evaluation_steps = evaluation_steps
222+
self.model_dir = model_dir
223+
self.script_mode = script_mode
224+
self.distributions = distributions or {}
199225

226+
self._validate_args(py_version=py_version, script_mode=script_mode, framework_version=framework_version,
227+
training_steps=training_steps, evaluation_steps=evaluation_steps,
228+
requirements_file=requirements_file, checkpoint_path=checkpoint_path)
200229
self._validate_requirements_file(requirements_file)
201230
self.requirements_file = requirements_file
202231

232+
def _validate_args(self, py_version, script_mode, framework_version, training_steps,
233+
evaluation_steps, requirements_file, checkpoint_path):
234+
235+
if py_version == 'py3' or script_mode:
236+
237+
if framework_version is None:
238+
raise AttributeError(fw.EMPTY_FRAMEWORK_VERSION_ERROR)
239+
240+
found_args = []
241+
if training_steps:
242+
found_args.append('training_steps')
243+
if evaluation_steps:
244+
found_args.append('evaluation_steps')
245+
if requirements_file:
246+
found_args.append('requirements_file')
247+
if checkpoint_path:
248+
found_args.append('checkpoint_path')
249+
if found_args:
250+
raise AttributeError(
251+
'{} are deprecated in script mode. Please do not set {}.'
252+
.format(', '.join(_FRAMEWORK_MODE_ARGS), ', '.join(found_args))
253+
)
254+
203255
def _validate_requirements_file(self, requirements_file):
204256
if not requirements_file:
205257
return
@@ -245,7 +297,10 @@ def fit_super():
245297
if run_tensorboard_locally and wait is False:
246298
raise ValueError("Tensorboard is not supported with async fit")
247299

248-
if run_tensorboard_locally:
300+
if self._script_mode_enabled() and run_tensorboard_locally:
301+
LOGGER.warning(_SCRIPT_MODE_TENSORBOARD_WARNING.format(self.model_dir))
302+
fit_super()
303+
elif run_tensorboard_locally:
249304
tensorboard = Tensorboard(self)
250305
tensorboard.validate_requirements()
251306

@@ -275,13 +330,13 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
275330
model_channel_name)
276331

277332
# Move some of the tensorflow specific init params from hyperparameters into the main init params.
278-
for argument in ['checkpoint_path', 'training_steps', 'evaluation_steps']:
333+
for argument in ('checkpoint_path', 'training_steps', 'evaluation_steps', 'model_dir'):
279334
value = init_params['hyperparameters'].pop(argument, None)
280335
if value is not None:
281336
init_params[argument] = value
282337

283338
image_name = init_params.pop('image')
284-
framework, py_version, tag = framework_name_from_image(image_name)
339+
framework, py_version, tag = fw.framework_name_from_image(image_name)
285340
if not framework:
286341
# If we were unable to parse the framework name from the image it is not one of our
287342
# officially supported images, in this case just add the image to the init params.
@@ -294,7 +349,7 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
294349
# containing framework version, device type and python version (e.g. '1.5-gpu-py2').
295350
# For backward compatibility map deprecated image tag '1.0' to a '1.4' framework version
296351
# otherwise extract framework version from the tag itself.
297-
init_params['framework_version'] = '1.4' if tag == '1.0' else framework_version_from_tag(
352+
init_params['framework_version'] = '1.4' if tag == '1.0' else fw.framework_version_from_tag(
298353
tag)
299354

300355
training_job_name = init_params['base_job_name']
@@ -328,7 +383,7 @@ def create_model(self, model_server_workers=None, role=None,
328383
"""
329384

330385
role = role or self.role
331-
if endpoint_type == 'tensorflow-serving':
386+
if endpoint_type == 'tensorflow-serving' or self._script_mode_enabled():
332387
return self._create_tfs_model(role=role, vpc_config_override=vpc_config_override)
333388

334389
return self._create_default_model(model_server_workers=model_server_workers, role=role,
@@ -362,18 +417,39 @@ def hyperparameters(self):
362417
"""Return hyperparameters used by your custom TensorFlow code during model training."""
363418
hyperparameters = super(TensorFlow, self).hyperparameters()
364419

365-
if not self.checkpoint_path:
366-
local_code = get_config_value('local.local_code', self.sagemaker_session.config)
367-
if self.sagemaker_session.local_mode and local_code:
368-
self.checkpoint_path = '/opt/ml/shared/checkpoints'
369-
else:
370-
self.checkpoint_path = os.path.join(self.output_path,
371-
self._current_job_name, 'checkpoints')
420+
self.checkpoint_path = self.checkpoint_path or self._default_s3_path('checkpoints')
372421

373-
additional_hyperparameters = {'checkpoint_path': self.checkpoint_path,
374-
'training_steps': self.training_steps,
375-
'evaluation_steps': self.evaluation_steps,
376-
'sagemaker_requirements': self.requirements_file}
422+
if self._script_mode_enabled():
423+
self.model_dir = self.model_dir or self._default_s3_path('model')
424+
additional_hyperparameters = {'model_dir': self.model_dir}
425+
if 'parameter_server' in self.distributions:
426+
enabled = self.distributions['parameter_server'].get('enabled', False)
427+
additional_hyperparameters[self.LAUNCH_PS_ENV_NAME] = enabled
428+
else:
429+
additional_hyperparameters = {'checkpoint_path': self.checkpoint_path,
430+
'training_steps': self.training_steps,
431+
'evaluation_steps': self.evaluation_steps,
432+
'sagemaker_requirements': self.requirements_file}
377433

378434
hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))
379435
return hyperparameters
436+
437+
def _default_s3_path(self, directory):
438+
local_code = get_config_value('local.local_code', self.sagemaker_session.config)
439+
if self.sagemaker_session.local_mode and local_code:
440+
return '/opt/ml/shared/{}'.format(directory)
441+
else:
442+
return os.path.join(self.output_path, self._current_job_name, directory)
443+
444+
def _script_mode_enabled(self):
445+
return self.py_version == 'py3' or self.script_mode
446+
447+
def train_image(self):
448+
if self.image_name:
449+
return self.image_name
450+
451+
if self._script_mode_enabled():
452+
return fw.create_image_uri(self.sagemaker_session.boto_region_name, _SCRIPT_MODE,
453+
self.train_instance_type, self.framework_version, self.py_version)
454+
455+
return super(TensorFlow, self).train_image()
2.99 MB
Binary file not shown.
4.03 KB
Binary file not shown.
2.99 MB
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)