Skip to content

Support of Horovod and TF 1.12 for TensorFlow Script Mode. TFS 1.12 support #567

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 69 commits into from
Jan 17, 2019
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
1eb85ad
Add horovod support
icywang86rui Dec 6, 2018
c63fe06
Add newline at eof
icywang86rui Dec 6, 2018
b91208a
Do not skip integ test
icywang86rui Dec 6, 2018
a1b426a
Edit README to include distributed training with MPI
icywang86rui Dec 6, 2018
10fe7bf
PR commentsw
icywang86rui Dec 7, 2018
c5f68b1
Add processes_per_host and custom_mpi_options
icywang86rui Dec 7, 2018
858079e
Add missing period
icywang86rui Dec 7, 2018
ffc0812
Use distribution in README
icywang86rui Dec 7, 2018
7587e52
Use distributions in README
icywang86rui Dec 7, 2018
f1f8583
Fix README
icywang86rui Dec 7, 2018
64449a5
Imporve documentation
yangaws Dec 17, 2018
c857afe
Address comments from Eric
yangaws Dec 18, 2018
245b75f
Merge remote-tracking branch 'origin/master' into horovod
mvsusp Dec 18, 2018
e3aeb6e
Updated TF version
mvsusp Dec 18, 2018
2bcd290
Fix empty mpi distribution use case
mvsusp Dec 18, 2018
3cfdc57
Add check for necessary files in model.tar.gz
yangaws Dec 19, 2018
561414f
Add benchmarks as submodule
mvsusp Dec 19, 2018
b5d4a1c
Add benchmarks as submodule
mvsusp Dec 19, 2018
7843392
Handle PR comments
mvsusp Dec 19, 2018
1c4e8c5
Update version
mvsusp Dec 19, 2018
41175a2
Handle PR comments
mvsusp Dec 20, 2018
c137be1
Run TF tests against latest container instead of default.
nadiaya Dec 20, 2018
d44d590
Merge branch 'wru-horovod' of github.com:mvsusp/sagemaker-python-sdk …
nadiaya Dec 20, 2018
05ee7c1
Merge branch 'master' into wru-horovod
yangaws Dec 20, 2018
1680073
Fix urllib.parse import errors for python 2.
nadiaya Dec 20, 2018
0232e4a
Merge branch 'wru-horovod' of github.com:mvsusp/sagemaker-python-sdk …
nadiaya Dec 20, 2018
c19021a
Fix horovod integ test tar file extract error
yangaws Dec 20, 2018
6722c07
Merge branch 'master' into wru-horovod
yangaws Dec 20, 2018
24a5d61
fix flake8
yangaws Dec 20, 2018
160e646
Removed unnecessary tests
mvsusp Dec 21, 2018
7f93812
Merge branch 'master' into wru-horovod
uditbhatia Jan 10, 2019
333ebf7
Removing duplicated/unused TF import
uditbhatia Jan 10, 2019
1f80caf
Merge branch 'master' into wru-horovod
uditbhatia Jan 10, 2019
a1ec1b4
Add horovod support
icywang86rui Dec 6, 2018
9e8d88a
Add newline at eof
icywang86rui Dec 6, 2018
fafc9bb
Do not skip integ test
icywang86rui Dec 6, 2018
df313d8
Edit README to include distributed training with MPI
icywang86rui Dec 6, 2018
3fd1bf0
PR commentsw
icywang86rui Dec 7, 2018
e3051da
Add processes_per_host and custom_mpi_options
icywang86rui Dec 7, 2018
ead6229
Add missing period
icywang86rui Dec 7, 2018
d41e163
Use distribution in README
icywang86rui Dec 7, 2018
2aff9fc
Use distributions in README
icywang86rui Dec 7, 2018
3915406
Fix README
icywang86rui Dec 7, 2018
a07c0d6
Imporve documentation
yangaws Dec 17, 2018
308a31c
Address comments from Eric
yangaws Dec 18, 2018
56d6d07
Updated TF version
mvsusp Dec 18, 2018
3145ffd
Fix empty mpi distribution use case
mvsusp Dec 18, 2018
dd838ef
Add check for necessary files in model.tar.gz
yangaws Dec 19, 2018
15bfe00
Add benchmarks as submodule
mvsusp Dec 19, 2018
8e9734e
Add benchmarks as submodule
mvsusp Dec 19, 2018
b22671d
Handle PR comments
mvsusp Dec 19, 2018
20e906e
Update version
mvsusp Dec 19, 2018
430cd0a
Handle PR comments
mvsusp Dec 20, 2018
bd9c92d
Run TF tests against latest container instead of default.
nadiaya Dec 20, 2018
2fcdaea
Fix urllib.parse import errors for python 2.
nadiaya Dec 20, 2018
3d06e11
Fix horovod integ test tar file extract error
yangaws Dec 20, 2018
c78eb31
fix flake8
yangaws Dec 20, 2018
abce1dd
Removed unnecessary tests
mvsusp Dec 21, 2018
3342e94
Removing duplicated/unused TF import
uditbhatia Jan 10, 2019
cb7610f
Capitalizing the mpi_distribution ps_distribution constant
uditbhatia Jan 10, 2019
dca2173
resolving conflists
uditbhatia Jan 10, 2019
f91b29c
Merge branch 'master' into wru-horovod
uditbhatia Jan 11, 2019
30995dd
Restoring version default to 1.12
uditbhatia Jan 11, 2019
7c93fdc
Accomodating the mvs pr comments
uditbhatia Jan 11, 2019
a8d2cb0
Updating changelog
uditbhatia Jan 11, 2019
55c1998
chaing the TF_VERSION field to 1.11 from 1.12 in defaults.py
uditbhatia Jan 11, 2019
c913c40
Merge branch 'master' into wru-horovod
uditbhatia Jan 16, 2019
23d8074
Fixing flake 8 errors after merge from master and updating changelog
uditbhatia Jan 16, 2019
177f37f
Bumping up the python SDK version to 1.17.3 (as per instructions in M…
uditbhatia Jan 16, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
CHANGELOG
=========

1.16.4.dev
==========
1.17.0
======

* feature: support for Tensorflow 1.12
* feature: support for Tensorflow Serving 1.12
* feature: support for Horovod
* bug-fix: Session: don't allow get_execution_role() to return an ARN that's not a role but has "role" in the name

1.16.3
Expand Down
2 changes: 1 addition & 1 deletion doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __getattr__(cls, name):
'numpy', 'scipy', 'scipy.sparse']
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)

