Skip to content

Commit eb05645

Browse files
authored
Merge pull request #134 from aidan-plenert-macdonald/master
Make the MXNet Gluon MNIST Example scale
2 parents 53aec70 + 7330773 commit eb05645

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

sagemaker-python-sdk/mxnet_gluon_mnist/mnist.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# ------------------------------------------------------------ #
1717

1818

19-
def train(channel_input_dirs, hyperparameters, **kwargs):
19+
def train(channel_input_dirs, hyperparameters, hosts, num_gpus, **kwargs):
2020
# SageMaker passes num_cpus, num_gpus and other args we can use to tailor training to
2121
# the current container environment, but here we just use simple cpu context.
2222
ctx = mx.cpu()
@@ -41,8 +41,15 @@ def train(channel_input_dirs, hyperparameters, **kwargs):
4141
# Collect all parameters from net and its children, then initialize them.
4242
net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)
4343
# Trainer is for updating parameters with gradient.
44+
45+
if len(hosts) == 1:
46+
kvstore = 'device' if num_gpus > 0 else 'local'
47+
else:
48+
kvstore = 'dist_device_sync' if num_gpus > 0 else 'dist_sync'
49+
4450
trainer = gluon.Trainer(net.collect_params(), 'sgd',
45-
{'learning_rate': learning_rate, 'momentum': momentum})
51+
{'learning_rate': learning_rate, 'momentum': momentum},
52+
kvstore=kvstore)
4653
metric = mx.metric.Accuracy()
4754
loss = gluon.loss.SoftmaxCrossEntropyLoss()
4855

sagemaker-python-sdk/mxnet_gluon_sentiment/sentiment.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def train(current_host, hosts, num_cpus, num_gpus, channel_input_dirs, model_dir
3131
if len(hosts) == 1:
3232
kvstore = 'device' if num_gpus > 0 else 'local'
3333
else:
34-
kvstore = 'dist_sync'
34+
kvstore = 'dist_device_sync' if num_gpus > 0 else 'dist_sync'
3535

3636
ctx = mx.gpu() if num_gpus > 0 else mx.cpu()
3737

@@ -56,7 +56,8 @@ def train(current_host, hosts, num_cpus, num_gpus, channel_input_dirs, model_dir
5656
net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)
5757
# Trainer is for updating parameters with gradient.
5858
trainer = gluon.Trainer(net.collect_params(), 'adam',
59-
{'learning_rate': learning_rate})
59+
{'learning_rate': learning_rate},
60+
kvstore=kvstore)
6061
metric = mx.metric.Accuracy()
6162
loss = gluon.loss.SoftmaxCrossEntropyLoss()
6263

0 commit comments

Comments
 (0)