Skip to content

Commit 8890bc6

Browse files
author
Macdonald
committed
Adding code to make gluon MNIST scale
1 parent 53aec70 commit 8890bc6

File tree

1 file changed

+5
-3
lines changed
  • sagemaker-python-sdk/mxnet_gluon_mnist

1 file changed

+5
-3
lines changed

sagemaker-python-sdk/mxnet_gluon_mnist/mnist.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import logging
44
import mxnet as mx
5-
from mxnet import gluon, autograd
5+
from mxnet import gluon, autograd, kv
66
from mxnet.gluon import nn
77
import numpy as np
88
import json
@@ -16,7 +16,7 @@
1616
# ------------------------------------------------------------ #
1717

1818

19-
def train(channel_input_dirs, hyperparameters, **kwargs):
19+
def train(channel_input_dirs, hyperparameters, hosts, **kwargs):
2020
# SageMaker passes num_cpus, num_gpus and other args we can use to tailor training to
2121
# the current container environment, but here we just use simple cpu context.
2222
ctx = mx.cpu()
@@ -41,8 +41,10 @@ def train(channel_input_dirs, hyperparameters, **kwargs):
4141
# Collect all parameters from net and its children, then initialize them.
4242
net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)
4343
# Trainer is for updating parameters with gradient.
44+
store = kv.create('dist_sync' if hosts > 1 else 'local')
4445
trainer = gluon.Trainer(net.collect_params(), 'sgd',
45-
{'learning_rate': learning_rate, 'momentum': momentum})
46+
{'learning_rate': learning_rate, 'momentum': momentum},
47+
kvstore=store)
4648
metric = mx.metric.Accuracy()
4749
loss = gluon.loss.SoftmaxCrossEntropyLoss()
4850

0 commit comments

Comments
 (0)