Skip to content

Commit 5c5ca0c

Browse files
authored
doc: clarify how to use parameter servers with distributed MXNet training (#1104)
1 parent a27a766 commit 5c5ca0c

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

doc/using_mxnet.rst

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ It is good practice to save the best model after each training epoch,
202202
so that you can resume a training job if it gets interrupted.
203203
This is particularly important if you are using Managed Spot training.
204204

205-
To save MXNet model checkpoints, do the following in your training script:
205+
To save MXNet model checkpoints, do the following in your training script:
206206

207207
* Set the ``CHECKPOINTS_DIR`` environment variable and enable checkpoints.
208208

@@ -213,7 +213,7 @@ To save MXNet model checkpoints, do the following in your training script:
213213
214214
* Make sure you are emitting a validation metric to test the model. For information, see `Evaluation Metric API <https://mxnet.incubator.apache.org/api/python/metric/metric.html>`_.
215215
* After each training epoch, test whether the current model performs the best with respect to the validation metric, and if it does, save that model to ``CHECKPOINTS_DIR``.
216-
216+
217217
.. code:: python
218218
219219
if checkpoints_enabled and current_host == hosts[0]:
@@ -224,7 +224,7 @@ To save MXNet model checkpoints, do the following in your training script:
224224
trainer.save_states(CHECKPOINTS_DIR + '/%.4f-cifar10-%d.states'%(best_accuracy, epoch))
225225
226226
For a complete example of an MXNet training script that impelements checkpointing, see https://github.com/awslabs/amazon-sagemaker-examples/blob/master/sagemaker-python-sdk/mxnet_gluon_cifar10/cifar10.py.
227-
227+
228228

229229
Updating your MXNet training script
230230
-----------------------------------
@@ -331,7 +331,7 @@ The following code sample shows how you train a custom MXNet script "train.py".
331331
'learning-rate': 0.1})
332332
mxnet_estimator.fit('s3://my_bucket/my_training_data/')
333333
334-
For more information about the sagemaker.mxnet.MXNet estimator, see `sagemaker.mxnet.MXNet Class`_.
334+
For more information about the sagemaker.mxnet.MXNet estimator, see `sagemaker.mxnet.MXNet Class`_.
335335

336336

337337

@@ -370,15 +370,19 @@ fit Optional arguments
370370
Distributed training
371371
====================
372372

373-
When writing a distributed training script, use an MXNet kvstore to store and share model parameters.
373+
If you want to use parameter servers for distributed training, set the following parameter in your ``MXNet`` constructor:
374+
375+
.. code:: python
376+
377+
distributions={'parameter_server': {'enabled': True}}
378+
379+
Then, when writing a distributed training script, use an MXNet kvstore to store and share model parameters.
374380
During training, SageMaker automatically starts an MXNet kvstore server and scheduler processes on hosts in your training job cluster.
375381
Your script runs as an MXNet worker task, with one server process on each host in your cluster.
376382
One host is selected arbitrarily to run the scheduler process.
377383

378384
To learn more about writing distributed MXNet programs, please see `Distributed Training <https://mxnet.incubator.apache.org/versions/master/faq/distributed_training.html>`__ in the MXNet docs.
379385

380-
381-
382386
*******************
383387
Deploy MXNet models
384388
*******************

src/sagemaker/mxnet/estimator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,9 @@ def __init__(
9393
123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
9494
custom-image:latest.
9595
96-
distributions (dict): A dictionary with information on how to run distributed
97-
training (default: None).
98-
distributions:
96+
distributions (dict): A dictionary with information on how to run distributed
97+
training (default: None). To have parameter servers launched for training,
98+
set this value to be ``{'parameter_server': {'enabled': True}}``.
9999
**kwargs: Additional kwargs passed to the
100100
:class:`~sagemaker.estimator.Framework` constructor.
101101
"""

0 commit comments

Comments
 (0)