Skip to content

[DO NOT MERGE] Enable distributed training with Horovod for TensorFlow Script Mode #529

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,9 @@ class Framework(EstimatorBase):

__framework_name__ = None
LAUNCH_PS_ENV_NAME = 'sagemaker_parameter_server_enabled'
LAUNCH_MPI_ENV_NAME = 'sagemaker_mpi_enabled'
MPI_NUM_PROCESSES_PER_HOST = 'sagemaker_mpi_num_of_processes_per_host'
MPI_CUSTOM_MPI_OPTIONS = 'sagemaker_mpi_custom_mpi_options'

def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cloudwatch_metrics=False,
container_log_level=logging.INFO, code_location=None, image_name=None, dependencies=None, **kwargs):
Expand Down
69 changes: 63 additions & 6 deletions src/sagemaker/tensorflow/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,64 @@ After attaching, the estimator can be deployed as usual.

tf_estimator = TensorFlow.attach(training_job_name=training_job_name)

Distributed Training
''''''''''''''''''''

To run your training job with multiple instances in a distributed fashion, set ``train_instance_count``
to a number larger than 1. We support two different types of distributed training, parameter server and Horovod.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is running a script that uses MPI but not Horovod a use case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is but if you use the tensorflow container it uses Horovod. I could be wrong.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, MPI without Horovod is a valid use case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if that's the case, then I think this should be changed to something like "We support two different ways of handling distributed training: parameter servers and MPI. The use of MPI can be with or without Horovod." maybe include a link to Horovod documentation as well.

The ``distributions`` parameter is used to configure which distributed training strategy to use.

Training with parameter servers
"""""""""""""""""""""""""""""""

If you specify parameter_server as the value of the distributions parameter, the container launches a parameter server
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

backticks around parameter_server

thread on each instance in the training cluster, and then executes your training code. You can find more information on
TensorFlow distributed training at `TensorFlow docs <https://www.tensorflow.org/deploy/distributed>`__.
To enable parameter server training:

.. code:: python

from sagemaker.tensorflow import TensorFlow

tf_estimator = TensorFlow(entry_point='tf-train.py', role='SageMakerRole',
train_instance_count=2, train_instance_type='ml.p2.xlarge',
framework_version='1.11', py_version='py3',
distributions={'parameter_server': {'enabled': True}})
tf_estimator.fit('s3://bucket/path/to/training/data')

