2
2
3
3
import logging
4
4
import mxnet as mx
5
- from mxnet import gluon , autograd
5
+ from mxnet import gluon , autograd , kv
6
6
from mxnet .gluon import nn
7
7
import numpy as np
8
8
import json
16
16
# ------------------------------------------------------------ #
17
17
18
18
19
- def train (channel_input_dirs , hyperparameters , ** kwargs ):
19
+ def train (channel_input_dirs , hyperparameters , hosts , ** kwargs ):
20
20
# SageMaker passes num_cpus, num_gpus and other args we can use to tailor training to
21
21
# the current container environment, but here we just use simple cpu context.
22
22
ctx = mx .cpu ()
@@ -41,8 +41,10 @@ def train(channel_input_dirs, hyperparameters, **kwargs):
41
41
# Collect all parameters from net and its children, then initialize them.
42
42
net .initialize (mx .init .Xavier (magnitude = 2.24 ), ctx = ctx )
43
43
# Trainer is for updating parameters with gradient.
44
+ store = kv .create ('dist_sync' if hosts > 1 else 'local' )
44
45
trainer = gluon .Trainer (net .collect_params (), 'sgd' ,
45
- {'learning_rate' : learning_rate , 'momentum' : momentum })
46
+ {'learning_rate' : learning_rate , 'momentum' : momentum },
47
+ kvstore = store )
46
48
metric = mx .metric .Accuracy ()
47
49
loss = gluon .loss .SoftmaxCrossEntropyLoss ()
48
50
0 commit comments