version = '1.16.3'
version = '1.17.0'
project = u'sagemaker'

# Add any Sphinx extension module names here, as strings. They can be extensions
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@
from sagemaker.session import s3_input # noqa: F401
from sagemaker.session import get_execution_role # noqa: F401

__version__ = '1.16.3'
__version__ = '1.17.0'
3 changes: 3 additions & 0 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,9 @@ class Framework(EstimatorBase):

__framework_name__ = None
LAUNCH_PS_ENV_NAME = 'sagemaker_parameter_server_enabled'
LAUNCH_MPI_ENV_NAME = 'sagemaker_mpi_enabled'
MPI_NUM_PROCESSES_PER_HOST = 'sagemaker_mpi_num_of_processes_per_host'
MPI_CUSTOM_MPI_OPTIONS = 'sagemaker_mpi_custom_mpi_options'

def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cloudwatch_metrics=False,
container_log_level=logging.INFO, code_location=None, image_name=None, dependencies=None, **kwargs):
Expand Down
69 changes: 63 additions & 6 deletions src/sagemaker/tensorflow/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,64 @@ After attaching, the estimator can be deployed as usual.

tf_estimator = TensorFlow.attach(training_job_name=training_job_name)

Distributed Training
''''''''''''''''''''

To run your training job with multiple instances in a distributed fashion, set ``train_instance_count``
to a number larger than 1. We support two different types of distributed training, parameter server and Horovod.
The ``distributions`` parameter is used to configure which distributed training strategy to use.

Training with parameter servers
"""""""""""""""""""""""""""""""

If you specify parameter_server as the value of the distributions parameter, the container launches a parameter server
thread on each instance in the training cluster, and then executes your training code. You can find more information on
TensorFlow distributed training at `TensorFlow docs <https://www.tensorflow.org/deploy/distributed>`__.
To enable parameter server training:

.. code:: python

from sagemaker.tensorflow import TensorFlow

tf_estimator = TensorFlow(entry_point='tf-train.py', role='SageMakerRole',
train_instance_count=2, train_instance_type='ml.p2.xlarge',
framework_version='1.11', py_version='py3',
distributions={'parameter_server': {'enabled': True}})
tf_estimator.fit('s3://bucket/path/to/training/data')