Training with Horovod
"""""""""""""""""""""

Horovod is a distributed training framework based on MPI. You can find more details at `Horovod README <https://github.com/uber/horovod>`__.

The container sets up the MPI environment and executes the ``mpirun`` command enabling you to run any Horovod
training script with Script Mode.

Training with ``MPI`` is configured by specifying following fields in ``distributions``:

- ``enabled (bool)``: If set to ``True``, the MPI setup is performed and ``mpirun`` command is executed.
- ``processes_per_host (int)``: Number of processes MPI should launch on each host. Note, this should not be
greater than the available slots on the selected instance type.
- ``custom_mpi_options (str)``: Additional command line arguments to pass to ``mpirun``.

In the below example we create an estimator to launch Horovod distributed training with 2 processes on one host:

.. code:: python

from sagemaker.tensorflow import TensorFlow

tf_estimator = TensorFlow(entry_point='tf-train.py', role='SageMakerRole',
train_instance_count=1, train_instance_type='ml.p2.xlarge',
framework_version='1.11', py_version='py3',
distributions={
'mpi': {
'enabled': True,
'processes_per_host': 2,
'custom_mpi_options': '--NCCL_DEBUG INFO'
}
})
tf_estimator.fit('s3://bucket/path/to/training/data')

sagemaker.tensorflow.TensorFlow class
'''''''''''''''''''''''''''''''''''''

Expand Down Expand Up @@ -277,11 +335,10 @@ Optional:
- ``model_dir (str)`` Location where model data, checkpoint data, and TensorBoard checkpoints should be saved during training.
If not specified a S3 location will be generated under the training job's default bucket. And ``model_dir`` will be
passed in your training script as one of the command line arguments.
- ``distributions (dict)`` Configure your distrubtion strategy with this argument. For launching parameter server for
for distributed training, you must set ``distributions`` to ``{'parameter_server': {'enabled': True}}``
- ``distributions (dict)`` Configure your distribution strategy with this argument.

Training with Pipe Mode using PipeModeDataset
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Amazon SageMaker allows users to create training jobs using Pipe input mode.
With Pipe input mode, your dataset is streamed directly to your training instances instead of being downloaded first.
Expand Down Expand Up @@ -327,9 +384,9 @@ To run training job with Pipe input mode, pass in ``input_mode='Pipe'`` to your
from sagemaker.tensorflow import TensorFlow

tf_estimator = TensorFlow(entry_point='tf-train-with-pipemodedataset.py', role='SageMakerRole',
training_steps=10000, evaluation_steps=100,
train_instance_count=1, train_instance_type='ml.p2.xlarge',
framework_version='1.10.0', input_mode='Pipe')
training_steps=10000, evaluation_steps=100,
train_instance_count=1, train_instance_type='ml.p2.xlarge',
framework_version='1.10.0', input_mode='Pipe')

tf_estimator.fit('s3://bucket/path/to/training/data')

Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/tensorflow/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import

TF_VERSION = '1.11'
TF_VERSION = '1.12'
34 changes: 27 additions & 7 deletions src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,21 @@ def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=N
script_mode (bool): If set to True will the estimator will use the Script Mode containers (default: False).
This will be ignored if py_version is set to 'py3'.
distributions (dict): A dictionary with information on how to run distributed training
(default: None). Currently we only support distributed training with parameter servers. To enable it
use the following setup:
(default: None). Currently we support distributed training with parameter servers and MPI. To enable
parameter server use the following setup:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/server/servers

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for "To enable parameter server" - s/server/servers

{
'parameter_server':
{
'enabled': True
}
}
To enable MPI:
{
'mpi':
{
'enabled': True
}
}
**kwargs: Additional kwargs passed to the Framework constructor.
"""
if framework_version is None:
Expand Down Expand Up @@ -419,13 +426,24 @@ def hyperparameters(self):
hyperparameters = super(TensorFlow, self).hyperparameters()

self.checkpoint_path = self.checkpoint_path or self._default_s3_path('checkpoints')
mpi_enabled = False

if self._script_mode_enabled():
self.model_dir = self.model_dir or self._default_s3_path('model')
additional_hyperparameters = {'model_dir': self.model_dir}
additional_hyperparameters = {}

if 'parameter_server' in self.distributions:
enabled = self.distributions['parameter_server'].get('enabled', False)
additional_hyperparameters[self.LAUNCH_PS_ENV_NAME] = enabled
ps_enabled = self.distributions['parameter_server'].get('enabled', False)
additional_hyperparameters[self.LAUNCH_PS_ENV_NAME] = ps_enabled

if 'mpi' in self.distributions:
mpi_dict = self.distributions['mpi']
mpi_enabled = mpi_dict.get('enabled', False)
additional_hyperparameters[self.LAUNCH_MPI_ENV_NAME] = mpi_enabled
additional_hyperparameters[self.MPI_NUM_PROCESSES_PER_HOST] = mpi_dict.get('processes_per_host', 1)
additional_hyperparameters[self.MPI_CUSTOM_MPI_OPTIONS] = mpi_dict.get('custom_mpi_options', '')

self.model_dir = self.model_dir or self._default_s3_path('model', mpi=mpi_enabled)
additional_hyperparameters['model_dir'] = self.model_dir
else:
additional_hyperparameters = {'checkpoint_path': self.checkpoint_path,
'training_steps': self.training_steps,
Expand All @@ -435,10 +453,12 @@ def hyperparameters(self):
hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))
return hyperparameters

def _default_s3_path(self, directory):
def _default_s3_path(self, directory, mpi=False):
local_code = get_config_value('local.local_code', self.sagemaker_session.config)
if self.sagemaker_session.local_mode and local_code:
return '/opt/ml/shared/{}'.format(directory)
elif mpi:
return '/opt/ml/model'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we make this a constant?

else:
return os.path.join(self.output_path, self._current_job_name, directory)

