Skip to content

doc: clarify how to use parameter servers with distributed MXNet training #1104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 25, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions doc/using_mxnet.rst
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ It is good practice to save the best model after each training epoch,
so that you can resume a training job if it gets interrupted.
This is particularly important if you are using Managed Spot training.

To save MXNet model checkpoints, do the following in your training script:
To save MXNet model checkpoints, do the following in your training script:

* Set the ``CHECKPOINTS_DIR`` environment variable and enable checkpoints.

Expand All @@ -213,7 +213,7 @@ To save MXNet model checkpoints, do the following in your training script:

* Make sure you are emitting a validation metric to test the model. For information, see `Evaluation Metric API <https://mxnet.incubator.apache.org/api/python/metric/metric.html>`_.
* After each training epoch, test whether the current model performs the best with respect to the validation metric, and if it does, save that model to ``CHECKPOINTS_DIR``.

.. code:: python

if checkpoints_enabled and current_host == hosts[0]:
Expand All @@ -224,7 +224,7 @@ To save MXNet model checkpoints, do the following in your training script:
trainer.save_states(CHECKPOINTS_DIR + '/%.4f-cifar10-%d.states'%(best_accuracy, epoch))

For a complete example of an MXNet training script that impelements checkpointing, see https://github.com/awslabs/amazon-sagemaker-examples/blob/master/sagemaker-python-sdk/mxnet_gluon_cifar10/cifar10.py.


Updating your MXNet training script
-----------------------------------
Expand Down Expand Up @@ -331,7 +331,7 @@ The following code sample shows how you train a custom MXNet script "train.py".
'learning-rate': 0.1})
mxnet_estimator.fit('s3://my_bucket/my_training_data/')

For more information about the sagemaker.mxnet.MXNet estimator, see `sagemaker.mxnet.MXNet Class`_.
For more information about the sagemaker.mxnet.MXNet estimator, see `sagemaker.mxnet.MXNet Class`_.



Expand Down Expand Up @@ -370,15 +370,19 @@ fit Optional arguments
Distributed training
====================

When writing a distributed training script, use an MXNet kvstore to store and share model parameters.
If you want to use parameter servers for distributed training, set the following parameter in your ``MXNet`` constructor:

.. code:: python

distributions={'parameter_server': {'enabled': True}}

Then, when writing a distributed training script, use an MXNet kvstore to store and share model parameters.
During training, SageMaker automatically starts an MXNet kvstore server and scheduler processes on hosts in your training job cluster.
Your script runs as an MXNet worker task, with one server process on each host in your cluster.
One host is selected arbitrarily to run the scheduler process.

To learn more about writing distributed MXNet programs, please see `Distributed Training <https://mxnet.incubator.apache.org/versions/master/faq/distributed_training.html>`__ in the MXNet docs.



*******************
Deploy MXNet models
*******************
Expand Down
6 changes: 3 additions & 3 deletions src/sagemaker/mxnet/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ def __init__(
123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
custom-image:latest.

distributions (dict): A dictionary with information on how to run distributed
training (default: None).
distributions:
distributions (dict): A dictionary with information on how to run distributed
training (default: None). To have parameter servers launched for training,
set this value to be ``{'parameter_server': {'enabled': True}}``.
**kwargs: Additional kwargs passed to the
:class:`~sagemaker.estimator.Framework` constructor.
"""
Expand Down