Training with Horovod
"""""""""""""""""""""

Horovod is a distributed training framework based on MPI. You can find more details at `Horovod README <https://github.com/uber/horovod>`__.

The container sets up the MPI environment and executes the ``mpirun`` command enabling you to run any Horovod
training script with Script Mode.

Training with ``MPI`` is configured by specifying following fields in ``distributions``:

- ``enabled (bool)``: If set to ``True``, the MPI setup is performed and ``mpirun`` command is executed.
- ``processes_per_host (int)``: Number of processes MPI should launch on each host. Note, this should not be
greater than the available slots on the selected instance type.
- ``custom_mpi_options (str)``: Additional command line arguments to pass to ``mpirun``.

In the below example we create an estimator to launch Horovod distributed training with 2 processes on one host:

.. code:: python

from sagemaker.tensorflow import TensorFlow

tf_estimator = TensorFlow(entry_point='tf-train.py', role='SageMakerRole',
train_instance_count=1, train_instance_type='ml.p2.xlarge',
framework_version='1.11', py_version='py3',
distributions={
'mpi': {
'enabled': True,
'processes_per_host': 2,
'custom_mpi_options': '--NCCL_DEBUG INFO'
}
})
tf_estimator.fit('s3://bucket/path/to/training/data')

sagemaker.tensorflow.TensorFlow class
'''''''''''''''''''''''''''''''''''''

Expand Down Expand Up @@ -277,11 +335,10 @@ Optional:
- ``model_dir (str)`` Location where model data, checkpoint data, and TensorBoard checkpoints should be saved during training.
If not specified a S3 location will be generated under the training job's default bucket. And ``model_dir`` will be
passed in your training script as one of the command line arguments.
- ``distributions (dict)`` Configure your distrubtion strategy with this argument. For launching parameter server for
for distributed training, you must set ``distributions`` to ``{'parameter_server': {'enabled': True}}``
- ``distributions (dict)`` Configure your distribution strategy with this argument.

Training with Pipe Mode using PipeModeDataset
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Amazon SageMaker allows users to create training jobs using Pipe input mode.
With Pipe input mode, your dataset is streamed directly to your training instances instead of being downloaded first.
Expand Down Expand Up @@ -327,9 +384,9 @@ To run training job with Pipe input mode, pass in ``input_mode='Pipe'`` to your
from sagemaker.tensorflow import TensorFlow

tf_estimator = TensorFlow(entry_point='tf-train-with-pipemodedataset.py', role='SageMakerRole',
training_steps=10000, evaluation_steps=100,
train_instance_count=1, train_instance_type='ml.p2.xlarge',
framework_version='1.10.0', input_mode='Pipe')
training_steps=10000, evaluation_steps=100,
train_instance_count=1, train_instance_type='ml.p2.xlarge',
framework_version='1.10.0', input_mode='Pipe')

tf_estimator.fit('s3://bucket/path/to/training/data')

Expand Down
35 changes: 28 additions & 7 deletions src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ class TensorFlow(Framework):
"""Handle end-to-end training and deployment of user-provided TensorFlow code."""

__framework_name__ = 'tensorflow'
LATEST_VERSION = '1.12'

def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=None, py_version='py2',
framework_version=None, model_dir=None, requirements_file='', image_name=None,
Expand Down Expand Up @@ -200,14 +201,21 @@ def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=N
script_mode (bool): If set to True will the estimator will use the Script Mode containers (default: False).
This will be ignored if py_version is set to 'py3'.
distributions (dict): A dictionary with information on how to run distributed training
(default: None). Currently we only support distributed training with parameter servers. To enable it
use the following setup:
(default: None). Currently we support distributed training with parameter servers and MPI. To enable
parameter servers use the following setup:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow up:
we should have a test that checks parmater server + mpi works.

{
'parameter_server':
{
'enabled': True
}
}
To enable MPI:
{
'mpi':
{
'enabled': True
}
}
**kwargs: Additional kwargs passed to the Framework constructor.
"""
if framework_version is None:
Expand Down Expand Up @@ -419,13 +427,24 @@ def hyperparameters(self):
hyperparameters = super(TensorFlow, self).hyperparameters()