Expand Down
135 changes: 135 additions & 0 deletions tests/data/tensorflow_mnist/horovod_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import argparse
import os

import tensorflow as tf
import horovod.tensorflow as hvd

layers = tf.contrib.layers
learn = tf.contrib.learn

tf.logging.set_verbosity(tf.logging.INFO)


def _parse_args():
parser = argparse.ArgumentParser()
# Data, model, and output directories
parser.add_argument('--output-data-dir', type=str, default=os.environ.get('SM_OUTPUT_DATA_DIR'))
parser.add_argument('--model_dir', type=str)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it's strange to me that we would mix underscores and hyphens in our examples like this


return parser.parse_known_args()


def conv_model(feature, target, mode):
"""2-layer convolution model."""
# Convert the target to a one-hot tensor of shape (batch_size, 10) and
# with a on-value of 1 for each one-hot vector of length 10.
target = tf.one_hot(tf.cast(target, tf.int32), 10, 1, 0)

# Reshape feature to 4d tensor with 2nd and 3rd dimensions being
# image width and height final dimension being the number of color channels.
feature = tf.reshape(feature, [-1, 28, 28, 1])

# First conv layer will compute 32 features for each 5x5 patch
with tf.variable_scope('conv_layer1'):
h_conv1 = layers.conv2d(
feature, 32, kernel_size=[5, 5], activation_fn=tf.nn.relu)
h_pool1 = tf.nn.max_pool(
h_conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

# Second conv layer will compute 64 features for each 5x5 patch.
with tf.variable_scope('conv_layer2'):
h_conv2 = layers.conv2d(
h_pool1, 64, kernel_size=[5, 5], activation_fn=tf.nn.relu)
h_pool2 = tf.nn.max_pool(
h_conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
# reshape tensor into a batch of vectors
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])

# Densely connected layer with 1024 neurons.
h_fc1 = layers.dropout(
layers.fully_connected(
h_pool2_flat, 1024, activation_fn=tf.nn.relu),
keep_prob=0.5,
is_training=mode == tf.contrib.learn.ModeKeys.TRAIN)

# Compute logits (1 per class) and compute loss.
logits = layers.fully_connected(h_fc1, 10, activation_fn=None)
loss = tf.losses.softmax_cross_entropy(target, logits)

return tf.argmax(logits, 1), loss


def main(_):
args, unknown = _parse_args()

# Horovod: initialize Horovod.
hvd.init()

# Download and load MNIST dataset.
mnist = learn.datasets.mnist.read_data_sets('MNIST-data-%d' % hvd.rank())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the size of the dataset?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the training data is 164M. With the eval data and the label it's about 200M.


# Build model...
with tf.name_scope('input'):
image = tf.placeholder(tf.float32, [None, 784], name='image')
label = tf.placeholder(tf.float32, [None], name='label')
predict, loss = conv_model(image, label, tf.contrib.learn.ModeKeys.TRAIN)

# Horovod: adjust learning rate based on number of GPUs.
opt = tf.train.RMSPropOptimizer(0.001 * hvd.size())

# Horovod: add Horovod Distributed Optimizer.
opt = hvd.DistributedOptimizer(opt)

global_step = tf.contrib.framework.get_or_create_global_step()
train_op = opt.minimize(loss, global_step=global_step)