self.checkpoint_path = self.checkpoint_path or self._default_s3_path('checkpoints')
mpi_enabled = False

if self._script_mode_enabled():
self.model_dir = self.model_dir or self._default_s3_path('model')
additional_hyperparameters = {'model_dir': self.model_dir}
additional_hyperparameters = {}

if 'parameter_server' in self.distributions:
enabled = self.distributions['parameter_server'].get('enabled', False)
additional_hyperparameters[self.LAUNCH_PS_ENV_NAME] = enabled
ps_enabled = self.distributions['parameter_server'].get('enabled', False)
additional_hyperparameters[self.LAUNCH_PS_ENV_NAME] = ps_enabled

if 'mpi' in self.distributions:
mpi_dict = self.distributions['mpi']
mpi_enabled = mpi_dict.get('enabled', False)
additional_hyperparameters[self.LAUNCH_MPI_ENV_NAME] = mpi_enabled
additional_hyperparameters[self.MPI_NUM_PROCESSES_PER_HOST] = mpi_dict.get('processes_per_host', 1)
additional_hyperparameters[self.MPI_CUSTOM_MPI_OPTIONS] = mpi_dict.get('custom_mpi_options', '')

self.model_dir = self.model_dir or self._default_s3_path('model', mpi=mpi_enabled)
additional_hyperparameters['model_dir'] = self.model_dir
else:
additional_hyperparameters = {'checkpoint_path': self.checkpoint_path,
'training_steps': self.training_steps,
Expand All @@ -435,10 +454,12 @@ def hyperparameters(self):
hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))
return hyperparameters

def _default_s3_path(self, directory):
def _default_s3_path(self, directory, mpi=False):
local_code = get_config_value('local.local_code', self.sagemaker_session.config)
if self.sagemaker_session.local_mode and local_code:
return '/opt/ml/shared/{}'.format(directory)
elif mpi:
return '/opt/ml/model'
else:
return os.path.join(self.output_path, self._current_job_name, directory)

Expand Down
6 changes: 3 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from sagemaker.pytorch import PyTorch
from sagemaker.rl import RLEstimator
from sagemaker.sklearn.defaults import SKLEARN_VERSION
from sagemaker.tensorflow.defaults import TF_VERSION
from sagemaker.tensorflow import TensorFlow

DEFAULT_REGION = 'us-west-2'

Expand All @@ -43,7 +43,7 @@ def pytest_addoption(parser):
parser.addoption('--rl-ray-full-version', action='store',
default=RLEstimator.RAY_LATEST_VERSION)
parser.addoption('--sklearn-full-version', action='store', default=SKLEARN_VERSION)
parser.addoption('--tf-full-version', action='store', default=TF_VERSION)
parser.addoption('--tf-full-version', action='store', default=TensorFlow.LATEST_VERSION)


def pytest_configure(config):
Expand Down Expand Up @@ -126,7 +126,7 @@ def sklearn_version(request):

@pytest.fixture(scope='module', params=['1.4', '1.4.1', '1.5', '1.5.0', '1.6', '1.6.0',
'1.7', '1.7.0', '1.8', '1.8.0', '1.9', '1.9.0',
'1.10', '1.10.0', '1.11', '1.11.0'])
'1.10', '1.10.0', '1.11', '1.11.0', '1.12', '1.12.0'])
def tf_version(request):
return request.param

Expand Down
3 changes: 3 additions & 0 deletions tests/data/horovod/launcher.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/usr/bin/env bash

python benchmarks/scripts/tf_cnn_benchmarks/tf_cnn_benchmarks.py --num_batches=500 --model vgg16 --variable_update horovod --horovod_device gpu --use_fp16 --summary_verbosity 1 --save_summaries_steps 10 --train_dir /opt/ml/model --eval_dir /opt/ml/model --batch_size 32
11 changes: 11 additions & 0 deletions tests/data/horovod/test_hvd_basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import json
import os
import horovod.tensorflow as hvd

hvd.init()

with open(os.path.join('/opt/ml/model/rank-%s' % hvd.rank()), 'w+') as f:
basic_info = {'rank': hvd.rank(), 'size': hvd.size()}

print(basic_info)
json.dump(basic_info, f)
135 changes: 135 additions & 0 deletions tests/data/tensorflow_mnist/horovod_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import argparse
import os

import tensorflow as tf
import horovod.tensorflow as hvd

layers = tf.contrib.layers
learn = tf.contrib.learn

tf.logging.set_verbosity(tf.logging.INFO)


def _parse_args():
parser = argparse.ArgumentParser()
# Data, model, and output directories
parser.add_argument('--output-data-dir', type=str, default=os.environ.get('SM_OUTPUT_DATA_DIR'))
parser.add_argument('--model_dir', type=str)

return parser.parse_known_args()


def conv_model(feature, target, mode):
"""2-layer convolution model."""
# Convert the target to a one-hot tensor of shape (batch_size, 10) and
# with a on-value of 1 for each one-hot vector of length 10.
target = tf.one_hot(tf.cast(target, tf.int32), 10, 1, 0)

# Reshape feature to 4d tensor with 2nd and 3rd dimensions being
# image width and height final dimension being the number of color channels.
feature = tf.reshape(feature, [-1, 28, 28, 1])

# First conv layer will compute 32 features for each 5x5 patch
with tf.variable_scope('conv_layer1'):
h_conv1 = layers.conv2d(
feature, 32, kernel_size=[5, 5], activation_fn=tf.nn.relu)
h_pool1 = tf.nn.max_pool(
h_conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

# Second conv layer will compute 64 features for each 5x5 patch.
with tf.variable_scope('conv_layer2'):
h_conv2 = layers.conv2d(
h_pool1, 64, kernel_size=[5, 5], activation_fn=tf.nn.relu)
h_pool2 = tf.nn.max_pool(
h_conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
# reshape tensor into a batch of vectors
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])

# Densely connected layer with 1024 neurons.
h_fc1 = layers.dropout(
layers.fully_connected(
h_pool2_flat, 1024, activation_fn=tf.nn.relu),
keep_prob=0.5,
is_training=mode == tf.contrib.learn.ModeKeys.TRAIN)

# Compute logits (1 per class) and compute loss.
logits = layers.fully_connected(h_fc1, 10, activation_fn=None)
loss = tf.losses.softmax_cross_entropy(target, logits)

return tf.argmax(logits, 1), loss


def main(_):
args, unknown = _parse_args()

# Horovod: initialize Horovod.
hvd.init()

# Download and load MNIST dataset.
mnist = learn.datasets.mnist.read_data_sets('MNIST-data-%d' % hvd.rank())

# Build model...
with tf.name_scope('input'):
image = tf.placeholder(tf.float32, [None, 784], name='image')
label = tf.placeholder(tf.float32, [None], name='label')
predict, loss = conv_model(image, label, tf.contrib.learn.ModeKeys.TRAIN)

# Horovod: adjust learning rate based on number of GPUs.
opt = tf.train.RMSPropOptimizer(0.001 * hvd.size())

# Horovod: add Horovod Distributed Optimizer.
opt = hvd.DistributedOptimizer(opt)

global_step = tf.contrib.framework.get_or_create_global_step()
train_op = opt.minimize(loss, global_step=global_step)

hooks = [
# Horovod: BroadcastGlobalVariablesHook broadcasts initial variable states
# from rank 0 to all other processes. This is necessary to ensure consistent
# initialization of all workers when training is started with random weights
# or restored from a checkpoint.
hvd.BroadcastGlobalVariablesHook(0),

tf.train.StopAtStepHook(last_step=200 // hvd.size()),

tf.train.LoggingTensorHook(tensors={'step': global_step, 'loss': loss},
every_n_iter=10),
]

# Horovod: pin GPU to be used to process local rank (one GPU per process)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.visible_device_list = str(hvd.local_rank())

# Horovod: save checkpoints only on worker 0 to prevent other workers from
# corrupting them.
checkpoint_dir = os.path.join(args.model_dir, 'checkpoints') if hvd.rank() == 0 else None

# The MonitoredTrainingSession takes care of session initialization,
# restoring from a checkpoint, saving to a checkpoint, and closing when done
# or an error occurs.
with tf.train.MonitoredTrainingSession(checkpoint_dir=checkpoint_dir,
hooks=hooks,
config=config) as mon_sess:
while not mon_sess.should_stop():
# Run a training step synchronously.
image_, label_ = mnist.train.next_batch(100)
mon_sess.run(train_op, feed_dict={image: image_, label: label_})


if __name__ == "__main__":
tf.app.run()
Loading