hooks = [
# Horovod: BroadcastGlobalVariablesHook broadcasts initial variable states
# from rank 0 to all other processes. This is necessary to ensure consistent
# initialization of all workers when training is started with random weights
# or restored from a checkpoint.
hvd.BroadcastGlobalVariablesHook(0),

tf.train.StopAtStepHook(last_step=200 // hvd.size()),

tf.train.LoggingTensorHook(tensors={'step': global_step, 'loss': loss},
every_n_iter=10),
]

# Horovod: pin GPU to be used to process local rank (one GPU per process)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.visible_device_list = str(hvd.local_rank())

# Horovod: save checkpoints only on worker 0 to prevent other workers from
# corrupting them.
checkpoint_dir = os.path.join(args.model_dir, 'checkpoints') if hvd.rank() == 0 else None

# The MonitoredTrainingSession takes care of session initialization,
# restoring from a checkpoint, saving to a checkpoint, and closing when done
# or an error occurs.
with tf.train.MonitoredTrainingSession(checkpoint_dir=checkpoint_dir,
hooks=hooks,
config=config) as mon_sess:
while not mon_sess.should_stop():
# Run a training step synchronously.
image_, label_ = mnist.train.next_batch(100)
mon_sess.run(train_op, feed_dict={image: image_, label: label_})


if __name__ == "__main__":
tf.app.run()
45 changes: 43 additions & 2 deletions tests/integ/test_tf_script_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import os
import pytest
import tarfile
import tempfile

import boto3
from sagemaker.tensorflow import TensorFlow
Expand All @@ -23,7 +25,8 @@

RESOURCE_PATH = os.path.join(os.path.dirname(__file__), '..', 'data', 'tensorflow_mnist')
SCRIPT = os.path.join(RESOURCE_PATH, 'mnist.py')
DISTRIBUTION_ENABLED = {'parameter_server': {'enabled': True}}
parameter_server_distribution = {'parameter_server': {'enabled': True}}
mpi_distribution = {'mpi': {'enabled': True}}


@pytest.fixture(scope='session', params=['ml.c5.xlarge', 'ml.p2.xlarge'])
Expand Down Expand Up @@ -62,7 +65,7 @@ def test_mnist_distributed(sagemaker_session, instance_type):
py_version=integ.PYTHON_VERSION,
script_mode=True,
framework_version='1.11',
distributions=DISTRIBUTION_ENABLED,
distributions=parameter_server_distribution,
base_job_name='test-tf-sm-mnist')
inputs = estimator.sagemaker_session.upload_data(
path=os.path.join(RESOURCE_PATH, 'data'),
Expand All @@ -74,6 +77,29 @@ def test_mnist_distributed(sagemaker_session, instance_type):
['graph.pbtxt', 'model.ckpt-0.index', 'model.ckpt-0.meta', 'saved_model.pb'])


def test_mnist_horovod_distributed(sagemaker_session):
instance_type = 'ml.p3.2xlarge'
estimator = TensorFlow(entry_point=os.path.join(RESOURCE_PATH, 'horovod_mnist.py'),
role='SageMakerRole',
train_instance_count=2,
train_instance_type=instance_type,
sagemaker_session=sagemaker_session,
py_version=integ.PYTHON_VERSION,
script_mode=True,
framework_version='1.12',
distributions=mpi_distribution,
base_job_name='test-tf-sm-horovod-mnist')
inputs = estimator.sagemaker_session.upload_data(
path=os.path.join(RESOURCE_PATH, 'data'),
key_prefix='scriptmode/distributed_mnist')

with timeout.timeout(minutes=integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
estimator.fit(inputs)
model_dir = os.path.join(estimator.output_path, estimator._current_job_name, 'output', 'model.tar.gz')
_assert_s3_files_exist_in_tar(model_dir,
['graph.pbtxt', 'model.ckpt-0.index', 'model.ckpt-0.meta', 'saved_model.pb'])


def _assert_s3_files_exist(s3_url, files):
parsed_url = urlparse(s3_url)
s3 = boto3.client('s3')
Expand All @@ -82,3 +108,18 @@ def _assert_s3_files_exist(s3_url, files):
found = [x['Key'] for x in contents if x['Key'].endswith(f)]
if not found:
raise ValueError('File {} is not found under {}'.format(f, s3_url))


def _assert_s3_files_exist_in_tar(s3_url, files):
parsed_url = urlparse(s3_url)
tmp_file = tempfile.NamedTemporaryFile()
s3 = boto3.resource('s3')
object = s3.Bucket(parsed_url.netloc).Object(parsed_url.path.lstrip('/'))

with open(tmp_file.name, 'wb') as temp_file:
object.download_fileobj(temp_file)
with tarfile.open(tmp_file.name, 'r') as tar_file:
for f in files:
found = [x for x in tar_file.getnames() if x.endswith(f)]
if not found:
raise ValueError('File {} is not found in {}'.format(f, s3_url))
Loading